aissia/tests/mocks/MockTransport.hpp

193 lines
6.0 KiB
C++

#pragma once
#include "shared/mcp/MCPTransport.hpp"
#include "shared/mcp/MCPTypes.hpp"
#include <queue>
#include <vector>
#include <functional>
namespace aissia::tests {
using namespace aissia::mcp;
/**
* @brief Mock implementation of IMCPTransport for testing MCPClient
*/
class MockTransport : public IMCPTransport {
public:
// ========================================================================
// IMCPTransport Interface
// ========================================================================
bool start() override {
if (m_startShouldFail) {
return false;
}
m_running = true;
return true;
}
void stop() override {
m_running = false;
}
bool isRunning() const override {
return m_running;
}
JsonRpcResponse sendRequest(const JsonRpcRequest& request, int timeoutMs = 30000) override {
m_sentRequests.push_back(request);
// If we have a custom handler, use it
if (m_requestHandler) {
return m_requestHandler(request);
}
// Otherwise, use prepared responses
if (!m_preparedResponses.empty()) {
auto response = m_preparedResponses.front();
m_preparedResponses.pop();
response.id = request.id; // Match the request ID
return response;
}
// Default: return error
JsonRpcResponse errorResponse;
errorResponse.id = request.id;
errorResponse.error = json{{"code", -32603}, {"message", "No prepared response"}};
return errorResponse;
}
void sendNotification(const std::string& method, const json& params) override {
m_sentNotifications.emplace_back(method, params);
}
// ========================================================================
// Test Configuration
// ========================================================================
/**
* @brief Make start() fail
*/
void setStartShouldFail(bool fail) {
m_startShouldFail = fail;
}
/**
* @brief Add a response to be returned on next sendRequest
*/
void prepareResponse(const JsonRpcResponse& response) {
m_preparedResponses.push(response);
}
/**
* @brief Prepare a successful response with result
*/
void prepareSuccessResponse(const json& result) {
JsonRpcResponse response;
response.result = result;
m_preparedResponses.push(response);
}
/**
* @brief Prepare an error response
*/
void prepareErrorResponse(int code, const std::string& message) {
JsonRpcResponse response;
response.error = json{{"code", code}, {"message", message}};
m_preparedResponses.push(response);
}
/**
* @brief Set a custom handler for all requests
*/
void setRequestHandler(std::function<JsonRpcResponse(const JsonRpcRequest&)> handler) {
m_requestHandler = std::move(handler);
}
/**
* @brief Simulate MCP server with initialize and tools/list
*/
void setupAsMCPServer(const std::string& serverName, const std::vector<MCPTool>& tools) {
m_requestHandler = [serverName, tools](const JsonRpcRequest& req) -> JsonRpcResponse {
JsonRpcResponse resp;
resp.id = req.id;
if (req.method == "initialize") {
resp.result = json{
{"protocolVersion", "2024-11-05"},
{"capabilities", {{"tools", json::object()}}},
{"serverInfo", {{"name", serverName}, {"version", "1.0.0"}}}
};
} else if (req.method == "tools/list") {
json toolsJson = json::array();
for (const auto& tool : tools) {
toolsJson.push_back(tool.toJson());
}
resp.result = json{{"tools", toolsJson}};
} else if (req.method == "tools/call") {
resp.result = json{
{"content", json::array({{{"type", "text"}, {"text", "Tool executed"}}})}
};
} else {
resp.error = json{{"code", -32601}, {"message", "Method not found"}};
}
return resp;
};
}
// ========================================================================
// Test Verification
// ========================================================================
/**
* @brief Get all sent requests
*/
const std::vector<JsonRpcRequest>& getSentRequests() const {
return m_sentRequests;
}
/**
* @brief Check if a method was called
*/
bool wasMethodCalled(const std::string& method) const {
return std::any_of(m_sentRequests.begin(), m_sentRequests.end(),
[&method](const auto& req) { return req.method == method; });
}
/**
* @brief Get count of calls to a method
*/
size_t countMethodCalls(const std::string& method) const {
return std::count_if(m_sentRequests.begin(), m_sentRequests.end(),
[&method](const auto& req) { return req.method == method; });
}
/**
* @brief Clear all state
*/
void clear() {
m_sentRequests.clear();
m_sentNotifications.clear();
while (!m_preparedResponses.empty()) {
m_preparedResponses.pop();
}
m_requestHandler = nullptr;
}
// ========================================================================
// Test State
// ========================================================================
bool m_running = false;
bool m_startShouldFail = false;
std::vector<JsonRpcRequest> m_sentRequests;
std::vector<std::pair<std::string, json>> m_sentNotifications;
std::queue<JsonRpcResponse> m_preparedResponses;
std::function<JsonRpcResponse(const JsonRpcRequest&)> m_requestHandler;
};
} // namespace aissia::tests