193 lines
6.0 KiB
C++
193 lines
6.0 KiB
C++
#pragma once
|
|
|
|
#include "shared/mcp/MCPTransport.hpp"
|
|
#include "shared/mcp/MCPTypes.hpp"
|
|
|
|
#include <queue>
|
|
#include <vector>
|
|
#include <functional>
|
|
|
|
namespace aissia::tests {
|
|
|
|
using namespace aissia::mcp;
|
|
|
|
/**
|
|
* @brief Mock implementation of IMCPTransport for testing MCPClient
|
|
*/
|
|
class MockTransport : public IMCPTransport {
|
|
public:
|
|
// ========================================================================
|
|
// IMCPTransport Interface
|
|
// ========================================================================
|
|
|
|
bool start() override {
|
|
if (m_startShouldFail) {
|
|
return false;
|
|
}
|
|
m_running = true;
|
|
return true;
|
|
}
|
|
|
|
void stop() override {
|
|
m_running = false;
|
|
}
|
|
|
|
bool isRunning() const override {
|
|
return m_running;
|
|
}
|
|
|
|
JsonRpcResponse sendRequest(const JsonRpcRequest& request, int timeoutMs = 30000) override {
|
|
m_sentRequests.push_back(request);
|
|
|
|
// If we have a custom handler, use it
|
|
if (m_requestHandler) {
|
|
return m_requestHandler(request);
|
|
}
|
|
|
|
// Otherwise, use prepared responses
|
|
if (!m_preparedResponses.empty()) {
|
|
auto response = m_preparedResponses.front();
|
|
m_preparedResponses.pop();
|
|
response.id = request.id; // Match the request ID
|
|
return response;
|
|
}
|
|
|
|
// Default: return error
|
|
JsonRpcResponse errorResponse;
|
|
errorResponse.id = request.id;
|
|
errorResponse.error = json{{"code", -32603}, {"message", "No prepared response"}};
|
|
return errorResponse;
|
|
}
|
|
|
|
void sendNotification(const std::string& method, const json& params) override {
|
|
m_sentNotifications.emplace_back(method, params);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Test Configuration
|
|
// ========================================================================
|
|
|
|
/**
|
|
* @brief Make start() fail
|
|
*/
|
|
void setStartShouldFail(bool fail) {
|
|
m_startShouldFail = fail;
|
|
}
|
|
|
|
/**
|
|
* @brief Add a response to be returned on next sendRequest
|
|
*/
|
|
void prepareResponse(const JsonRpcResponse& response) {
|
|
m_preparedResponses.push(response);
|
|
}
|
|
|
|
/**
|
|
* @brief Prepare a successful response with result
|
|
*/
|
|
void prepareSuccessResponse(const json& result) {
|
|
JsonRpcResponse response;
|
|
response.result = result;
|
|
m_preparedResponses.push(response);
|
|
}
|
|
|
|
/**
|
|
* @brief Prepare an error response
|
|
*/
|
|
void prepareErrorResponse(int code, const std::string& message) {
|
|
JsonRpcResponse response;
|
|
response.error = json{{"code", code}, {"message", message}};
|
|
m_preparedResponses.push(response);
|
|
}
|
|
|
|
/**
|
|
* @brief Set a custom handler for all requests
|
|
*/
|
|
void setRequestHandler(std::function<JsonRpcResponse(const JsonRpcRequest&)> handler) {
|
|
m_requestHandler = std::move(handler);
|
|
}
|
|
|
|
/**
|
|
* @brief Simulate MCP server with initialize and tools/list
|
|
*/
|
|
void setupAsMCPServer(const std::string& serverName, const std::vector<MCPTool>& tools) {
|
|
m_requestHandler = [serverName, tools](const JsonRpcRequest& req) -> JsonRpcResponse {
|
|
JsonRpcResponse resp;
|
|
resp.id = req.id;
|
|
|
|
if (req.method == "initialize") {
|
|
resp.result = json{
|
|
{"protocolVersion", "2024-11-05"},
|
|
{"capabilities", {{"tools", json::object()}}},
|
|
{"serverInfo", {{"name", serverName}, {"version", "1.0.0"}}}
|
|
};
|
|
} else if (req.method == "tools/list") {
|
|
json toolsJson = json::array();
|
|
for (const auto& tool : tools) {
|
|
toolsJson.push_back(tool.toJson());
|
|
}
|
|
resp.result = json{{"tools", toolsJson}};
|
|
} else if (req.method == "tools/call") {
|
|
resp.result = json{
|
|
{"content", json::array({{{"type", "text"}, {"text", "Tool executed"}}})}
|
|
};
|
|
} else {
|
|
resp.error = json{{"code", -32601}, {"message", "Method not found"}};
|
|
}
|
|
|
|
return resp;
|
|
};
|
|
}
|
|
|
|
// ========================================================================
|
|
// Test Verification
|
|
// ========================================================================
|
|
|
|
/**
|
|
* @brief Get all sent requests
|
|
*/
|
|
const std::vector<JsonRpcRequest>& getSentRequests() const {
|
|
return m_sentRequests;
|
|
}
|
|
|
|
/**
|
|
* @brief Check if a method was called
|
|
*/
|
|
bool wasMethodCalled(const std::string& method) const {
|
|
return std::any_of(m_sentRequests.begin(), m_sentRequests.end(),
|
|
[&method](const auto& req) { return req.method == method; });
|
|
}
|
|
|
|
/**
|
|
* @brief Get count of calls to a method
|
|
*/
|
|
size_t countMethodCalls(const std::string& method) const {
|
|
return std::count_if(m_sentRequests.begin(), m_sentRequests.end(),
|
|
[&method](const auto& req) { return req.method == method; });
|
|
}
|
|
|
|
/**
|
|
* @brief Clear all state
|
|
*/
|
|
void clear() {
|
|
m_sentRequests.clear();
|
|
m_sentNotifications.clear();
|
|
while (!m_preparedResponses.empty()) {
|
|
m_preparedResponses.pop();
|
|
}
|
|
m_requestHandler = nullptr;
|
|
}
|
|
|
|
// ========================================================================
|
|
// Test State
|
|
// ========================================================================
|
|
|
|
bool m_running = false;
|
|
bool m_startShouldFail = false;
|
|
std::vector<JsonRpcRequest> m_sentRequests;
|
|
std::vector<std::pair<std::string, json>> m_sentNotifications;
|
|
std::queue<JsonRpcResponse> m_preparedResponses;
|
|
std::function<JsonRpcResponse(const JsonRpcRequest&)> m_requestHandler;
|
|
};
|
|
|
|
} // namespace aissia::tests
|