diff --git a/.env.example b/.env.example index a2e3de8bf1..1aa95fd5be 100644 --- a/.env.example +++ b/.env.example @@ -893,6 +893,14 @@ OPENWEATHER_API_KEY= # Cache connection status checks for this many milliseconds to avoid expensive verification # MCP_CONNECTION_CHECK_TTL=60000 +# Max bytes allowed in a non-GET streamable HTTP MCP response before rejecting it. +# Set to 0 to disable. Default: 16777216 (16 MiB) +# MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES=16777216 + +# Max bytes allowed in a single SSE line for non-GET streamable HTTP MCP responses. +# Set to 0 to disable. Default: 1048576 (1 MiB) +# MCP_STREAMABLE_HTTP_MAX_LINE_BYTES=1048576 + # Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it) # When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration # MCP_SKIP_CODE_CHALLENGE_CHECK=false diff --git a/packages/api/src/agents/handlers.spec.ts b/packages/api/src/agents/handlers.spec.ts index 41200f31cd..dca3a85005 100644 --- a/packages/api/src/agents/handlers.spec.ts +++ b/packages/api/src/agents/handlers.spec.ts @@ -1,4 +1,5 @@ import { Constants } from '@librechat/agents'; +import { logger } from '@librechat/data-schemas'; import type { ToolExecuteBatchRequest, ToolExecuteResult, @@ -278,6 +279,128 @@ describe('createToolExecuteHandler', () => { }); }); + describe('tool error handling', () => { + it('truncates oversized tool errors in the result and log context', async () => { + const oversizedMessage = `tool failed: ${'x'.repeat(15_000)}`; + const thrown = new Error(oversizedMessage); + thrown.stack = `Error: ${oversizedMessage}\n${'stack-line\n'.repeat(600)}`; + const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({ + loadedTools: [ + { + name: 'bad_tool', + invoke: jest.fn(async () => { + throw thrown; + }), + }, + ] as never[], + })); + const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger); + try { + const handler = createToolExecuteHandler({ loadTools }); + const [result] = await invokeHandler(handler, [ + { + id: 'call_bad', + name: 'bad_tool', + args: {}, + }, + ]); + + expect(result.status).toBe('error'); + expect(result.errorMessage).toContain('truncated'); + expect(result.errorMessage!.length).toBeLessThanOrEqual(12_000); + expect(errorSpy).toHaveBeenCalledWith( + '[ON_TOOL_EXECUTE] Tool bad_tool error', + expect.objectContaining({ + messageTruncated: true, + messageLength: oversizedMessage.length, + }), + ); + const [, logContext] = errorSpy.mock.calls[0] as unknown as [string, { stack?: string }]; + expect(logContext.stack!.length).toBeLessThanOrEqual(4_000); + } finally { + errorSpy.mockRestore(); + } + }); + + it('returns a per-tool error when thrown value stringification fails', async () => { + const thrown = { + toString() { + throw new Error('toString failed'); + }, + }; + const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({ + loadedTools: [ + { + name: 'bad_to_string_tool', + invoke: jest.fn(async () => { + throw thrown; + }), + }, + ] as never[], + })); + const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger); + try { + const handler = createToolExecuteHandler({ loadTools }); + const [result] = await invokeHandler(handler, [ + { + id: 'call_bad_to_string', + name: 'bad_to_string_tool', + args: {}, + }, + ]); + + expect(result.status).toBe('error'); + expect(result.errorMessage).toBe('[Thrown value could not be converted to string]'); + expect(errorSpy).toHaveBeenCalledWith( + '[ON_TOOL_EXECUTE] Tool bad_to_string_tool error', + expect.objectContaining({ + name: 'object', + messageTruncated: false, + }), + ); + } finally { + errorSpy.mockRestore(); + } + }); + + it('preserves message from thrown plain objects', async () => { + const thrown = { message: 'plain object timeout' }; + const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({ + loadedTools: [ + { + name: 'plain_object_tool', + invoke: jest.fn(async () => { + throw thrown; + }), + }, + ] as never[], + })); + const errorSpy = jest.spyOn(logger, 'error').mockReturnValue(logger); + try { + const handler = createToolExecuteHandler({ loadTools }); + const [result] = await invokeHandler(handler, [ + { + id: 'call_plain_object', + name: 'plain_object_tool', + args: {}, + }, + ]); + + expect(result.status).toBe('error'); + expect(result.errorMessage).toBe('plain object timeout'); + expect(errorSpy).toHaveBeenCalledWith( + '[ON_TOOL_EXECUTE] Tool plain_object_tool error', + expect.objectContaining({ + message: 'plain object timeout', + messageTruncated: false, + }), + ); + } finally { + errorSpy.mockRestore(); + } + }); + }); + describe('skill tool model-invocation gate', () => { function createSkillHandler(getSkillByName: ToolExecuteOptions['getSkillByName']) { const loadTools: ToolExecuteOptions['loadTools'] = jest.fn(async () => ({ diff --git a/packages/api/src/agents/handlers.ts b/packages/api/src/agents/handlers.ts index 5a6e3834bf..4690cf73e4 100644 --- a/packages/api/src/agents/handlers.ts +++ b/packages/api/src/agents/handlers.ts @@ -154,9 +154,77 @@ export interface ToolExecuteOptions { const MAX_READABLE_BYTES = 262_144; const MAX_BINARY_BYTES = 5 * 1024 * 1024; const MAX_CACHE_BYTES = 512 * 1024; +const MAX_TOOL_ERROR_MESSAGE_CHARS = 12_000; +const MAX_TOOL_ERROR_STACK_CHARS = 4_000; const IMAGE_MIMES = new Set(['image/png', 'image/jpeg', 'image/gif', 'image/webp']); +function truncateMiddle(value: string, maxChars: number): string { + if (value.length <= maxChars) { + return value; + } + + const indicator = `\n\n... [truncated: ${value.length} chars exceeded ${maxChars} limit] ...\n\n`; + const available = maxChars - indicator.length; + if (available <= 0) { + return value.slice(0, maxChars); + } + + const headSize = Math.ceil(available * 0.7); + const tailSize = available - headSize; + return value.slice(0, headSize) + indicator + value.slice(value.length - tailSize); +} + +function stringifyThrownValue(error: unknown): string { + try { + return String(error); + } catch { + return '[Thrown value could not be converted to string]'; + } +} + +function getThrownValueMessage(error: unknown): string { + if (error instanceof Error) { + return error.message; + } + + if (error != null && typeof error === 'object') { + try { + const message = (error as { message?: unknown }).message; + if (typeof message === 'string') { + return message; + } + if (message != null) { + return stringifyThrownValue(message); + } + } catch { + // Fall through to whole-value stringification. + } + } + + return stringifyThrownValue(error); +} + +function getSafeToolError(error: unknown): { + message: string; + logContext: Record; +} { + const rawMessage = getThrownValueMessage(error); + const message = truncateMiddle(rawMessage, MAX_TOOL_ERROR_MESSAGE_CHARS); + const stack = error instanceof Error && error.stack ? error.stack : undefined; + + return { + message, + logContext: { + name: error instanceof Error ? error.name : typeof error, + message, + messageLength: rawMessage.length, + messageTruncated: message.length !== rawMessage.length, + stack: stack ? truncateMiddle(stack, MAX_TOOL_ERROR_STACK_CHARS) : undefined, + }, + }; +} + function addLineNumbers(content: string): string { const lines = content.split('\n'); const w = String(lines.length).length; @@ -1195,13 +1263,13 @@ export function createToolExecuteHandler(options: ToolExecuteOptions): EventHand status: 'success' as const, }; } catch (toolError) { - const error = toolError as Error; - logger.error(`[ON_TOOL_EXECUTE] Tool ${tc.name} error:`, error); + const { message, logContext } = getSafeToolError(toolError); + logger.error(`[ON_TOOL_EXECUTE] Tool ${tc.name} error`, logContext); return { toolCallId: tc.id, status: 'error' as const, content: '', - errorMessage: error.message, + errorMessage: message, }; } }), diff --git a/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts index 98f698c9cd..fcc6b72dcf 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts @@ -258,6 +258,45 @@ async function createStreamableServer(): Promise }; } +async function createOversizedToolResultStreamableServer( + payloadSize: number, +): Promise> { + const sessions = new Map(); + + const httpServer = http.createServer(async (req, res) => { + const sid = req.headers['mcp-session-id'] as string | undefined; + let transport = sid ? sessions.get(sid) : undefined; + + if (!transport) { + transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); + const mcp = new McpServer({ name: 'oversized-tool-result', version: '0.0.1' }); + mcp.tool('oversized', 'Returns an oversized text payload', {}, async () => ({ + content: [{ type: 'text', text: 'x'.repeat(payloadSize) }], + })); + await mcp.connect(transport); + } + + await transport.handleRequest(req, res); + + if (transport.sessionId && !sessions.has(transport.sessionId)) { + sessions.set(transport.sessionId, transport); + transport.onclose = () => sessions.delete(transport!.sessionId!); + } + }); + + const destroySockets = trackSockets(httpServer); + const port = await getFreePort(); + await new Promise((resolve) => httpServer.listen(port, '127.0.0.1', resolve)); + + return { + url: `http://127.0.0.1:${port}/`, + close: async () => { + await closeMCPSessions(sessions); + await destroySockets(); + }, + }; +} + describe('MCP SSRF protection – redirect blocking', () => { let redirectServer: TestServer; let conn: MCPConnection | null; @@ -457,6 +496,11 @@ interface HeaderCaptureServer { close: () => Promise; } +interface RawResponseServer { + url: string; + close: () => Promise; +} + /** * Captures every incoming request's headers, method, and body, then replies * with a benign 200 so tests can assert what actually crossed a redirect @@ -490,6 +534,17 @@ async function createHeaderCaptureServer(): Promise { }; } +async function createRawResponseServer(handler: http.RequestListener): Promise { + const server = http.createServer(handler); + const destroySockets = trackSockets(server); + const port = await getFreePort(); + await new Promise((resolve) => server.listen(port, '127.0.0.1', resolve)); + return { + url: `http://127.0.0.1:${port}/`, + close: destroySockets, + }; +} + /** * Issues a single 307/308 redirect to `redirectTarget` (which lives on a * different port and is therefore a different origin). Used to verify @@ -823,8 +878,20 @@ describe('MCP SSRF protection – cross-origin credential stripping on redirect' describe('MCP SSRF protection – customFetch input shapes', () => { let target: Omit | undefined; let conn: MCPConnection | null; + const originalMaxResponseBytes = process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES; + const originalMaxLineBytes = process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES; afterEach(async () => { + if (originalMaxResponseBytes == null) { + delete process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES; + } else { + process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = originalMaxResponseBytes; + } + if (originalMaxLineBytes == null) { + delete process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES; + } else { + process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = originalMaxLineBytes; + } await safeDisconnect(conn); conn = null; if (target) { @@ -846,12 +913,33 @@ describe('MCP SSRF protection – customFetch input shapes', () => { connection as unknown as { createFetchFunction: ( getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + guardStreamableHTTPResponses?: boolean, ) => CustomFetch; } ).createFetchFunction; return factory.call(connection, () => null); } + function getGuardedStreamableHTTPCustomFetch(connection: MCPConnection): CustomFetch { + const factory = ( + connection as unknown as { + createFetchFunction: ( + getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + guardStreamableHTTPResponses?: boolean, + ) => CustomFetch; + } + ).createFetchFunction; + return factory.call(connection, () => null, undefined, undefined, undefined, undefined, true); + } + it.each<['string' | 'URL' | 'Request']>([['string'], ['URL'], ['Request']])( 'should accept a %s input without throwing on URL derivation', async (shape) => { @@ -1015,6 +1103,142 @@ describe('MCP SSRF protection – customFetch input shapes', () => { } } }); + + it('should not apply streamable HTTP response caps unless the transport opts in', async () => { + process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = '8'; + const server = await createRawResponseServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end('{"jsonrpc":"2.0","id":1,"result":{"too":"large"}}'); + }); + try { + conn = new MCPConnection({ + serverName: 'customfetch-unguarded-byte-limit', + serverConfig: { type: 'sse', url: server.url }, + useSSRFProtection: false, + }); + + const customFetch = getCustomFetch(conn); + const response = await customFetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }), + }); + + await expect(response.text()).resolves.toContain('"too":"large"'); + } finally { + await server.close(); + } + }); + + it('should reject oversized JSON POST responses with the streamable HTTP byte cap', async () => { + process.env.MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = '8'; + const server = await createRawResponseServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end('{"jsonrpc":"2.0","id":1,"result":{"too":"large"}}'); + }); + try { + conn = new MCPConnection({ + serverName: 'customfetch-json-byte-limit', + serverConfig: { type: 'streamable-http', url: server.url }, + useSSRFProtection: false, + }); + + const customFetch = getGuardedStreamableHTTPCustomFetch(conn); + const response = await customFetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }), + }); + + await expect(response.text()).rejects.toThrow( + /MCP response exceeded byte limit.*limit=8 bytes/, + ); + } finally { + await server.close(); + } + }); + + it('should reject a POST response with an oversized SSE line before the SSE parser can grow it', async () => { + process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '16'; + const server = await createRawResponseServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.end(`data: ${'x'.repeat(64)}\n\n`); + }); + try { + conn = new MCPConnection({ + serverName: 'customfetch-sse-line-limit', + serverConfig: { type: 'streamable-http', url: server.url }, + useSSRFProtection: false, + }); + + const customFetch = getGuardedStreamableHTTPCustomFetch(conn); + const response = await customFetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'notifications/cancelled' }), + }); + await expect(response.text()).rejects.toThrow( + /MCP response contained an oversized SSE line.*lineLimit=16 bytes.*observedLine=17 bytes/, + ); + } finally { + await server.close(); + } + }); + + it('should fail an actual streamable HTTP tool call promptly with a clear oversized SSE line error', async () => { + process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '512'; + target = await createOversizedToolResultStreamableServer(2048); + conn = new MCPConnection({ + serverName: 'streamable-http-tool-call-sse-line-limit', + serverConfig: { type: 'streamable-http', url: target.url }, + useSSRFProtection: false, + }); + + await conn.connect(); + const startedAt = Date.now(); + + await expect( + conn.client.callTool({ name: 'oversized', arguments: {} }, undefined, { timeout: 3000 }), + ).rejects.toThrow(/MCP response contained an oversized SSE line/); + expect(Date.now() - startedAt).toBeLessThan(1500); + }); + + it('should stream valid SSE POST responses without waiting for EOF', async () => { + process.env.MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = '4096'; + let finish!: () => void; + const finished = new Promise((resolve) => { + finish = resolve; + }); + const server = await createRawResponseServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"jsonrpc":"2.0","id":1,"result":{}}\n\n'); + finished.then(() => res.end()).catch(() => res.end()); + }); + try { + conn = new MCPConnection({ + serverName: 'customfetch-sse-streaming', + serverConfig: { type: 'streamable-http', url: server.url }, + useSSRFProtection: false, + }); + + const customFetch = getGuardedStreamableHTTPCustomFetch(conn); + const response = await customFetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 1 }), + }); + const reader = response.body!.getReader(); + const { value, done } = await reader.read(); + + expect(done).toBe(false); + expect(Buffer.from(value as Uint8Array).toString('utf8')).toContain('"result":{}'); + await reader.cancel().catch(() => undefined); + finish(); + } finally { + finish(); + await server.close(); + } + }); }); describe('MCP SSRF protection – WebSocket DNS resolution', () => { diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index fb9481b8b1..3eddf26a91 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -74,6 +74,351 @@ const SSE_CONNECT_TIMEOUT = 120000; const DEFAULT_INIT_TIMEOUT = 30000; /** Max 307/308 redirects to follow per request (prevents redirect loops) */ const MAX_REDIRECTS = 5; +const DEFAULT_MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES = 16 * 1024 * 1024; +const DEFAULT_MCP_STREAMABLE_HTTP_MAX_LINE_BYTES = 1024 * 1024; + +function getNonNegativeIntegerEnv(name: string, defaultValue: number): number { + const raw = process.env[name]; + if (raw == null || raw.trim() === '') { + return defaultValue; + } + + const trimmed = raw.trim(); + if (!/^\d+$/.test(trimmed)) { + return defaultValue; + } + + const parsed = Number(trimmed); + return Number.isSafeInteger(parsed) ? parsed : defaultValue; +} + +function bytesToMiB(bytes: number): string { + return `${(bytes / 1024 / 1024).toFixed(2)} MiB`; +} + +function getMemoryDebugSnapshot(): Record { + const mem = process.memoryUsage(); + return { + rss: bytesToMiB(mem.rss), + heapUsed: bytesToMiB(mem.heapUsed), + heapTotal: bytesToMiB(mem.heapTotal), + external: bytesToMiB(mem.external), + arrayBuffers: bytesToMiB(mem.arrayBuffers ?? 0), + }; +} + +const textEncoder = new TextEncoder(); +const textDecoder = new TextDecoder(); +type JSONRPCRequestId = string | number; + +function getChunkBytes(chunk: unknown): Uint8Array { + if (typeof chunk === 'string') { + return textEncoder.encode(chunk); + } + if (chunk instanceof ArrayBuffer) { + return new Uint8Array(chunk); + } + if (ArrayBuffer.isView(chunk)) { + const view = new Uint8Array(chunk.buffer, chunk.byteOffset, chunk.byteLength); + return new Uint8Array(view); + } + return new Uint8Array(); +} + +function copyBytes(bytes: Uint8Array): Uint8Array { + const copy = new Uint8Array(bytes.byteLength); + copy.set(bytes); + return copy; +} + +function concatBytes(chunks: Uint8Array[]): Uint8Array { + if (chunks.length === 0) { + return new Uint8Array(); + } + if (chunks.length === 1) { + return chunks[0]; + } + const totalLength = chunks.reduce((total, chunk) => total + chunk.byteLength, 0); + const combined = new Uint8Array(totalLength); + let offset = 0; + for (const chunk of chunks) { + combined.set(chunk, offset); + offset += chunk.byteLength; + } + return combined; +} + +function getBodyText(body: unknown): string | null { + if (typeof body === 'string') { + return body; + } + if (body instanceof ArrayBuffer) { + return textDecoder.decode(new Uint8Array(body)); + } + if (ArrayBuffer.isView(body)) { + return textDecoder.decode(new Uint8Array(body.buffer, body.byteOffset, body.byteLength)); + } + return null; +} + +function getJSONRPCRequestIds(body: unknown): JSONRPCRequestId[] { + const bodyText = getBodyText(body); + if (!bodyText) { + return []; + } + + let parsed: unknown; + try { + parsed = JSON.parse(bodyText); + } catch { + return []; + } + + const messages = Array.isArray(parsed) ? parsed : [parsed]; + return messages.flatMap((message) => { + if (!message || typeof message !== 'object') { + return []; + } + const jsonrpcMessage = message as { id?: unknown; method?: unknown }; + const { id } = jsonrpcMessage; + if (typeof jsonrpcMessage.method !== 'string') { + return []; + } + if (typeof id !== 'string' && typeof id !== 'number') { + return []; + } + return [id]; + }); +} + +function buildBlockedMCPResponseMessage( + reason: string, + details: { + maxResponseBytes: number; + maxLineBytes: number; + totalBytes: number; + currentLineBytes: number; + chunkCount: number; + }, +): string { + const limitDetails = + reason === 'MCP response exceeded byte limit' + ? `limit=${details.maxResponseBytes} bytes, observed=${details.totalBytes} bytes` + : `lineLimit=${details.maxLineBytes} bytes, observedLine=${details.currentLineBytes} bytes, observedTotal=${details.totalBytes} bytes`; + + return `[MCP] ${reason} (${limitDetails}, chunks=${details.chunkCount}). The MCP server returned an unsafe streamable HTTP response; narrow the tool result or retry after the server response is fixed.`; +} + +function buildBlockedMCPResponseSSE(requestIds: JSONRPCRequestId[], message: string): Uint8Array { + const events = requestIds + .map((id) => { + const payload = { + jsonrpc: '2.0', + id, + error: { + code: -32000, + message, + }, + }; + return `data: ${JSON.stringify(payload)}\n\n`; + }) + .join(''); + return textEncoder.encode(events); +} + +function getMCPStreamableHTTPResponseLimits(): { + maxResponseBytes: number; + maxLineBytes: number; +} { + return { + maxResponseBytes: getNonNegativeIntegerEnv( + 'MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES', + DEFAULT_MCP_STREAMABLE_HTTP_MAX_RESPONSE_BYTES, + ), + maxLineBytes: getNonNegativeIntegerEnv( + 'MCP_STREAMABLE_HTTP_MAX_LINE_BYTES', + DEFAULT_MCP_STREAMABLE_HTTP_MAX_LINE_BYTES, + ), + }; +} + +async function guardMCPStreamableHTTPResponse( + response: UndiciResponse, + context: { + logPrefix: string; + method: string; + url: string; + requestIds?: JSONRPCRequestId[]; + }, +): Promise { + if (context.method === 'GET' || !response.body) { + return response; + } + + const contentType = response.headers.get('content-type') ?? ''; + const isEventStream = contentType.toLowerCase().includes('text/event-stream'); + const { maxResponseBytes, maxLineBytes } = getMCPStreamableHTTPResponseLimits(); + const canEmitFallbackSSEError = isEventStream && maxLineBytes > 0; + if (!isEventStream && maxResponseBytes === 0) { + return response; + } + if (maxResponseBytes === 0 && maxLineBytes === 0) { + return response; + } + + let totalBytes = 0; + let currentLineBytes = 0; + let chunkCount = 0; + let pendingSSELineChunks: Uint8Array[] = []; + const sseEventDataLines: string[] = []; + const unresolvedRequestIds = new Set(context.requestIds ?? []); + + const buildAndLogBlockedError = ( + reason: string, + details: Record, + ): Error => { + const message = buildBlockedMCPResponseMessage(reason, { + maxResponseBytes, + maxLineBytes, + totalBytes, + currentLineBytes, + chunkCount, + }); + logger.warn(`${context.logPrefix} MCP streamable HTTP response blocked: ${reason}`, { + method: context.method, + url: sanitizeUrlForLogging(context.url), + status: response.status, + contentType, + maxResponseBytes, + maxLineBytes, + totalBytes, + currentLineBytes, + chunkCount, + ...details, + memory: getMemoryDebugSnapshot(), + }); + return new Error(message); + }; + + const trackSSELineForResolvedIds = (lineBytes: Uint8Array): void => { + if (unresolvedRequestIds.size === 0) { + return; + } + + const rawLine = textDecoder.decode(lineBytes).replace(/[\r\n]+$/, ''); + if (rawLine === '') { + if (sseEventDataLines.length === 0) { + return; + } + const data = sseEventDataLines.join('\n'); + sseEventDataLines.length = 0; + try { + const parsed = JSON.parse(data) as { id?: unknown }; + if (typeof parsed.id === 'string' || typeof parsed.id === 'number') { + unresolvedRequestIds.delete(parsed.id); + } + } catch { + /** Ignore malformed SSE data here; the SDK parser will report it. */ + } + return; + } + + const separatorIndex = rawLine.indexOf(':'); + const field = separatorIndex === -1 ? rawLine : rawLine.slice(0, separatorIndex); + if (field !== 'data') { + return; + } + let value = separatorIndex === -1 ? '' : rawLine.slice(separatorIndex + 1); + if (value.startsWith(' ')) { + value = value.slice(1); + } + sseEventDataLines.push(value); + }; + + const enqueuePendingSSELine = (controller: TransformStreamDefaultController) => { + if (pendingSSELineChunks.length === 0) { + return; + } + const lineBytes = concatBytes(pendingSSELineChunks); + pendingSSELineChunks = []; + trackSSELineForResolvedIds(lineBytes); + controller.enqueue(lineBytes); + }; + + const blockResponse = ( + controller: TransformStreamDefaultController, + reason: string, + details: Record, + ) => { + const error = buildAndLogBlockedError(reason, details); + const fallbackRequestIds = [...unresolvedRequestIds]; + if (canEmitFallbackSSEError && fallbackRequestIds.length > 0) { + controller.enqueue(buildBlockedMCPResponseSSE(fallbackRequestIds, error.message)); + controller.terminate(); + return; + } + throw error; + }; + + const guardedBody = (response.body as unknown as ReadableStream).pipeThrough( + new TransformStream({ + transform(chunk, controller) { + const bytes = getChunkBytes(chunk); + if (bytes.byteLength === 0) { + return; + } + + chunkCount += 1; + totalBytes += bytes.byteLength; + + if (maxResponseBytes > 0 && totalBytes > maxResponseBytes) { + blockResponse(controller, 'MCP response exceeded byte limit', { + chunkBytes: bytes.byteLength, + }); + return; + } + + if (isEventStream && maxLineBytes > 0) { + let segmentStart = 0; + for (let i = 0; i < bytes.byteLength; i++) { + const byte = bytes[i]; + if (byte === 10 || byte === 13) { + if (i + 1 > segmentStart) { + pendingSSELineChunks.push(copyBytes(bytes.subarray(segmentStart, i + 1))); + } + enqueuePendingSSELine(controller); + segmentStart = i + 1; + currentLineBytes = 0; + continue; + } + currentLineBytes += 1; + if (currentLineBytes > maxLineBytes) { + blockResponse(controller, 'MCP response contained an oversized SSE line', { + chunkBytes: bytes.byteLength, + }); + return; + } + } + if (segmentStart < bytes.byteLength) { + pendingSSELineChunks.push(copyBytes(bytes.subarray(segmentStart))); + } + return; + } + + controller.enqueue(bytes); + }, + flush(controller) { + enqueuePendingSSELine(controller); + }, + }), + ); + + return new Response(guardedBody as unknown as BodyInit, { + status: response.status, + statusText: response.statusText, + headers: response.headers as unknown as HeadersInit, + }) as unknown as UndiciResponse; +} /** * Headers stripped before forwarding a request across an origin boundary on @@ -579,6 +924,7 @@ export class MCPConnection extends EventEmitter { sseBodyTimeout?: number, configuredSecretHeaderKeys?: ReadonlySet, baseUrl?: string, + guardStreamableHTTPResponses = false, ): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise { const basePort = baseUrl ? getUrlPort(baseUrl) : ''; const ssrfConnect = this.useSSRFProtection @@ -587,6 +933,7 @@ export class MCPConnection extends EventEmitter { const connectOpts = ssrfConnect != null ? { connect: ssrfConnect } : {}; /** Capture only the fields needed by the fetch closure; see factory note above. */ const agents = this.agents; + const logPrefix = this.getLogPrefix(); const effectiveTimeout = timeout || DEFAULT_TIMEOUT; const postAgent = new Agent({ bodyTimeout: effectiveTimeout, @@ -664,18 +1011,27 @@ export class MCPConnection extends EventEmitter { let currentInit = buildFetchInit(resolvedInit, dispatcher, requestHeaders); let currentUrlString = urlString; const originalOrigin = new URL(currentUrlString).origin; - for (let redirects = 0; ; redirects++) { const response = await undiciFetch(currentUrlString, currentInit); const isMethodPreservingRedirect = response.status === 307 || response.status === 308; + const responseContext = { + logPrefix, + method: (currentInit?.method ?? 'GET').toUpperCase(), + url: currentUrlString, + requestIds: getJSONRPCRequestIds(currentInit?.body), + }; if (!isMethodPreservingRedirect || redirects >= MAX_REDIRECTS) { - return response; + return guardStreamableHTTPResponses + ? guardMCPStreamableHTTPResponse(response, responseContext) + : response; } const location = response.headers.get('location'); if (!location) { - return response; + return guardStreamableHTTPResponses + ? guardMCPStreamableHTTPResponse(response, responseContext) + : response; } const targetUrl = new URL(location, currentUrlString); @@ -904,6 +1260,7 @@ export class MCPConnection extends EventEmitter { this.sseReadTimeout || DEFAULT_SSE_READ_TIMEOUT, httpConfiguredSecretHeaderKeys, options.url, + true, ) as unknown as FetchLike, });