#pragma once #include "shared/mcp/MCPTransport.hpp" #include "shared/mcp/MCPTypes.hpp" #include #include #include namespace aissia::tests { using namespace aissia::mcp; /** * @brief Mock implementation of IMCPTransport for testing MCPClient */ class MockTransport : public IMCPTransport { public: // ======================================================================== // IMCPTransport Interface // ======================================================================== bool start() override { if (m_startShouldFail) { return false; } m_running = true; return true; } void stop() override { m_running = false; } bool isRunning() const override { return m_running; } JsonRpcResponse sendRequest(const JsonRpcRequest& request, int timeoutMs = 30000) override { m_sentRequests.push_back(request); // If we have a custom handler, use it if (m_requestHandler) { return m_requestHandler(request); } // Otherwise, use prepared responses if (!m_preparedResponses.empty()) { auto response = m_preparedResponses.front(); m_preparedResponses.pop(); response.id = request.id; // Match the request ID return response; } // Default: return error JsonRpcResponse errorResponse; errorResponse.id = request.id; errorResponse.error = json{{"code", -32603}, {"message", "No prepared response"}}; return errorResponse; } void sendNotification(const std::string& method, const json& params) override { m_sentNotifications.emplace_back(method, params); } // ======================================================================== // Test Configuration // ======================================================================== /** * @brief Make start() fail */ void setStartShouldFail(bool fail) { m_startShouldFail = fail; } /** * @brief Add a response to be returned on next sendRequest */ void prepareResponse(const JsonRpcResponse& response) { m_preparedResponses.push(response); } /** * @brief Prepare a successful response with result */ void prepareSuccessResponse(const json& result) { JsonRpcResponse response; response.result = result; m_preparedResponses.push(response); } /** * @brief Prepare an error response */ void prepareErrorResponse(int code, const std::string& message) { JsonRpcResponse response; response.error = json{{"code", code}, {"message", message}}; m_preparedResponses.push(response); } /** * @brief Set a custom handler for all requests */ void setRequestHandler(std::function handler) { m_requestHandler = std::move(handler); } /** * @brief Simulate MCP server with initialize and tools/list */ void setupAsMCPServer(const std::string& serverName, const std::vector& tools) { m_requestHandler = [serverName, tools](const JsonRpcRequest& req) -> JsonRpcResponse { JsonRpcResponse resp; resp.id = req.id; if (req.method == "initialize") { resp.result = json{ {"protocolVersion", "2024-11-05"}, {"capabilities", {{"tools", json::object()}}}, {"serverInfo", {{"name", serverName}, {"version", "1.0.0"}}} }; } else if (req.method == "tools/list") { json toolsJson = json::array(); for (const auto& tool : tools) { toolsJson.push_back(tool.toJson()); } resp.result = json{{"tools", toolsJson}}; } else if (req.method == "tools/call") { resp.result = json{ {"content", json::array({{{"type", "text"}, {"text", "Tool executed"}}})} }; } else { resp.error = json{{"code", -32601}, {"message", "Method not found"}}; } return resp; }; } // ======================================================================== // Test Verification // ======================================================================== /** * @brief Get all sent requests */ const std::vector& getSentRequests() const { return m_sentRequests; } /** * @brief Check if a method was called */ bool wasMethodCalled(const std::string& method) const { return std::any_of(m_sentRequests.begin(), m_sentRequests.end(), [&method](const auto& req) { return req.method == method; }); } /** * @brief Get count of calls to a method */ size_t countMethodCalls(const std::string& method) const { return std::count_if(m_sentRequests.begin(), m_sentRequests.end(), [&method](const auto& req) { return req.method == method; }); } /** * @brief Clear all state */ void clear() { m_sentRequests.clear(); m_sentNotifications.clear(); while (!m_preparedResponses.empty()) { m_preparedResponses.pop(); } m_requestHandler = nullptr; } // ======================================================================== // Test State // ======================================================================== bool m_running = false; bool m_startShouldFail = false; std::vector m_sentRequests; std::vector> m_sentNotifications; std::queue m_preparedResponses; std::function m_requestHandler; }; } // namespace aissia::tests