mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-09 17:31:19 +00:00
🦣 fix: Response Size Limits for Streamable HTTP MCP Responses (#13219)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
GitNexus Index / index (push) Waiting to run
GitNexus Index / post-index (push) Blocked by required conditions
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
GitNexus Index / index (push) Waiting to run
GitNexus Index / post-index (push) Blocked by required conditions
* fix: implement response size limits for streamable HTTP MCP responses - Added environment variables for maximum response and line sizes in streamable HTTP responses. - Introduced functions to handle response size validation and error logging. - Updated MCP connection logic to enforce these limits, ensuring safe handling of large responses. * fix: address MCP response guard review findings * fix: satisfy logger spy typings * fix: clarify blocked MCP response errors * fix: harden MCP guard review edge cases * test: cover MCP oversized SSE call failures
This commit is contained in:
parent
830d124e4d
commit
1ed84ee4eb
5 changed files with 786 additions and 6 deletions
|
|
@ -893,6 +893,14 @@ OPENWEATHER_API_KEY=
|
|||
# Cache connection status checks for this many milliseconds to avoid expensive verification
|
||||
# MCP_CONNECTION_CHECK_TTL=60000
|
||||
|
||||
# Max bytes allowed in a non-GET streamable HTTP MCP response before rejecting it.
|
||||
# Set to 0 to disable. Default: 16777216 (16 MiB)
|
||||
# MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES=16777216
|
||||
|
||||
# Max bytes allowed in a single SSE line for non-GET streamable HTTP MCP responses.
|
||||
# Set to 0 to disable. Default: 1048576 (1 MiB)
|
||||
# MCP_STREAMABLE_HTTP_MAX_LINE_BYTES=1048576
|
||||
|
||||
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
||||
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
||||
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { Constants } from '@librechat/agents';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type {
|
||||
ToolExecuteBatchRequest,
|
||||
ToolExecuteResult,
|
||||
|
|
@ -278,6 +279,128 @@ describe('createToolExecuteHandler', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('tool error handling', () => {
|
||||
it('truncates oversized tool errors in the result and log context', async () => {
|
||||
const oversizedMessage = `tool failed: ${'x'.repeat(15_000)}`;
|
||||
const thrown = new Error(oversizedMessage);
|
||||
thrown.stack = `Error: ${oversizedMessage}\n${'stack-line\n'.repeat(600)}`;
|
||||
const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({
|
||||
loadedTools: [
|
||||
{
|
||||
name: 'bad_tool',
|
||||
invoke: jest.fn(async () => {
|
||||
throw thrown;
|
||||
}),
|
||||
},
|
||||
] as never[],
|
||||
}));
|
||||
const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger);
|
||||
try {
|
||||
const handler = createToolExecuteHandler({ loadTools });
|
||||
const [result] = await invokeHandler(handler, [
|
||||
{
|
||||
id: 'call_bad',
|
||||
name: 'bad_tool',
|
||||
args: {},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(result.status).toBe('error');
|
||||
expect(result.errorMessage).toContain('truncated');
|
||||
expect(result.errorMessage!.length).toBeLessThanOrEqual(12_000);
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'[ON_TOOL_EXECUTE] Tool bad_tool error',
|
||||
expect.objectContaining({
|
||||
messageTruncated: true,
|
||||
messageLength: oversizedMessage.length,
|
||||
}),
|
||||
);
|
||||
const [, logContext] = errorSpy.mock.calls[0] as unknown as [string, { stack?: string }];
|
||||
expect(logContext.stack!.length).toBeLessThanOrEqual(4_000);
|
||||
} finally {
|
||||
errorSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it('returns a per-tool error when thrown value stringification fails', async () => {
|
||||
const thrown = {
|
||||
toString() {
|
||||
throw new Error('toString failed');
|
||||
},
|
||||
};
|
||||
const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({
|
||||
loadedTools: [
|
||||
{
|
||||
name: 'bad_to_string_tool',
|
||||
invoke: jest.fn(async () => {
|
||||
throw thrown;
|
||||
}),
|
||||
},
|
||||
] as never[],
|
||||
}));
|
||||
const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger);
|
||||
try {
|
||||
const handler = createToolExecuteHandler({ loadTools });
|
||||
const [result] = await invokeHandler(handler, [
|
||||
{
|
||||
id: 'call_bad_to_string',
|
||||
name: 'bad_to_string_tool',
|
||||
args: {},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(result.status).toBe('error');
|
||||
expect(result.errorMessage).toBe('[Thrown value could not be converted to string]');
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'[ON_TOOL_EXECUTE] Tool bad_to_string_tool error',
|
||||
expect.objectContaining({
|
||||
name: 'object',
|
||||
messageTruncated: false,
|
||||
}),
|
||||
);
|
||||
} finally {
|
||||
errorSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it('preserves message from thrown plain objects', async () => {
|
||||
const thrown = { message: 'plain object timeout' };
|
||||
const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({
|
||||
loadedTools: [
|
||||
{
|
||||
name: 'plain_object_tool',
|
||||
invoke: jest.fn(async () => {
|
||||
throw thrown;
|
||||
}),
|
||||
},
|
||||
] as never[],
|
||||
}));
|
||||
const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger);
|
||||
try {
|
||||
const handler = createToolExecuteHandler({ loadTools });
|
||||
const [result] = await invokeHandler(handler, [
|
||||
{
|
||||
id: 'call_plain_object',
|
||||
name: 'plain_object_tool',
|
||||
args: {},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(result.status).toBe('error');
|
||||
expect(result.errorMessage).toBe('plain object timeout');
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'[ON_TOOL_EXECUTE] Tool plain_object_tool error',
|
||||
expect.objectContaining({
|
||||
message: 'plain object timeout',
|
||||
messageTruncated: false,
|
||||
}),
|
||||
);
|
||||
} finally {
|
||||
errorSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('skill tool model-invocation gate', () => {
|
||||
function createSkillHandler(getSkillByName: ToolExecuteOptions['getSkillByName']) {
|
||||
const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({
|
||||
|
|
|
|||
|
|
@ -154,9 +154,77 @@ export interface ToolExecuteOptions {
|
|||
const MAX_READABLE_BYTES = 262_144;
|
||||
const MAX_BINARY_BYTES = 5 * 1024 * 1024;
|
||||
const MAX_CACHE_BYTES = 512 * 1024;
|
||||
const MAX_TOOL_ERROR_MESSAGE_CHARS = 12_000;
|
||||
const MAX_TOOL_ERROR_STACK_CHARS = 4_000;
|
||||
|
||||
const IMAGE_MIMES = new Set(['image/png', 'image/jpeg', 'image/gif', 'image/webp']);
|
||||
|
||||
function truncateMiddle(value: string, maxChars: number): string {
|
||||
if (value.length <= maxChars) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const indicator = `\n\n... [truncated: ${value.length} chars exceeded ${maxChars} limit] ...\n\n`;
|
||||
const available = maxChars - indicator.length;
|
||||
if (available <= 0) {
|
||||
return value.slice(0, maxChars);
|
||||
}
|
||||
|
||||
const headSize = Math.ceil(available * 0.7);
|
||||
const tailSize = available - headSize;
|
||||
return value.slice(0, headSize) + indicator + value.slice(value.length - tailSize);
|
||||
}
|
||||
|
||||
function stringifyThrownValue(error: unknown): string {
|
||||
try {
|
||||
return String(error);
|
||||
} catch {
|
||||
return '[Thrown value could not be converted to string]';
|
||||
}
|
||||
}
|
||||
|
||||
function getThrownValueMessage(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (error != null && typeof error === 'object') {
|
||||
try {
|
||||
const message = (error as { message?: unknown }).message;
|
||||
if (typeof message === 'string') {
|
||||
return message;
|
||||
}
|
||||
if (message != null) {
|
||||
return stringifyThrownValue(message);
|
||||
}
|
||||
} catch {
|
||||
// Fall through to whole-value stringification.
|
||||
}
|
||||
}
|
||||
|
||||
return stringifyThrownValue(error);
|
||||
}
|
||||
|
||||
function getSafeToolError(error: unknown): {
|
||||
message: string;
|
||||
logContext: Record<string, unknown>;
|
||||
} {
|
||||
const rawMessage = getThrownValueMessage(error);
|
||||
const message = truncateMiddle(rawMessage, MAX_TOOL_ERROR_MESSAGE_CHARS);
|
||||
const stack = error instanceof Error && error.stack ? error.stack : undefined;
|
||||
|
||||
return {
|
||||
message,
|
||||
logContext: {
|
||||
name: error instanceof Error ? error.name : typeof error,
|
||||
message,
|
||||
messageLength: rawMessage.length,
|
||||
messageTruncated: message.length !== rawMessage.length,
|
||||
stack: stack ? truncateMiddle(stack, MAX_TOOL_ERROR_STACK_CHARS) : undefined,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function addLineNumbers(content: string): string {
|
||||
const lines = content.split('\n');
|
||||
const w = String(lines.length).length;
|
||||
|
|
@ -1195,13 +1263,13 @@ export function createToolExecuteHandler(options: ToolExecuteOptions): EventHand
|
|||
status: 'success' as const,
|
||||
};
|
||||
} catch (toolError) {
|
||||
const error = toolError as Error;
|
||||
logger.error(`[ON_TOOL_EXECUTE] Tool ${tc.name} error:`, error);
|
||||
const { message, logContext } = getSafeToolError(toolError);
|
||||
logger.error(`[ON_TOOL_EXECUTE] Tool ${tc.name} error`, logContext);
|
||||
return {
|
||||
toolCallId: tc.id,
|
||||
status: 'error' as const,
|
||||
content: '',
|
||||
errorMessage: error.message,
|
||||
errorMessage: message,
|
||||
};
|
||||
}
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -258,6 +258,45 @@ async function createStreamableServer(): Promise<Omit<TestServer, 'redirectHit'>
|
|||
};
|
||||
}
|
||||
|
||||
async function createOversizedToolResultStreamableServer(
|
||||
payloadSize: number,
|
||||
): Promise<Omit<TestServer, 'redirectHit'>> {
|
||||
const sessions = new Map<string, StreamableHTTPServerTransport>();
|
||||
|
||||
const httpServer = http.createServer(async (req, res) => {
|
||||
const sid = req.headers['mcp-session-id'] as string | undefined;
|
||||
let transport = sid ? sessions.get(sid) : undefined;
|
||||
|
||||
if (!transport) {
|
||||
transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() });
|
||||
const mcp = new McpServer({ name: 'oversized-tool-result', version: '0.0.1' });
|
||||
mcp.tool('oversized', 'Returns an oversized text payload', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'x'.repeat(payloadSize) }],
|
||||
}));
|
||||
await mcp.connect(transport);
|
||||
}
|
||||
|
||||
await transport.handleRequest(req, res);
|
||||
|
||||
if (transport.sessionId && !sessions.has(transport.sessionId)) {
|
||||
sessions.set(transport.sessionId, transport);
|
||||
transport.onclose = () => sessions.delete(transport!.sessionId!);
|
||||
}
|
||||
});
|
||||
|
||||
const destroySockets = trackSockets(httpServer);
|
||||
const port = await getFreePort();
|
||||
await new Promise<void>((resolve) => httpServer.listen(port, '127.0.0.1', resolve));
|
||||
|
||||
return {
|
||||
url: `http://127.0.0.1:${port}/`,
|
||||
close: async () => {
|
||||
await closeMCPSessions(sessions);
|
||||
await destroySockets();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('MCP SSRF protection – redirect blocking', () => {
|
||||
let redirectServer: TestServer;
|
||||
let conn: MCPConnection | null;
|
||||
|
|
@ -457,6 +496,11 @@ interface HeaderCaptureServer {
|
|||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
interface RawResponseServer {
|
||||
url: string;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Captures every incoming request's headers, method, and body, then replies
|
||||
* with a benign 200 so tests can assert what actually crossed a redirect
|
||||
|
|
@ -490,6 +534,17 @@ async function createHeaderCaptureServer(): Promise<HeaderCaptureServer> {
|
|||
};
|
||||
}
|
||||
|
||||
async function createRawResponseServer(handler: http.RequestListener): Promise<RawResponseServer> {
|
||||
const server = http.createServer(handler);
|
||||
const destroySockets = trackSockets(server);
|
||||
const port = await getFreePort();
|
||||
await new Promise<void>((resolve) => server.listen(port, '127.0.0.1', resolve));
|
||||
return {
|
||||
url: `http://127.0.0.1:${port}/`,
|
||||
close: destroySockets,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Issues a single 307/308 redirect to `redirectTarget` (which lives on a
|
||||
* different port and is therefore a different origin). Used to verify
|
||||
|
|
@ -823,8 +878,20 @@ describe('MCP SSRF protection – cross-origin credential stripping on redirect'
|
|||
describe('MCP SSRF protection – customFetch input shapes', () => {
|
||||
let target: Omit<TestServer, 'redirectHit'> | undefined;
|
||||
let conn: MCPConnection | null;
|
||||
const originalMaxResponseBytes = process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES;
|
||||
const originalMaxLineBytes = process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES;
|
||||
|
||||
afterEach(async () => {
|
||||
if (originalMaxResponseBytes == null) {
|
||||
delete process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES;
|
||||
} else {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = originalMaxResponseBytes;
|
||||
}
|
||||
if (originalMaxLineBytes == null) {
|
||||
delete process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES;
|
||||
} else {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = originalMaxLineBytes;
|
||||
}
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
if (target) {
|
||||
|
|
@ -846,12 +913,33 @@ describe('MCP SSRF protection – customFetch input shapes', () => {
|
|||
connection as unknown as {
|
||||
createFetchFunction: (
|
||||
getHeaders: () => Record<string, string> | null | undefined,
|
||||
timeout?: number,
|
||||
sseBodyTimeout?: number,
|
||||
configuredSecretHeaderKeys?: ReadonlySet<string>,
|
||||
baseUrl?: string,
|
||||
guardStreamableHTTPResponses?: boolean,
|
||||
) => CustomFetch;
|
||||
}
|
||||
).createFetchFunction;
|
||||
return factory.call(connection, () => null);
|
||||
}
|
||||
|
||||
function getGuardedStreamableHTTPCustomFetch(connection: MCPConnection): CustomFetch {
|
||||
const factory = (
|
||||
connection as unknown as {
|
||||
createFetchFunction: (
|
||||
getHeaders: () => Record<string, string> | null | undefined,
|
||||
timeout?: number,
|
||||
sseBodyTimeout?: number,
|
||||
configuredSecretHeaderKeys?: ReadonlySet<string>,
|
||||
baseUrl?: string,
|
||||
guardStreamableHTTPResponses?: boolean,
|
||||
) => CustomFetch;
|
||||
}
|
||||
).createFetchFunction;
|
||||
return factory.call(connection, () => null, undefined, undefined, undefined, undefined, true);
|
||||
}
|
||||
|
||||
it.each<['string' | 'URL' | 'Request']>([['string'], ['URL'], ['Request']])(
|
||||
'should accept a %s input without throwing on URL derivation',
|
||||
async (shape) => {
|
||||
|
|
@ -1015,6 +1103,142 @@ describe('MCP SSRF protection – customFetch input shapes', () => {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('should not apply streamable HTTP response caps unless the transport opts in', async () => {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = '8';
|
||||
const server = await createRawResponseServer((_req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end('{"jsonrpc":"2.0","id":1,"result":{"too":"large"}}');
|
||||
});
|
||||
try {
|
||||
conn = new MCPConnection({
|
||||
serverName: 'customfetch-unguarded-byte-limit',
|
||||
serverConfig: { type: 'sse', url: server.url },
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
|
||||
const customFetch = getCustomFetch(conn);
|
||||
const response = await customFetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }),
|
||||
});
|
||||
|
||||
await expect(response.text()).resolves.toContain('"too":"large"');
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should reject oversized JSON POST responses with the streamable HTTP byte cap', async () => {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = '8';
|
||||
const server = await createRawResponseServer((_req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end('{"jsonrpc":"2.0","id":1,"result":{"too":"large"}}');
|
||||
});
|
||||
try {
|
||||
conn = new MCPConnection({
|
||||
serverName: 'customfetch-json-byte-limit',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
|
||||
const customFetch = getGuardedStreamableHTTPCustomFetch(conn);
|
||||
const response = await customFetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }),
|
||||
});
|
||||
|
||||
await expect(response.text()).rejects.toThrow(
|
||||
/MCP response exceeded byte limit.*limit=8 bytes/,
|
||||
);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should reject a POST response with an oversized SSE line before the SSE parser can grow it', async () => {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '16';
|
||||
const server = await createRawResponseServer((_req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/event-stream' });
|
||||
res.end(`data: ${'x'.repeat(64)}\n\n`);
|
||||
});
|
||||
try {
|
||||
conn = new MCPConnection({
|
||||
serverName: 'customfetch-sse-line-limit',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
|
||||
const customFetch = getGuardedStreamableHTTPCustomFetch(conn);
|
||||
const response = await customFetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'notifications/cancelled' }),
|
||||
});
|
||||
await expect(response.text()).rejects.toThrow(
|
||||
/MCP response contained an oversized SSE line.*lineLimit=16 bytes.*observedLine=17 bytes/,
|
||||
);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should fail an actual streamable HTTP tool call promptly with a clear oversized SSE line error', async () => {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '512';
|
||||
target = await createOversizedToolResultStreamableServer(2048);
|
||||
conn = new MCPConnection({
|
||||
serverName: 'streamable-http-tool-call-sse-line-limit',
|
||||
serverConfig: { type: 'streamable-http', url: target.url },
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
|
||||
await conn.connect();
|
||||
const startedAt = Date.now();
|
||||
|
||||
await expect(
|
||||
conn.client.callTool({ name: 'oversized', arguments: {} }, undefined, { timeout: 3000 }),
|
||||
).rejects.toThrow(/MCP response contained an oversized SSE line/);
|
||||
expect(Date.now() - startedAt).toBeLessThan(1500);
|
||||
});
|
||||
|
||||
it('should stream valid SSE POST responses without waiting for EOF', async () => {
|
||||
process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '4096';
|
||||
let finish!: () => void;
|
||||
const finished = new Promise<void>((resolve) => {
|
||||
finish = resolve;
|
||||
});
|
||||
const server = await createRawResponseServer((_req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/event-stream' });
|
||||
res.write('data: {"jsonrpc":"2.0","id":1,"result":{}}\n\n');
|
||||
finished.then(() => res.end()).catch(() => res.end());
|
||||
});
|
||||
try {
|
||||
conn = new MCPConnection({
|
||||
serverName: 'customfetch-sse-streaming',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
|
||||
const customFetch = getGuardedStreamableHTTPCustomFetch(conn);
|
||||
const response = await customFetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }),
|
||||
});
|
||||
const reader = response.body!.getReader();
|
||||
const { value, done } = await reader.read();
|
||||
|
||||
expect(done).toBe(false);
|
||||
expect(Buffer.from(value as Uint8Array).toString('utf8')).toContain('"result":{}');
|
||||
await reader.cancel().catch(() => undefined);
|
||||
finish();
|
||||
} finally {
|
||||
finish();
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('MCP SSRF protection – WebSocket DNS resolution', () => {
|
||||
|
|
|
|||
|
|
@ -74,6 +74,351 @@ const SSE_CONNECT_TIMEOUT = 120000;
|
|||
const DEFAULT_INIT_TIMEOUT = 30000;
|
||||
/** Max 307/308 redirects to follow per request (prevents redirect loops) */
|
||||
const MAX_REDIRECTS = 5;
|
||||
const DEFAULT_MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = 16 * 1024 * 1024;
|
||||
const DEFAULT_MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = 1024 * 1024;
|
||||
|
||||
function getNonNegativeIntegerEnv(name: string, defaultValue: number): number {
|
||||
const raw = process.env[name];
|
||||
if (raw == null || raw.trim() === '') {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
const trimmed = raw.trim();
|
||||
if (!/^\d+$/.test(trimmed)) {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
const parsed = Number(trimmed);
|
||||
return Number.isSafeInteger(parsed) ? parsed : defaultValue;
|
||||
}
|
||||
|
||||
function bytesToMiB(bytes: number): string {
|
||||
return `${(bytes / 1024 / 1024).toFixed(2)} MiB`;
|
||||
}
|
||||
|
||||
function getMemoryDebugSnapshot(): Record<string, string> {
|
||||
const mem = process.memoryUsage();
|
||||
return {
|
||||
rss: bytesToMiB(mem.rss),
|
||||
heapUsed: bytesToMiB(mem.heapUsed),
|
||||
heapTotal: bytesToMiB(mem.heapTotal),
|
||||
external: bytesToMiB(mem.external),
|
||||
arrayBuffers: bytesToMiB(mem.arrayBuffers ?? 0),
|
||||
};
|
||||
}
|
||||
|
||||
const textEncoder = new TextEncoder();
|
||||
const textDecoder = new TextDecoder();
|
||||
type JSONRPCRequestId = string | number;
|
||||
|
||||
function getChunkBytes(chunk: unknown): Uint8Array {
|
||||
if (typeof chunk === 'string') {
|
||||
return textEncoder.encode(chunk);
|
||||
}
|
||||
if (chunk instanceof ArrayBuffer) {
|
||||
return new Uint8Array(chunk);
|
||||
}
|
||||
if (ArrayBuffer.isView(chunk)) {
|
||||
const view = new Uint8Array(chunk.buffer, chunk.byteOffset, chunk.byteLength);
|
||||
return new Uint8Array(view);
|
||||
}
|
||||
return new Uint8Array();
|
||||
}
|
||||
|
||||
function copyBytes(bytes: Uint8Array): Uint8Array {
|
||||
const copy = new Uint8Array(bytes.byteLength);
|
||||
copy.set(bytes);
|
||||
return copy;
|
||||
}
|
||||
|
||||
function concatBytes(chunks: Uint8Array[]): Uint8Array {
|
||||
if (chunks.length === 0) {
|
||||
return new Uint8Array();
|
||||
}
|
||||
if (chunks.length === 1) {
|
||||
return chunks[0];
|
||||
}
|
||||
const totalLength = chunks.reduce((total, chunk) => total + chunk.byteLength, 0);
|
||||
const combined = new Uint8Array(totalLength);
|
||||
let offset = 0;
|
||||
for (const chunk of chunks) {
|
||||
combined.set(chunk, offset);
|
||||
offset += chunk.byteLength;
|
||||
}
|
||||
return combined;
|
||||
}
|
||||
|
||||
function getBodyText(body: unknown): string | null {
|
||||
if (typeof body === 'string') {
|
||||
return body;
|
||||
}
|
||||
if (body instanceof ArrayBuffer) {
|
||||
return textDecoder.decode(new Uint8Array(body));
|
||||
}
|
||||
if (ArrayBuffer.isView(body)) {
|
||||
return textDecoder.decode(new Uint8Array(body.buffer, body.byteOffset, body.byteLength));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function getJSONRPCRequestIds(body: unknown): JSONRPCRequestId[] {
|
||||
const bodyText = getBodyText(body);
|
||||
if (!bodyText) {
|
||||
return [];
|
||||
}
|
||||
|
||||
let parsed: unknown;
|
||||
try {
|
||||
parsed = JSON.parse(bodyText);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
|
||||
const messages = Array.isArray(parsed) ? parsed : [parsed];
|
||||
return messages.flatMap((message) => {
|
||||
if (!message || typeof message !== 'object') {
|
||||
return [];
|
||||
}
|
||||
const jsonrpcMessage = message as { id?: unknown; method?: unknown };
|
||||
const { id } = jsonrpcMessage;
|
||||
if (typeof jsonrpcMessage.method !== 'string') {
|
||||
return [];
|
||||
}
|
||||
if (typeof id !== 'string' && typeof id !== 'number') {
|
||||
return [];
|
||||
}
|
||||
return [id];
|
||||
});
|
||||
}
|
||||
|
||||
function buildBlockedMCPResponseMessage(
|
||||
reason: string,
|
||||
details: {
|
||||
maxResponseBytes: number;
|
||||
maxLineBytes: number;
|
||||
totalBytes: number;
|
||||
currentLineBytes: number;
|
||||
chunkCount: number;
|
||||
},
|
||||
): string {
|
||||
const limitDetails =
|
||||
reason === 'MCP response exceeded byte limit'
|
||||
? `limit=${details.maxResponseBytes} bytes, observed=${details.totalBytes} bytes`
|
||||
: `lineLimit=${details.maxLineBytes} bytes, observedLine=${details.currentLineBytes} bytes, observedTotal=${details.totalBytes} bytes`;
|
||||
|
||||
return `[MCP] ${reason} (${limitDetails}, chunks=${details.chunkCount}). The MCP server returned an unsafe streamable HTTP response; narrow the tool result or retry after the server response is fixed.`;
|
||||
}
|
||||
|
||||
function buildBlockedMCPResponseSSE(requestIds: JSONRPCRequestId[], message: string): Uint8Array {
|
||||
const events = requestIds
|
||||
.map((id) => {
|
||||
const payload = {
|
||||
jsonrpc: '2.0',
|
||||
id,
|
||||
error: {
|
||||
code: -32000,
|
||||
message,
|
||||
},
|
||||
};
|
||||
return `data: ${JSON.stringify(payload)}\n\n`;
|
||||
})
|
||||
.join('');
|
||||
return textEncoder.encode(events);
|
||||
}
|
||||
|
||||
function getMCPStreamableHTTPResponseLimits(): {
|
||||
maxResponseBytes: number;
|
||||
maxLineBytes: number;
|
||||
} {
|
||||
return {
|
||||
maxResponseBytes: getNonNegativeIntegerEnv(
|
||||
'MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES',
|
||||
DEFAULT_MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES,
|
||||
),
|
||||
maxLineBytes: getNonNegativeIntegerEnv(
|
||||
'MCP_STREAMABLE_HTTP_MAX_LINE_BYTES',
|
||||
DEFAULT_MCP_STREAMABLE_HTTP_MAX_LINE_BYTES,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
async function guardMCPStreamableHTTPResponse(
|
||||
response: UndiciResponse,
|
||||
context: {
|
||||
logPrefix: string;
|
||||
method: string;
|
||||
url: string;
|
||||
requestIds?: JSONRPCRequestId[];
|
||||
},
|
||||
): Promise<UndiciResponse> {
|
||||
if (context.method === 'GET' || !response.body) {
|
||||
return response;
|
||||
}
|
||||
|
||||
const contentType = response.headers.get('content-type') ?? '';
|
||||
const isEventStream = contentType.toLowerCase().includes('text/event-stream');
|
||||
const { maxResponseBytes, maxLineBytes } = getMCPStreamableHTTPResponseLimits();
|
||||
const canEmitFallbackSSEError = isEventStream && maxLineBytes > 0;
|
||||
if (!isEventStream && maxResponseBytes === 0) {
|
||||
return response;
|
||||
}
|
||||
if (maxResponseBytes === 0 && maxLineBytes === 0) {
|
||||
return response;
|
||||
}
|
||||
|
||||
let totalBytes = 0;
|
||||
let currentLineBytes = 0;
|
||||
let chunkCount = 0;
|
||||
let pendingSSELineChunks: Uint8Array[] = [];
|
||||
const sseEventDataLines: string[] = [];
|
||||
const unresolvedRequestIds = new Set(context.requestIds ?? []);
|
||||
|
||||
const buildAndLogBlockedError = (
|
||||
reason: string,
|
||||
details: Record<string, unknown>,
|
||||
): Error => {
|
||||
const message = buildBlockedMCPResponseMessage(reason, {
|
||||
maxResponseBytes,
|
||||
maxLineBytes,
|
||||
totalBytes,
|
||||
currentLineBytes,
|
||||
chunkCount,
|
||||
});
|
||||
logger.warn(`${context.logPrefix} MCP streamable HTTP response blocked: ${reason}`, {
|
||||
method: context.method,
|
||||
url: sanitizeUrlForLogging(context.url),
|
||||
status: response.status,
|
||||
contentType,
|
||||
maxResponseBytes,
|
||||
maxLineBytes,
|
||||
totalBytes,
|
||||
currentLineBytes,
|
||||
chunkCount,
|
||||
...details,
|
||||
memory: getMemoryDebugSnapshot(),
|
||||
});
|
||||
return new Error(message);
|
||||
};
|
||||
|
||||
const trackSSELineForResolvedIds = (lineBytes: Uint8Array): void => {
|
||||
if (unresolvedRequestIds.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const rawLine = textDecoder.decode(lineBytes).replace(/[\r\n]+$/, '');
|
||||
if (rawLine === '') {
|
||||
if (sseEventDataLines.length === 0) {
|
||||
return;
|
||||
}
|
||||
const data = sseEventDataLines.join('\n');
|
||||
sseEventDataLines.length = 0;
|
||||
try {
|
||||
const parsed = JSON.parse(data) as { id?: unknown };
|
||||
if (typeof parsed.id === 'string' || typeof parsed.id === 'number') {
|
||||
unresolvedRequestIds.delete(parsed.id);
|
||||
}
|
||||
} catch {
|
||||
/** Ignore malformed SSE data here; the SDK parser will report it. */
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const separatorIndex = rawLine.indexOf(':');
|
||||
const field = separatorIndex === -1 ? rawLine : rawLine.slice(0, separatorIndex);
|
||||
if (field !== 'data') {
|
||||
return;
|
||||
}
|
||||
let value = separatorIndex === -1 ? '' : rawLine.slice(separatorIndex + 1);
|
||||
if (value.startsWith(' ')) {
|
||||
value = value.slice(1);
|
||||
}
|
||||
sseEventDataLines.push(value);
|
||||
};
|
||||
|
||||
const enqueuePendingSSELine = (controller: TransformStreamDefaultController<Uint8Array>) => {
|
||||
if (pendingSSELineChunks.length === 0) {
|
||||
return;
|
||||
}
|
||||
const lineBytes = concatBytes(pendingSSELineChunks);
|
||||
pendingSSELineChunks = [];
|
||||
trackSSELineForResolvedIds(lineBytes);
|
||||
controller.enqueue(lineBytes);
|
||||
};
|
||||
|
||||
const blockResponse = (
|
||||
controller: TransformStreamDefaultController<Uint8Array>,
|
||||
reason: string,
|
||||
details: Record<string, unknown>,
|
||||
) => {
|
||||
const error = buildAndLogBlockedError(reason, details);
|
||||
const fallbackRequestIds = [...unresolvedRequestIds];
|
||||
if (canEmitFallbackSSEError && fallbackRequestIds.length > 0) {
|
||||
controller.enqueue(buildBlockedMCPResponseSSE(fallbackRequestIds, error.message));
|
||||
controller.terminate();
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
};
|
||||
|
||||
const guardedBody = (response.body as unknown as ReadableStream<unknown>).pipeThrough(
|
||||
new TransformStream<unknown, Uint8Array>({
|
||||
transform(chunk, controller) {
|
||||
const bytes = getChunkBytes(chunk);
|
||||
if (bytes.byteLength === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
chunkCount += 1;
|
||||
totalBytes += bytes.byteLength;
|
||||
|
||||
if (maxResponseBytes > 0 && totalBytes > maxResponseBytes) {
|
||||
blockResponse(controller, 'MCP response exceeded byte limit', {
|
||||
chunkBytes: bytes.byteLength,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (isEventStream && maxLineBytes > 0) {
|
||||
let segmentStart = 0;
|
||||
for (let i = 0; i < bytes.byteLength; i++) {
|
||||
const byte = bytes[i];
|
||||
if (byte === 10 || byte === 13) {
|
||||
if (i + 1 > segmentStart) {
|
||||
pendingSSELineChunks.push(copyBytes(bytes.subarray(segmentStart, i + 1)));
|
||||
}
|
||||
enqueuePendingSSELine(controller);
|
||||
segmentStart = i + 1;
|
||||
currentLineBytes = 0;
|
||||
continue;
|
||||
}
|
||||
currentLineBytes += 1;
|
||||
if (currentLineBytes > maxLineBytes) {
|
||||
blockResponse(controller, 'MCP response contained an oversized SSE line', {
|
||||
chunkBytes: bytes.byteLength,
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (segmentStart < bytes.byteLength) {
|
||||
pendingSSELineChunks.push(copyBytes(bytes.subarray(segmentStart)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
controller.enqueue(bytes);
|
||||
},
|
||||
flush(controller) {
|
||||
enqueuePendingSSELine(controller);
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
return new Response(guardedBody as unknown as BodyInit, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: response.headers as unknown as HeadersInit,
|
||||
}) as unknown as UndiciResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Headers stripped before forwarding a request across an origin boundary on
|
||||
|
|
@ -579,6 +924,7 @@ export class MCPConnection extends EventEmitter {
|
|||
sseBodyTimeout?: number,
|
||||
configuredSecretHeaderKeys?: ReadonlySet<string>,
|
||||
baseUrl?: string,
|
||||
guardStreamableHTTPResponses = false,
|
||||
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
|
||||
const basePort = baseUrl ? getUrlPort(baseUrl) : '';
|
||||
const ssrfConnect = this.useSSRFProtection
|
||||
|
|
@ -587,6 +933,7 @@ export class MCPConnection extends EventEmitter {
|
|||
const connectOpts = ssrfConnect != null ? { connect: ssrfConnect } : {};
|
||||
/** Capture only the fields needed by the fetch closure; see factory note above. */
|
||||
const agents = this.agents;
|
||||
const logPrefix = this.getLogPrefix();
|
||||
const effectiveTimeout = timeout || DEFAULT_TIMEOUT;
|
||||
const postAgent = new Agent({
|
||||
bodyTimeout: effectiveTimeout,
|
||||
|
|
@ -664,18 +1011,27 @@ export class MCPConnection extends EventEmitter {
|
|||
let currentInit = buildFetchInit(resolvedInit, dispatcher, requestHeaders);
|
||||
let currentUrlString = urlString;
|
||||
const originalOrigin = new URL(currentUrlString).origin;
|
||||
|
||||
for (let redirects = 0; ; redirects++) {
|
||||
const response = await undiciFetch(currentUrlString, currentInit);
|
||||
const isMethodPreservingRedirect = response.status === 307 || response.status === 308;
|
||||
const responseContext = {
|
||||
logPrefix,
|
||||
method: (currentInit?.method ?? 'GET').toUpperCase(),
|
||||
url: currentUrlString,
|
||||
requestIds: getJSONRPCRequestIds(currentInit?.body),
|
||||
};
|
||||
|
||||
if (!isMethodPreservingRedirect || redirects >= MAX_REDIRECTS) {
|
||||
return response;
|
||||
return guardStreamableHTTPResponses
|
||||
? guardMCPStreamableHTTPResponse(response, responseContext)
|
||||
: response;
|
||||
}
|
||||
|
||||
const location = response.headers.get('location');
|
||||
if (!location) {
|
||||
return response;
|
||||
return guardStreamableHTTPResponses
|
||||
? guardMCPStreamableHTTPResponse(response, responseContext)
|
||||
: response;
|
||||
}
|
||||
|
||||
const targetUrl = new URL(location, currentUrlString);
|
||||
|
|
@ -904,6 +1260,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.sseReadTimeout || DEFAULT_SSE_READ_TIMEOUT,
|
||||
httpConfiguredSecretHeaderKeys,
|
||||
options.url,
|
||||
true,
|
||||
) as unknown as FetchLike,
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue