mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-05-13 16:07:30 +00:00
* 🪟 feat: Add allowedAddresses Exemption List For SSRF-Guarded Targets LibreChat already blocks SSRF-prone targets (private IPs, loopback, link-local, .internal/.local TLDs) at every server-side fetch site that consumes user-controllable URLs — custom-endpoint baseURLs, MCP servers, OpenAPI Actions, and OAuth endpoints. The only existing escape hatch is `allowedDomains`, but that flips the field into a strict whitelist: adding `127.0.0.1` to permit a self-hosted Ollama also blocks every public destination that isn't in the list. Introduce `allowedAddresses` as the orthogonal primitive: a private- IP-space exemption list. When a hostname or its resolved IP appears in the list, the SSRF block is bypassed for that target. Public destinations remain reachable. Operators can now run self-hosted LLMs / MCP servers / Action endpoints on private addresses without weakening the default-deny posture for everything else. Schema additions in `packages/data-provider/src/config.ts`: - `endpoints.allowedAddresses` (new — gates `validateEndpointURL`) - `mcpSettings.allowedAddresses` (parallel to `allowedDomains`) - `actions.allowedAddresses` (parallel to `allowedDomains`) Core changes in `packages/api/src/auth/`: - New `isAddressAllowed(hostnameOrIP, allowedAddresses)` — pure, case-insensitive, bracket-stripped literal match. - Threaded the list through `isSSRFTarget`, `resolveHostnameSSRF`, `isDomainAllowedCore`, `isActionDomainAllowed`, `isMCPDomainAllowed`, `isOAuthUrlAllowed`, and `validateEndpointURL`. - Extended `createSSRFSafeAgents` and `createSSRFSafeUndiciConnect` to accept the list, building an SSRF-safe DNS lookup that exempts matching hostnames/IPs at TCP connect time (TOCTOU-safe). Wiring: - Custom and OpenAI endpoint initialize sites pass `endpoints.allowedAddresses` to `validateEndpointURL`. - `MCPServersRegistry` stores `allowedAddresses` and exposes it via `getAllowedAddresses()`. The factory, connection class, manager, `UserConnectionManager`, and `ConnectionsRepository` all thread it through to the SSRF utilities. - `MCPOAuthHandler.initiateOAuthFlow`, `refreshOAuthTokens`, and `validateOAuthUrl` accept the list and consult it on every URL validation along the OAuth chain. - `ToolService`, `ActionService`, and the assistants/agents action routes pass `actions.allowedAddresses` to `isActionDomainAllowed` and to `createSSRFSafeAgents` for runtime action calls. - `initializeMCPs.js` reads `mcpSettings.allowedAddresses` from the app config and forwards it to the registry constructor. Documentation: - `librechat.example.yaml` shows the new field next to each existing `allowedDomains` block, with a note clarifying that `allowedAddresses` is an exemption list (not a whitelist). Tests: - Unit tests for `isAddressAllowed` covering literal IPs, hostnames, IPv6 brackets, case insensitivity, and partial-match rejection. - Exemption tests for every entry point: `isSSRFTarget`, `resolveHostnameSSRF`, `validateEndpointURL`, `isActionDomainAllowed`, `isMCPDomainAllowed`, `isOAuthUrlAllowed`. - Existing tests updated to reflect the new optional parameter. Default behavior is unchanged: omitted = empty list = no exemptions. * 🩹 fix: Plumb allowedAddresses Through AppConfig endpoints Type The initial PR added `endpoints.allowedAddresses` to the data-provider config schema and consumed it in the endpoint initialize sites, but the runtime `AppConfig.endpoints` shape in `@librechat/data-schemas` was a hand-maintained subset that didn't include the new field — so `tsc` rejected `appConfig.endpoints.allowedAddresses`. Add the field to `AppConfig['endpoints']` in `packages/data-schemas/src/types/app.ts` and forward it from the loaded config in `packages/data-schemas/src/app/endpoints.ts` so the runtime config carries the value. Update `initializeMCPs.spec.js` to expect the third positional argument (`allowedAddresses`) on the `createMCPServersRegistry` call. * 🩹 fix: Enforce allowedDomains Before allowedAddresses In isOAuthUrlAllowed The initial implementation checked the address exemption first, so a URL whose hostname appeared in `allowedAddresses` would return true even when the admin had configured `allowedDomains` as a strict bound on OAuth endpoints. A malicious MCP server could advertise OAuth metadata, token, or revocation URLs at any address the admin had permitted for an unrelated reason (a self-hosted LLM at `127.0.0.1`, for example) and pass validation, expanding SSRF reach beyond the configured domain whitelist. Reorder: when `allowedDomains` is set, treat it as authoritative — return true only if the URL matches a domain entry, otherwise fall through to false. The address exemption only applies when no `allowedDomains` is configured (mirrors how the downstream SSRF check in `validateOAuthUrl` consults `allowedAddresses`). Add a regression test asserting that an `allowedAddresses` entry does not broaden a configured `allowedDomains` list. Reported by chatgpt-codex-connector on PR #12933. * 🩹 fix: Forward allowedAddresses To Remaining OAuth Callers Two `MCPOAuthHandler` callers still used the pre-feature signatures and were silently dropping the new `allowedAddresses` argument: - `api/server/routes/mcp.js` invoked `initiateOAuthFlow` with the old 5-argument shape, so OAuth flows initiated through the route handler ignored the registry's `getAllowedAddresses()` and would reject any metadata/authorization/token URL on a permitted private host. - `api/server/controllers/UserController.js#maybeUninstallOAuthMCP` invoked `revokeOAuthToken` without the address exemption, so uninstalling an OAuth-backed MCP server on a permitted private host would fail at the revocation step even though the rest of the MCP connection path now permits it. Both sites now read `allowedAddresses` from the registry alongside `allowedDomains` and forward it. Reported by Copilot on PR #12933. * 🩹 fix: Update Test Mocks And Assertions For OAuth allowedAddresses The previous commit started passing `allowedAddresses` to `MCPOAuthHandler.initiateOAuthFlow` from `api/server/routes/mcp.js` and to `MCPOAuthHandler.revokeOAuthToken` from `api/server/controllers/UserController.js`, but the corresponding test files mocked the registry without `getAllowedAddresses` (causing `TypeError`s) and asserted the old positional shape on `toHaveBeenCalledWith`. Update the mocks and assertions to match the new arity: - `api/server/routes/__tests__/mcp.spec.js`: add `getAllowedDomains`/`getAllowedAddresses` to the registry mock and expect the additional positional args on `initiateOAuthFlow`. - `api/server/controllers/__tests__/maybeUninstallOAuthMCP.spec.js`: add a `getAllowedAddresses` mock alongside the existing `getAllowedDomains` and seed it in `setupOAuthServerFound`. - `api/server/controllers/__tests__/UserController.mcpOAuth.spec.js`: add `getAllowedAddresses` to the registry mock and expect the trailing `null` arg on the three `revokeOAuthToken` assertions. * 🛡️ fix: Address Comprehensive Review — Scope allowedAddresses To Private IP Space Major findings from the comprehensive PR review (severity → fix): **CRITICAL — `validateOAuthUrl` SSRF fallback bypass.** When `allowedDomains` is configured and a URL fails the whitelist, the SSRF fallback in `validateOAuthUrl` was still passing `allowedAddresses` to `isSSRFTarget` / `resolveHostnameSSRF`, letting a malicious MCP server advertise OAuth endpoints at any address the admin had permitted for an unrelated reason. Suppress `allowedAddresses` in the fallback when `allowedDomains` is active — the address exemption is opt-in for the no-whitelist mode only. **MAJOR — WebSocket transport SSRF check ignored exemptions.** The `constructTransport` WebSocket branch called `resolveHostnameSSRF(wsHostname)` without `this.allowedAddresses`, so a permitted private MCP server would pass `isMCPDomainAllowed` but be blocked at transport creation. Forward the exemption. **Scope `allowedAddresses` to private IP space only (operator directive).** The exemption list is for permitting private/internal targets; it must not be a back-door to broaden trust to public destinations. - Schema (`packages/data-provider/src/config.ts`): new `allowedAddressesSchema` rejects URLs (`://`), paths/CIDR (`/`), whitespace, and public IPv4/IPv6 literals at config-load time. Wired into `endpoints`, `mcpSettings`, and `actions`. - Runtime (`packages/api/src/auth/domain.ts`): `isAddressAllowed` now drops public-IP candidates and public-IP entries on the match path — defense in depth so a misconfigured runtime list never grants exemption. - Hot path (`packages/api/src/auth/agent.ts`): `buildSSRFSafeLookup` pre-normalizes the list into a `Set<string>` once at construction and applies the same scoping filter, so the connect-time DNS lookup is an O(1) Set membership check instead of a full re-iterate-and-normalize on every outbound request. **Test coverage for the connect-time and OAuth-fallback paths.** - `agent.spec.ts`: new describe block exercising `buildSSRFSafeLookup` and `createSSRFSafe*` with `allowedAddresses` — hostname-literal exemption, resolved-IP exemption, public-IP scoping, URL/CIDR/whitespace rejection, and the default no-list block. - `handler.allowedAddresses.test.ts` (new): integration tests for `validateOAuthUrl` — covers both the no-domains-set "permit private" path and the strict-bound regression where `allowedAddresses` must NOT bypass `allowedDomains`. **Documentation & cleanup.** - `connection.ts` redirect SSRF check: explicit comment that `allowedAddresses` is intentionally NOT consulted for redirect targets (server-controlled, must not inherit the admin's exemption). - `MCPConnectionFactory.test.ts`: replaced an `eslint-disable` with a proper `import { getTenantId } from '@librechat/data-schemas'`. The disable was added to make a pre-existing `require()` quiet — the cleaner fix is to use the existing top-level import. Updated `MCPConnectionSSRF.test.ts` WebSocket SSRF assertions to match the new two-argument call shape (`hostname, allowedAddresses`). * 🩹 fix: Require Absolute URL Before allowedAddresses Trust Bypass In isOAuthUrlAllowed `parseDomainSpec` is lenient — it silently prepends `https://` to schemeless inputs so it can match patterns like bare `example.com`. That leniency leaked into `isOAuthUrlAllowed`'s new `allowedAddresses` short-circuit: a value like `10.0.0.5/oauth` (no scheme) would parse successfully via the prepended default, hit the address-exemption path, return `true`, and skip `validateOAuthUrl`'s strict `new URL(url)` parse-or-throw — only to fail later in OAuth discovery with a less clear runtime error. Add a strict `new URL(url)` gate at the top of `isOAuthUrlAllowed`. Schemeless inputs now fall through to `validateOAuthUrl`'s explicit "Invalid OAuth <field>" rejection. Tests added in both `auth/domain.spec.ts` (unit) and the OAuth handler integration spec (end-to-end). Reported by chatgpt-codex-connector (P2) on PR #12933. * 🛡️ fix: Address Follow-Up Comprehensive Review — Schema Tests, Shared Normalization, host:port Auditing the second comprehensive review: **F1 MAJOR — schema validation untested.** `allowedAddressesSchema` had zero coverage, so a regression in the three refinement stages or the three wiring locations (`endpoints` / `mcpSettings` / `actions`) would silently let invalid entries reach the runtime. Added a dedicated `describe('allowedAddressesSchema')` block in `config.spec.ts` covering: valid private IPs (v4 + v6, including the previously-missed 192.0.0.0/24 range), accepted hostnames, all rejection categories (URLs, CIDR, paths, whitespace tabs/newlines, host:port, public IP literals), and full `configSchema.parse()` integration at each of the three nesting points. **F2 MINOR — `isPrivateIPv4Literal` divergence.** The schema reimpl in `packages/data-provider` was discarding the `c` octet, so the `192.0.0.0/24` (RFC 5736 IETF protocol assignments) range that the authoritative `isPrivateIPv4` accepts was being rejected with a misleading "public IP" error. Destructure `c` and add the missing range check; covered by the new schema tests. **F3 MINOR — DRY violation across `domain.ts` and `agent.ts`.** Both files had independent normalization implementations with a subtle whitespace-check divergence (`/\s/` vs `.includes(' ')`). Extracted the shared logic into a new `packages/api/src/auth/allowedAddresses.ts` module that both consumers import: - `normalizeAddressEntry(entry)` — single-entry shape check - `looksLikeHostPort(entry)` — host:port detector (used by F4) - `normalizeAllowedAddressesSet(list)` — pre-normalized Set for the connect-time hot path - `isAddressInAllowedSet(candidate, set)` — membership check that enforces private-IP scoping on the candidate Both `isAddressAllowed` (preflight) and `buildSSRFSafeLookup` (connect) now go through the same primitives; the whitespace divergence is gone. To break the import cycle (`allowedAddresses` needs `isPrivateIP`, `domain` previously owned it), extracted IP private-range detection into a leaf `auth/ip.ts` module. `domain.ts` re-exports `isPrivateIP` for backward compatibility with existing call sites. **F4 MINOR — `host:port` silently misclassified.** Entries like `localhost:8080` previously slipped through the URL/path guard, were mis-detected as IPv6, failed `isPrivateIP`, and were silently dropped with a misleading "public IP" schema error. Added an explicit `looksLikeHostPort` check with a clear error: "allowedAddresses entries must not include a port — list the bare hostname or IP only." Bare `::1`, `[::1]`, and other valid IPv6 literals are intentionally not matched (regex distinguishes by colon count and the bracketed `[ipv6]:port` form). **F5 MINOR — hostname-trust documentation gap.** Hostname entries short-circuit `resolveHostnameSSRF` before any DNS lookup — that's a deliberate design (admin trusts the name) but it means the exemption follows whatever the name resolves to at runtime. Added an explicit note in `librechat.example.yaml` for both `mcpSettings.allowedAddresses` and `endpoints.allowedAddresses`: "a hostname entry trusts whatever IP that name resolves to. Only list hostnames whose DNS you control. Prefer literal IPs when you can." **F6** (8 positional params) is flagged for follow-up; refactor to an options object is a breaking-API change deferred to a separate PR. **F7** (redirect/WebSocket asymmetry, NIT, conf 40) — skipping; the existing inline comment is sufficient. * 🧹 chore: Address Follow-Up NITs — Import Order And Mirror-Function Naming Three NITs from the latest comprehensive review: **NIT #1 (conf 85) — local import order.** AGENTS.md requires local imports sorted longest-to-shortest. Both `domain.ts` and `agent.ts` had `./ip` (shorter) before `./allowedAddresses` (longer). Swapped. **NIT #2 (conf 60) — missing cross-reference.** The schema-side `isHostPortShape` in `packages/data-provider/src/config.ts` had no note pointing at the canonical runtime mirror. Added a JSDoc paragraph explaining the mirror relationship and why a local copy exists (the data-provider package can't import from `@librechat/api` without creating a circular dependency). **NIT #3 (conf 50) — naming inconsistency.** Renamed `isHostPortShape` → `looksLikeHostPort` so the schema mirror matches the runtime helper exactly. Kept as a separate function (not a shared import) for the same circular-dependency reason; the matching name makes it obvious they should stay in lockstep.
1349 lines
43 KiB
JavaScript
1349 lines
43 KiB
JavaScript
// Mock all dependencies - define mocks before imports
|
|
// Mock all dependencies
|
|
jest.mock('@librechat/data-schemas', () => ({
|
|
logger: {
|
|
debug: jest.fn(),
|
|
error: jest.fn(),
|
|
info: jest.fn(),
|
|
warn: jest.fn(),
|
|
},
|
|
}));
|
|
|
|
// Create mock registry instance
|
|
const mockRegistryInstance = {
|
|
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
|
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
|
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
|
ensureConfigServers: jest.fn(() => Promise.resolve({})),
|
|
};
|
|
|
|
// Create isMCPDomainAllowed mock that can be configured per-test
|
|
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
|
|
|
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
|
|
|
jest.mock('@librechat/api', () => {
|
|
const actual = jest.requireActual('@librechat/api');
|
|
return {
|
|
...actual,
|
|
sendEvent: jest.fn(),
|
|
get isMCPDomainAllowed() {
|
|
return mockIsMCPDomainAllowed;
|
|
},
|
|
GenerationJobManager: {
|
|
emitChunk: jest.fn(),
|
|
},
|
|
};
|
|
});
|
|
|
|
const { logger } = require('@librechat/data-schemas');
|
|
const { MCPOAuthHandler } = require('@librechat/api');
|
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
|
const D = Constants.mcp_delimiter;
|
|
const {
|
|
createMCPTool,
|
|
createMCPTools,
|
|
getMCPSetupData,
|
|
checkOAuthFlowStatus,
|
|
getServerConnectionStatus,
|
|
createUnavailableToolStub,
|
|
} = require('./MCP');
|
|
|
|
jest.mock('./Config', () => ({
|
|
loadCustomConfig: jest.fn(),
|
|
get getAppConfig() {
|
|
return mockGetAppConfig;
|
|
},
|
|
}));
|
|
|
|
jest.mock('~/config', () => ({
|
|
getMCPManager: jest.fn(),
|
|
getFlowStateManager: jest.fn(),
|
|
getOAuthReconnectionManager: jest.fn(),
|
|
getMCPServersRegistry: jest.fn(() => mockRegistryInstance),
|
|
}));
|
|
|
|
jest.mock('~/cache', () => ({
|
|
getLogStores: jest.fn(),
|
|
}));
|
|
|
|
jest.mock('~/models', () => ({
|
|
findToken: jest.fn(),
|
|
createToken: jest.fn(),
|
|
updateToken: jest.fn(),
|
|
}));
|
|
|
|
jest.mock('./Tools/mcp', () => ({
|
|
reinitMCPServer: jest.fn(),
|
|
}));
|
|
|
|
jest.mock('./GraphTokenService', () => ({
|
|
getGraphApiToken: jest.fn(),
|
|
}));
|
|
|
|
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
|
|
let mockGetMCPManager;
|
|
let mockGetFlowStateManager;
|
|
let mockGetLogStores;
|
|
let mockGetOAuthReconnectionManager;
|
|
|
|
beforeEach(() => {
|
|
jest.clearAllMocks();
|
|
jest.spyOn(MCPOAuthHandler, 'generateFlowId');
|
|
|
|
mockGetMCPManager = require('~/config').getMCPManager;
|
|
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
|
mockGetLogStores = require('~/cache').getLogStores;
|
|
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
|
});
|
|
|
|
describe('getMCPSetupData', () => {
|
|
const mockUserId = 'user-123';
|
|
const mockConfig = {
|
|
server1: { type: 'stdio' },
|
|
server2: { type: 'http' },
|
|
};
|
|
|
|
beforeEach(() => {
|
|
mockGetMCPManager.mockReturnValue({
|
|
appConnections: { getLoaded: jest.fn(() => new Map()) },
|
|
getUserConnections: jest.fn(() => new Map()),
|
|
});
|
|
mockRegistryInstance.getOAuthServers.mockResolvedValue(new Set());
|
|
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
|
});
|
|
|
|
it('should successfully return MCP setup data', async () => {
|
|
const mockConfigWithOAuth = {
|
|
server1: { type: 'stdio' },
|
|
server2: { type: 'http', requiresOAuth: true },
|
|
};
|
|
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth);
|
|
|
|
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
|
|
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
|
|
|
|
const mockMCPManager = {
|
|
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
|
getUserConnections: jest.fn(() => mockUserConnections),
|
|
};
|
|
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
|
|
|
const result = await getMCPSetupData(mockUserId);
|
|
|
|
expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled();
|
|
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(
|
|
mockUserId,
|
|
expect.any(Object),
|
|
);
|
|
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
|
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
|
|
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
|
|
|
expect(result.mcpConfig).toEqual(mockConfigWithOAuth);
|
|
expect(result.appConnections).toEqual(mockAppConnections);
|
|
expect(result.userConnections).toEqual(mockUserConnections);
|
|
expect(result.oauthServers).toEqual(new Set(['server2']));
|
|
});
|
|
|
|
it('should return empty data when no servers are configured', async () => {
|
|
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
|
|
const result = await getMCPSetupData(mockUserId);
|
|
expect(result.mcpConfig).toEqual({});
|
|
expect(result.oauthServers).toEqual(new Set());
|
|
});
|
|
|
|
it('should handle null values from MCP manager gracefully', async () => {
|
|
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
|
|
|
const mockMCPManager = {
|
|
appConnections: { getLoaded: jest.fn(() => Promise.resolve(null)) },
|
|
getUserConnections: jest.fn(() => null),
|
|
};
|
|
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
|
mockRegistryInstance.getOAuthServers.mockResolvedValue(new Set());
|
|
|
|
const result = await getMCPSetupData(mockUserId);
|
|
|
|
expect(result).toEqual({
|
|
mcpConfig: mockConfig,
|
|
appConnections: new Map(),
|
|
userConnections: new Map(),
|
|
oauthServers: new Set(),
|
|
});
|
|
});
|
|
});
|
|
|
|
describe('checkOAuthFlowStatus', () => {
|
|
const mockUserId = 'user-123';
|
|
const mockServerName = 'test-server';
|
|
const mockFlowId = 'flow-123';
|
|
|
|
beforeEach(() => {
|
|
const mockFlowsCache = {};
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(),
|
|
};
|
|
|
|
mockGetLogStores.mockReturnValue(mockFlowsCache);
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
MCPOAuthHandler.generateFlowId.mockReturnValue(mockFlowId);
|
|
});
|
|
|
|
it('should return false flags when no flow state exists', async () => {
|
|
const mockFlowManager = { getFlowState: jest.fn(() => null) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(mockGetLogStores).toHaveBeenCalledWith(CacheKeys.FLOWS);
|
|
expect(MCPOAuthHandler.generateFlowId).toHaveBeenCalledWith(mockUserId, mockServerName);
|
|
expect(mockFlowManager.getFlowState).toHaveBeenCalledWith(mockFlowId, 'mcp_oauth');
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
|
});
|
|
|
|
it('should detect failed flow when status is FAILED', async () => {
|
|
const mockFlowState = {
|
|
status: 'FAILED',
|
|
createdAt: Date.now() - 60000, // 1 minute ago
|
|
ttl: 180000,
|
|
};
|
|
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
|
expect(logger.debug).toHaveBeenCalledWith(
|
|
expect.stringContaining('Found failed OAuth flow'),
|
|
expect.objectContaining({
|
|
flowId: mockFlowId,
|
|
status: 'FAILED',
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should detect failed flow when flow has timed out', async () => {
|
|
const mockFlowState = {
|
|
status: 'PENDING',
|
|
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s TTL)
|
|
ttl: 180000,
|
|
};
|
|
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
|
expect(logger.debug).toHaveBeenCalledWith(
|
|
expect.stringContaining('Found failed OAuth flow'),
|
|
expect.objectContaining({
|
|
timedOut: true,
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should detect failed flow when TTL not specified and flow exceeds default TTL', async () => {
|
|
const mockFlowState = {
|
|
status: 'PENDING',
|
|
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s default TTL)
|
|
// ttl not specified, should use 180000 default
|
|
};
|
|
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
|
});
|
|
|
|
it('should detect active flow when status is PENDING and within TTL', async () => {
|
|
const mockFlowState = {
|
|
status: 'PENDING',
|
|
createdAt: Date.now() - 60000, // 1 minute ago (< 180s TTL)
|
|
ttl: 180000,
|
|
};
|
|
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: true, hasFailedFlow: false });
|
|
expect(logger.debug).toHaveBeenCalledWith(
|
|
expect.stringContaining('Found active OAuth flow'),
|
|
expect.objectContaining({
|
|
flowId: mockFlowId,
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should return false flags for other statuses', async () => {
|
|
const mockFlowState = {
|
|
status: 'COMPLETED',
|
|
createdAt: Date.now() - 60000,
|
|
ttl: 180000,
|
|
};
|
|
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
|
});
|
|
|
|
it('should handle errors gracefully', async () => {
|
|
const mockError = new Error('Flow state error');
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(() => {
|
|
throw mockError;
|
|
}),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
|
|
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
|
|
|
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
|
expect(logger.error).toHaveBeenCalledWith(
|
|
expect.stringContaining('Error checking OAuth flows'),
|
|
mockError,
|
|
);
|
|
});
|
|
});
|
|
|
|
describe('getServerConnectionStatus', () => {
|
|
const mockUserId = 'user-123';
|
|
const mockServerName = 'test-server';
|
|
const mockConfig = { updatedAt: Date.now() };
|
|
|
|
it('should return app connection state when available', async () => {
|
|
const appConnections = new Map([
|
|
[
|
|
mockServerName,
|
|
{
|
|
connectionState: 'connected',
|
|
isStale: jest.fn(() => false),
|
|
},
|
|
],
|
|
]);
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set();
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: false,
|
|
connectionState: 'connected',
|
|
});
|
|
});
|
|
|
|
it('should fallback to user connection state when app connection not available', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map([
|
|
[
|
|
mockServerName,
|
|
{
|
|
connectionState: 'connecting',
|
|
isStale: jest.fn(() => false),
|
|
},
|
|
],
|
|
]);
|
|
const oauthServers = new Set();
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: false,
|
|
connectionState: 'connecting',
|
|
});
|
|
});
|
|
|
|
it('should default to disconnected when no connections exist', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set();
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: false,
|
|
connectionState: 'disconnected',
|
|
});
|
|
});
|
|
|
|
it('should prioritize app connection over user connection', async () => {
|
|
const appConnections = new Map([
|
|
[
|
|
mockServerName,
|
|
{
|
|
connectionState: 'connected',
|
|
isStale: jest.fn(() => false),
|
|
},
|
|
],
|
|
]);
|
|
const userConnections = new Map([
|
|
[
|
|
mockServerName,
|
|
{
|
|
connectionState: 'disconnected',
|
|
isStale: jest.fn(() => false),
|
|
},
|
|
],
|
|
]);
|
|
const oauthServers = new Set();
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: false,
|
|
connectionState: 'connected',
|
|
});
|
|
});
|
|
|
|
it('should indicate OAuth requirement when server is in OAuth servers set', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
// Mock OAuthReconnectionManager
|
|
const mockOAuthReconnectionManager = {
|
|
isReconnecting: jest.fn(() => false),
|
|
};
|
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result.requiresOAuth).toBe(true);
|
|
});
|
|
|
|
it('should handle OAuth flow status when disconnected and requires OAuth with failed flow', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
// Mock OAuthReconnectionManager
|
|
const mockOAuthReconnectionManager = {
|
|
isReconnecting: jest.fn(() => false),
|
|
};
|
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
|
|
|
// Mock flow state to return failed flow
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(() => ({
|
|
status: 'FAILED',
|
|
createdAt: Date.now() - 60000,
|
|
ttl: 180000,
|
|
})),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
mockGetLogStores.mockReturnValue({});
|
|
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: true,
|
|
connectionState: 'error',
|
|
});
|
|
});
|
|
|
|
it('should handle OAuth flow status when disconnected and requires OAuth with active flow', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
// Mock OAuthReconnectionManager
|
|
const mockOAuthReconnectionManager = {
|
|
isReconnecting: jest.fn(() => false),
|
|
};
|
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
|
|
|
// Mock flow state to return active flow
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(() => ({
|
|
status: 'PENDING',
|
|
createdAt: Date.now() - 60000, // 1 minute ago
|
|
ttl: 180000, // 3 minutes TTL
|
|
})),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
mockGetLogStores.mockReturnValue({});
|
|
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: true,
|
|
connectionState: 'connecting',
|
|
});
|
|
});
|
|
|
|
it('should handle OAuth flow status when disconnected and requires OAuth with no flow', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
// Mock OAuthReconnectionManager
|
|
const mockOAuthReconnectionManager = {
|
|
isReconnecting: jest.fn(() => false),
|
|
};
|
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
|
|
|
// Mock flow state to return no flow
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(() => null),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
mockGetLogStores.mockReturnValue({});
|
|
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: true,
|
|
connectionState: 'disconnected',
|
|
});
|
|
});
|
|
|
|
it('should return connecting state when OAuth server is reconnecting', async () => {
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
// Mock OAuthReconnectionManager to return true for isReconnecting
|
|
const mockOAuthReconnectionManager = {
|
|
isReconnecting: jest.fn(() => true),
|
|
};
|
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: true,
|
|
connectionState: 'connecting',
|
|
});
|
|
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
|
mockUserId,
|
|
mockServerName,
|
|
);
|
|
});
|
|
|
|
it('should not check OAuth flow status when server is connected', async () => {
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
mockGetLogStores.mockReturnValue({});
|
|
|
|
const appConnections = new Map([
|
|
[
|
|
mockServerName,
|
|
{
|
|
connectionState: 'connected',
|
|
isStale: jest.fn(() => false),
|
|
},
|
|
],
|
|
]);
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set([mockServerName]);
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: true,
|
|
connectionState: 'connected',
|
|
});
|
|
|
|
// Should not call flow manager since server is connected
|
|
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should not check OAuth flow status when server does not require OAuth', async () => {
|
|
const mockFlowManager = {
|
|
getFlowState: jest.fn(),
|
|
};
|
|
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
|
mockGetLogStores.mockReturnValue({});
|
|
|
|
const appConnections = new Map();
|
|
const userConnections = new Map();
|
|
const oauthServers = new Set(); // Server not in OAuth servers
|
|
|
|
const result = await getServerConnectionStatus(
|
|
mockUserId,
|
|
mockServerName,
|
|
mockConfig,
|
|
appConnections,
|
|
userConnections,
|
|
oauthServers,
|
|
);
|
|
|
|
expect(result).toEqual({
|
|
requiresOAuth: false,
|
|
connectionState: 'disconnected',
|
|
});
|
|
|
|
// Should not call flow manager since server doesn't require OAuth
|
|
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
});
|
|
|
|
describe('User parameter passing tests', () => {
|
|
let mockReinitMCPServer;
|
|
let mockGetFlowStateManager;
|
|
let mockGetLogStores;
|
|
|
|
beforeEach(() => {
|
|
jest.clearAllMocks();
|
|
mockReinitMCPServer = require('./Tools/mcp').reinitMCPServer;
|
|
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
|
mockGetLogStores = require('~/cache').getLogStores;
|
|
|
|
// Setup default mocks
|
|
mockGetLogStores.mockReturnValue({});
|
|
mockGetFlowStateManager.mockReturnValue({
|
|
createFlowWithHandler: jest.fn(),
|
|
failFlow: jest.fn(),
|
|
});
|
|
|
|
// Reset domain validation mock to default (allow all)
|
|
mockIsMCPDomainAllowed.mockReset();
|
|
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
|
|
|
// Reset registry mocks
|
|
mockRegistryInstance.getServerConfig.mockReset();
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
|
|
|
// Reset getAppConfig mock to default (no restrictions)
|
|
mockGetAppConfig.mockReset();
|
|
mockGetAppConfig.mockResolvedValue({});
|
|
});
|
|
|
|
describe('createMCPTools', () => {
|
|
it('should pass user parameter to reinitMCPServer when calling reconnectServer internally', async () => {
|
|
const mockUser = { id: 'test-user-123', name: 'Test User' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
const mockSignal = new AbortController().signal;
|
|
|
|
mockReinitMCPServer.mockResolvedValue({
|
|
tools: [{ name: 'test-tool' }],
|
|
availableTools: {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
},
|
|
});
|
|
|
|
await createMCPTools({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
serverName: 'test-server',
|
|
provider: 'openai',
|
|
signal: mockSignal,
|
|
userMCPAuthMap: {},
|
|
});
|
|
|
|
// Verify reinitMCPServer was called with the user
|
|
expect(mockReinitMCPServer).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
user: mockUser,
|
|
serverName: 'test-server',
|
|
}),
|
|
);
|
|
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
|
|
});
|
|
|
|
it('should throw error if user is not provided', async () => {
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
mockReinitMCPServer.mockResolvedValue({
|
|
tools: [],
|
|
availableTools: {},
|
|
});
|
|
|
|
// Call without user should throw error
|
|
await expect(
|
|
createMCPTools({
|
|
res: mockRes,
|
|
user: undefined,
|
|
serverName: 'test-server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
}),
|
|
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
|
|
|
|
// Verify reinitMCPServer was not called due to early error
|
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('createMCPTool', () => {
|
|
it('should pass user parameter to reinitMCPServer when tool not in cache', async () => {
|
|
const mockUser = { id: 'test-user-456', email: 'test@example.com' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
const mockSignal = new AbortController().signal;
|
|
|
|
mockReinitMCPServer.mockResolvedValue({
|
|
availableTools: {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
},
|
|
});
|
|
|
|
// Call without availableTools to trigger reinit
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
signal: mockSignal,
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined, // Force reinit
|
|
});
|
|
|
|
// Verify reinitMCPServer was called with the user
|
|
expect(mockReinitMCPServer).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
user: mockUser,
|
|
serverName: 'test-server',
|
|
}),
|
|
);
|
|
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
|
|
});
|
|
|
|
it('should not call reinitMCPServer when tool is in cache', async () => {
|
|
const mockUser = { id: 'test-user-789' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
const availableTools = {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Cached tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
};
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: availableTools,
|
|
});
|
|
|
|
// Verify reinitMCPServer was NOT called since tool was in cache
|
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('reinitMCPServer (via reconnectServer)', () => {
|
|
it('should always receive user parameter when called from createMCPTools', async () => {
|
|
const mockUser = { id: 'user-001', role: 'admin' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Track all calls to reinitMCPServer
|
|
const reinitCalls = [];
|
|
mockReinitMCPServer.mockImplementation((params) => {
|
|
reinitCalls.push(params);
|
|
return Promise.resolve({
|
|
tools: [{ name: 'tool1' }, { name: 'tool2' }],
|
|
availableTools: {
|
|
[`tool1${D}server1`]: { function: { description: 'Tool 1', parameters: {} } },
|
|
[`tool2${D}server1`]: { function: { description: 'Tool 2', parameters: {} } },
|
|
},
|
|
});
|
|
});
|
|
|
|
await createMCPTools({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
serverName: 'server1',
|
|
provider: 'anthropic',
|
|
userMCPAuthMap: {},
|
|
});
|
|
|
|
// Verify all calls to reinitMCPServer had the user
|
|
expect(reinitCalls.length).toBeGreaterThan(0);
|
|
reinitCalls.forEach((call) => {
|
|
expect(call.user).toBe(mockUser);
|
|
expect(call.user.id).toBe('user-001');
|
|
});
|
|
});
|
|
|
|
it('should always receive user parameter when called from createMCPTool', async () => {
|
|
const mockUser = { id: 'user-002', permissions: ['read', 'write'] };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Track all calls to reinitMCPServer
|
|
const reinitCalls = [];
|
|
mockReinitMCPServer.mockImplementation((params) => {
|
|
reinitCalls.push(params);
|
|
return Promise.resolve({
|
|
availableTools: {
|
|
[`my-tool${D}my-server`]: {
|
|
function: { description: 'My Tool', parameters: {} },
|
|
},
|
|
},
|
|
});
|
|
});
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `my-tool${D}my-server`,
|
|
provider: 'google',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined, // Force reinit
|
|
});
|
|
|
|
// Verify the call to reinitMCPServer had the user
|
|
expect(reinitCalls.length).toBe(1);
|
|
expect(reinitCalls[0].user).toBe(mockUser);
|
|
expect(reinitCalls[0].user.id).toBe('user-002');
|
|
});
|
|
});
|
|
|
|
describe('Runtime domain validation', () => {
|
|
it('should skip tool creation when domain is not allowed', async () => {
|
|
const mockUser = { id: 'domain-test-user', role: 'user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Mock server config with URL (remote server)
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
|
url: 'https://disallowed-domain.com/sse',
|
|
});
|
|
|
|
// Mock getAppConfig to return domain restrictions
|
|
mockGetAppConfig.mockResolvedValue({
|
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
|
});
|
|
|
|
// Mock domain validation to return false (domain not allowed)
|
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
|
|
|
const result = await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
},
|
|
});
|
|
|
|
// Should return undefined for disallowed domain
|
|
expect(result).toBeUndefined();
|
|
|
|
// Should not call reinitMCPServer since domain check failed
|
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
|
|
|
// Verify getAppConfig was called with user role
|
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
|
|
|
// Verify domain validation was called with correct parameters
|
|
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
|
{ url: 'https://disallowed-domain.com/sse' },
|
|
['allowed-domain.com'],
|
|
undefined,
|
|
);
|
|
});
|
|
|
|
it('should allow tool creation when domain is allowed', async () => {
|
|
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Mock server config with URL (remote server)
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
|
url: 'https://allowed-domain.com/sse',
|
|
});
|
|
|
|
// Mock getAppConfig to return domain restrictions
|
|
mockGetAppConfig.mockResolvedValue({
|
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
|
});
|
|
|
|
// Mock domain validation to return true (domain allowed)
|
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
|
|
|
const availableTools = {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
};
|
|
|
|
const result = await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools,
|
|
});
|
|
|
|
// Should create tool successfully
|
|
expect(result).toBeDefined();
|
|
|
|
// Verify getAppConfig was called with user role
|
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
|
});
|
|
|
|
it('should skip domain validation for stdio transports (no URL)', async () => {
|
|
const mockUser = { id: 'stdio-test-user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Mock server config without URL (stdio transport)
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
|
command: 'npx',
|
|
args: ['@modelcontextprotocol/server'],
|
|
});
|
|
|
|
// Mock getAppConfig (should not be called for stdio)
|
|
mockGetAppConfig.mockResolvedValue({
|
|
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
|
});
|
|
|
|
const availableTools = {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
};
|
|
|
|
const result = await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools,
|
|
});
|
|
|
|
// Should create tool successfully without domain check
|
|
expect(result).toBeDefined();
|
|
|
|
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
|
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
|
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
|
const mockUser = { id: 'domain-test-user', role: 'user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// Mock server config with URL (remote server)
|
|
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
|
|
|
// Mock getAppConfig to return domain restrictions
|
|
mockGetAppConfig.mockResolvedValue({
|
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
|
});
|
|
|
|
// Mock domain validation to return false (domain not allowed)
|
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
|
|
|
const result = await createMCPTools({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
serverName: 'test-server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
config: serverConfig,
|
|
});
|
|
|
|
// Should return empty array for disallowed domain
|
|
expect(result).toEqual([]);
|
|
|
|
// Should not call reinitMCPServer since domain check failed early
|
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
|
|
|
// Verify getAppConfig was called with user role
|
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
|
});
|
|
|
|
it('should use user role when fetching domain restrictions', async () => {
|
|
const adminUser = { id: 'admin-user', role: 'admin' };
|
|
const regularUser = { id: 'regular-user', role: 'user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
|
url: 'https://some-domain.com/sse',
|
|
});
|
|
|
|
// Mock different responses based on role
|
|
mockGetAppConfig
|
|
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
|
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
|
|
|
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
|
|
|
const availableTools = {
|
|
[`test-tool${D}test-server`]: {
|
|
function: {
|
|
description: 'Test tool',
|
|
parameters: { type: 'object', properties: {} },
|
|
},
|
|
},
|
|
};
|
|
|
|
// Call with admin user
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: adminUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools,
|
|
});
|
|
|
|
// Reset and call with regular user
|
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
|
url: 'https://some-domain.com/sse',
|
|
});
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: regularUser,
|
|
toolKey: `test-tool${D}test-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools,
|
|
});
|
|
|
|
// Verify getAppConfig was called with correct roles
|
|
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
|
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
|
});
|
|
});
|
|
|
|
describe('createUnavailableToolStub', () => {
|
|
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
|
|
const stub = createUnavailableToolStub('myTool', 'myServer');
|
|
// invoke() goes through langchain's base tool, which checks responseFormat.
|
|
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
|
|
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
|
|
const result = await stub.invoke({});
|
|
// If we reach here without throwing, the two-tuple format is correct.
|
|
// invoke() returns the content portion of [content, artifact] as a string.
|
|
expect(result).toContain('temporarily unavailable');
|
|
});
|
|
});
|
|
|
|
describe('negative tool cache and throttle interaction', () => {
|
|
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
|
|
const mockUser = { id: 'throttle-test-user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// First call: reconnect succeeds but tool not found
|
|
mockReinitMCPServer.mockResolvedValueOnce({
|
|
availableTools: {},
|
|
});
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `missing-tool${D}cache-dedup-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
// Second call within 10s for DIFFERENT tool on same server:
|
|
// reconnect is throttled (returns null), tool is still cached as missing.
|
|
// This is intentional: the cache acts as cross-user dedup since the
|
|
// throttle is per-user-per-server and can't prevent N different users
|
|
// from each triggering their own reconnect.
|
|
const result2 = await createMCPTool({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
toolKey: `other-tool${D}cache-dedup-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
expect(result2).toBeDefined();
|
|
expect(result2.name).toContain('other-tool');
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
|
|
const userA = { id: 'cache-user-A' };
|
|
const userB = { id: 'cache-user-B' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// User A: real reconnect, tool not found → cached
|
|
mockReinitMCPServer.mockResolvedValueOnce({
|
|
availableTools: {},
|
|
});
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: userA,
|
|
toolKey: `shared-tool${D}cross-user-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
|
|
// User B requests the SAME tool within 10s.
|
|
// The negative cache is keyed by toolKey (no user prefix), so user B
|
|
// gets a cache hit and no reconnect fires. This is the cross-user
|
|
// storm protection: without this, user B's unthrottled first request
|
|
// would trigger a second reconnect to the same server.
|
|
const result = await createMCPTool({
|
|
res: mockRes,
|
|
user: userB,
|
|
toolKey: `shared-tool${D}cross-user-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
expect(result).toBeDefined();
|
|
expect(result.name).toContain('shared-tool');
|
|
// reinitMCPServer still called only once — user B hit the cache
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
|
|
const userA = { id: 'storm-user-A' };
|
|
const userB = { id: 'storm-user-B' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// User A: real reconnect for tool-1, tool not found → cached
|
|
mockReinitMCPServer.mockResolvedValueOnce({
|
|
availableTools: {},
|
|
});
|
|
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: userA,
|
|
toolKey: `tool-1${D}storm-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
// User A: tool-2 on same server within 10s → throttled → cached from throttle
|
|
await createMCPTool({
|
|
res: mockRes,
|
|
user: userA,
|
|
toolKey: `tool-2${D}storm-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
|
|
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
|
|
// Without this caching, user B would trigger a real reconnect since
|
|
// user B has their own throttle key and hasn't reconnected yet.
|
|
const result = await createMCPTool({
|
|
res: mockRes,
|
|
user: userB,
|
|
toolKey: `tool-2${D}storm-server`,
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
availableTools: undefined,
|
|
});
|
|
|
|
expect(result).toBeDefined();
|
|
expect(result.name).toContain('tool-2');
|
|
// Still only 1 real reconnect — user B was protected by the cache
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
});
|
|
});
|
|
|
|
describe('createMCPTools throttle handling', () => {
|
|
it('should return empty array with debug log when reconnect is throttled', async () => {
|
|
const mockUser = { id: 'throttle-tools-user' };
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
// First call: real reconnect
|
|
mockReinitMCPServer.mockResolvedValueOnce({
|
|
tools: [{ name: 'tool1' }],
|
|
availableTools: {
|
|
[`tool1${D}throttle-tools-server`]: {
|
|
function: { description: 'Tool 1', parameters: {} },
|
|
},
|
|
},
|
|
});
|
|
|
|
await createMCPTools({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
serverName: 'throttle-tools-server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
});
|
|
|
|
// Second call within 10s — throttled
|
|
const result = await createMCPTools({
|
|
res: mockRes,
|
|
user: mockUser,
|
|
serverName: 'throttle-tools-server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
});
|
|
|
|
expect(result).toEqual([]);
|
|
// reinitMCPServer called only once — second was throttled
|
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
|
// Should log at debug level (not warn) for throttled case
|
|
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
|
|
});
|
|
});
|
|
|
|
describe('User parameter integrity', () => {
|
|
it('should preserve user object properties through the call chain', async () => {
|
|
const complexUser = {
|
|
id: 'complex-user',
|
|
name: 'John Doe',
|
|
email: 'john@example.com',
|
|
metadata: { subscription: 'premium', settings: { theme: 'dark' } },
|
|
};
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
let capturedUser = null;
|
|
mockReinitMCPServer.mockImplementation((params) => {
|
|
capturedUser = params.user;
|
|
return Promise.resolve({
|
|
tools: [{ name: 'test' }],
|
|
availableTools: {
|
|
[`test${D}server`]: { function: { description: 'Test', parameters: {} } },
|
|
},
|
|
});
|
|
});
|
|
|
|
await createMCPTools({
|
|
res: mockRes,
|
|
user: complexUser,
|
|
serverName: 'server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
});
|
|
|
|
// Verify the complete user object was passed
|
|
expect(capturedUser).toEqual(complexUser);
|
|
expect(capturedUser.id).toBe('complex-user');
|
|
expect(capturedUser.metadata.subscription).toBe('premium');
|
|
expect(capturedUser.metadata.settings.theme).toBe('dark');
|
|
});
|
|
|
|
it('should throw error when user is null', async () => {
|
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
|
|
|
mockReinitMCPServer.mockResolvedValue({
|
|
tools: [],
|
|
availableTools: {},
|
|
});
|
|
|
|
await expect(
|
|
createMCPTools({
|
|
res: mockRes,
|
|
user: null,
|
|
serverName: 'test-server',
|
|
provider: 'openai',
|
|
userMCPAuthMap: {},
|
|
}),
|
|
).rejects.toThrow("Cannot read properties of null (reading 'id')");
|
|
|
|
// Verify reinitMCPServer was not called due to early error
|
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
});
|