diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 4375f6e307..2f3599f1ec 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -596,11 +596,15 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const flowManager = getFlowStateManager(flowsCache); const configServers = await resolveConfigServers(req); const pendingOAuthServers = new Set(); + const oauthToolCallIds = new Map(); + const oauthStepIndexes = new Map(); - const createOAuthEmitter = (serverName) => { + const createOAuthEmitter = (serverName, index) => { return async (authURL) => { const flowId = `${req.user.id}:${serverName}:${Date.now()}`; const stepId = 'step_oauth_login_' + serverName; + oauthToolCallIds.set(serverName, flowId); + oauthStepIndexes.set(serverName, index); const toolCall = { id: flowId, name: buildOAuthToolCallName(serverName), @@ -611,7 +615,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, id: stepId, type: StepTypes.TOOL_CALLS, - index: 0, + index, stepDetails: { type: StepTypes.TOOL_CALLS, tool_calls: [toolCall], @@ -645,6 +649,40 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to }; }; + const createOAuthEndEmitter = (serverName) => { + return async () => { + const stepId = 'step_oauth_login_' + serverName; + const toolCall = { + id: oauthToolCallIds.get(serverName), + name: buildOAuthToolCallName(serverName), + args: '', + output: 'OAuth authentication completed', + type: 'tool_call', + }; + + const runStepCompletedEvent = { + event: GraphEvents.ON_RUN_STEP_COMPLETED, + data: { + result: { + id: stepId, + index: oauthStepIndexes.get(serverName) ?? 0, + tool_call: toolCall, + }, + }, + }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, runStepCompletedEvent); + } else if (res && !res.writableEnded) { + sendEvent(res, runStepCompletedEvent); + } else { + logger.warn( + `[Tool Definitions] Cannot emit OAuth completion for ${serverName}: no streamId and res not available`, + ); + } + }; + }; + const getOrFetchMCPServerTools = async (userId, serverName) => { let serverConfig; try { @@ -781,7 +819,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to `[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`, ); - const oauthWaitPromises = serverNames.map(async (serverName) => { + const oauthWaitPromises = serverNames.map(async (serverName, index) => { try { const result = await reinitMCPServer({ user: req.user, @@ -790,7 +828,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to userMCPAuthMap, flowManager, returnOnOAuth: false, - oauthStart: createOAuthEmitter(serverName), + oauthStart: createOAuthEmitter(serverName, index), + oauthEnd: createOAuthEndEmitter(serverName), connectionTimeout: Time.TWO_MINUTES, }); diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index e13fa9989b..5a4aab7c15 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -21,6 +21,7 @@ const { getLogStores } = require('~/cache'); * @param {number} [params.connectionTimeout] * @param {FlowStateManager} [params.flowManager] * @param {(authURL: string) => Promise} [params.oauthStart] + * @param {() => Promise} [params.oauthEnd] * @param {Record>} [params.userMCPAuthMap] */ async function reinitMCPServer({ @@ -35,6 +36,7 @@ async function reinitMCPServer({ oauthStart: _oauthStart, flowManager: _flowManager, serverConfig: providedConfig, + oauthEnd, }) { /** @type {MCPConnection | null} */ let connection = null; @@ -63,28 +65,29 @@ async function reinitMCPServer({ oauthUrl: null, tools: null, }; - } - logger.info( - `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, - ); - try { - const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; - await registry.reinspectServer(serverName, storageLocation, user?.id); - logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); - } catch (reinspectError) { - logger.error( - `[MCP Reinitialize] Reinspection failed for server ${serverName}:`, - reinspectError, + } else { + logger.info( + `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, ); - return { - availableTools: null, - success: false, - message: `MCP server '${serverName}' is still unreachable`, - oauthRequired: false, - serverName, - oauthUrl: null, - tools: null, - }; + try { + const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; + await registry.reinspectServer(serverName, storageLocation, user?.id); + logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); + } catch (reinspectError) { + logger.error( + `[MCP Reinitialize] Reinspection failed for server ${serverName}:`, + reinspectError, + ); + return { + availableTools: null, + success: false, + message: `MCP server '${serverName}' is still unreachable`, + oauthRequired: false, + serverName, + oauthUrl: null, + tools: null, + }; + } } } @@ -132,6 +135,7 @@ async function reinitMCPServer({ flowManager, tokenMethods, returnOnOAuth, + oauthEnd, customUserVars, connectionTimeout, serverConfig, diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index 3fc9e58eb6..28edda80fe 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -11,6 +11,8 @@ const { const mockGetEndpointsConfig = jest.fn(); const mockGetMCPServerTools = jest.fn(); const mockGetCachedTools = jest.fn(); +const mockSendEvent = jest.fn(); +const mockEmitChunk = jest.fn(); jest.mock('~/server/services/Config', () => ({ getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), getMCPServerTools: (...args) => mockGetMCPServerTools(...args), @@ -23,6 +25,10 @@ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args), getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args), + sendEvent: (...args) => mockSendEvent(...args), + GenerationJobManager: { + emitChunk: (...args) => mockEmitChunk(...args), + }, })); const mockLoadToolsUtil = jest.fn(); @@ -93,6 +99,7 @@ const { processRequiredActions, resolveAgentCapabilities, } = require('../ToolService'); +const { reinitMCPServer } = require('~/server/services/Tools/mcp'); function createMockReq(capabilities) { return { @@ -297,6 +304,85 @@ describe('ToolService - Action Capability Gating', () => { expect(result.actionsEnabled).toBe(false); }); + it('emits separate MCP OAuth login steps and completion events for multiple pending servers', async () => { + const req = createMockReq([AgentCapabilities.tools]); + const res = { writableEnded: false }; + const servers = ['ELI', 'Vespa']; + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig([AgentCapabilities.tools])); + mockResolveConfigServers.mockResolvedValue( + Object.fromEntries( + servers.map((serverName) => [ + serverName, + { + type: 'streamable-http', + url: `https://mcp.example.com/${serverName}`, + requiresOAuth: true, + }, + ]), + ), + ); + + mockLoadToolDefinitions + .mockImplementationOnce(async (_args, deps) => { + await deps.getOrFetchMCPServerTools(req.user.id, servers[0]); + await deps.getOrFetchMCPServerTools(req.user.id, servers[1]); + return { + toolDefinitions: [], + toolRegistry: new Map(), + hasDeferredTools: false, + }; + }) + .mockResolvedValue({ + toolDefinitions: [], + toolRegistry: new Map(), + hasDeferredTools: false, + }); + + reinitMCPServer.mockImplementation( + async ({ serverName, returnOnOAuth, oauthStart, oauthEnd }) => { + if (returnOnOAuth === false) { + await oauthStart(`https://auth.example.com/${serverName}`); + await oauthEnd(); + return { availableTools: { [`tool_${serverName}`]: {} } }; + } + + await oauthStart(`https://auth.example.com/${serverName}`); + return { availableTools: null }; + }, + ); + + await loadAgentTools({ + req, + res, + agent: { + id: 'agent_123', + tools: servers.map((server) => `search${Constants.mcp_delimiter}${server}`), + }, + definitionsOnly: true, + }); + + const runStepEvents = mockSendEvent.mock.calls + .map(([, event]) => event) + .filter((event) => event.data?.stepDetails?.type === 'tool_calls'); + const deltaEvents = mockSendEvent.mock.calls + .map(([, event]) => event) + .filter((event) => event.data?.delta?.type === 'tool_calls'); + const authDeltaEvents = deltaEvents.filter((event) => event.data.delta.auth); + const completionEvents = mockSendEvent.mock.calls + .map(([, event]) => event) + .filter((event) => event.data?.result?.tool_call?.name?.startsWith('oauth')); + + expect(runStepEvents.map((event) => event.data.index)).toEqual([0, 1]); + expect(authDeltaEvents.map((event) => event.data.id)).toEqual([ + 'step_oauth_login_ELI', + 'step_oauth_login_Vespa', + ]); + expect(completionEvents.map((event) => event.data.result.id)).toEqual([ + 'step_oauth_login_ELI', + 'step_oauth_login_Vespa', + ]); + }); + it('should not expose cached MCP tool definitions when the registry lookup fails', async () => { const serverName = 'private-server'; const mcpTool = `search${Constants.mcp_delimiter}${serverName}`; diff --git a/client/src/hooks/SSE/__tests__/useStepHandler.spec.ts b/client/src/hooks/SSE/__tests__/useStepHandler.spec.ts index ae28a9e6f1..2e1b730ff5 100644 --- a/client/src/hooks/SSE/__tests__/useStepHandler.spec.ts +++ b/client/src/hooks/SSE/__tests__/useStepHandler.spec.ts @@ -234,6 +234,182 @@ describe('useStepHandler', () => { ); }); + it('should preserve multiple tool call steps for the same preliminary response', () => { + const responseMessage = createResponseMessage(); + let currentMessages = [responseMessage]; + mockGetMessages.mockImplementation(() => currentMessages); + mockSetMessages.mockImplementation((messages) => { + currentMessages = messages; + }); + + const { result } = renderHook(() => useStepHandler(createHookParams())); + const submission = createSubmission({ initialResponse: responseMessage }); + + const firstRunStep = createToolCallRunStep({ + id: 'step-oauth-eli', + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + index: 0, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-eli', + name: `oauth${Constants.mcp_delimiter}ELI`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }); + const secondRunStep = createToolCallRunStep({ + id: 'step-oauth-vespa', + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + index: 1, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-vespa', + name: `oauth${Constants.mcp_delimiter}Vespa`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }); + + act(() => { + result.current.stepHandler( + { event: StepEvents.ON_RUN_STEP, data: firstRunStep }, + submission, + ); + result.current.stepHandler( + { event: StepEvents.ON_RUN_STEP, data: secondRunStep }, + submission, + ); + }); + + const responseMsg = currentMessages.find((m) => !m.isCreatedByUser); + expect(responseMsg?.content).toHaveLength(2); + expect(responseMsg?.content?.[0]?.tool_call?.name).toBe(`oauth${Constants.mcp_delimiter}ELI`); + expect(responseMsg?.content?.[1]?.tool_call?.name).toBe( + `oauth${Constants.mcp_delimiter}Vespa`, + ); + }); + + it('should clear OAuth prompt slots when one occupies the real response slot', () => { + const responseMessage = createResponseMessage(); + let currentMessages = [responseMessage]; + mockGetMessages.mockImplementation(() => currentMessages); + mockSetMessages.mockImplementation((messages) => { + currentMessages = messages; + }); + + const { result } = renderHook(() => useStepHandler(createHookParams())); + const submission = createSubmission({ initialResponse: responseMessage }); + + act(() => { + result.current.stepHandler( + { + event: StepEvents.ON_RUN_STEP, + data: createToolCallRunStep({ + id: 'step-oauth-eli', + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + index: 0, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-eli', + name: `oauth${Constants.mcp_delimiter}ELI`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }), + }, + submission, + ); + result.current.stepHandler( + { + event: StepEvents.ON_RUN_STEP, + data: createToolCallRunStep({ + id: 'step-oauth-vespa', + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + index: 1, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-vespa', + name: `oauth${Constants.mcp_delimiter}Vespa`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }), + }, + submission, + ); + result.current.stepHandler( + { + event: StepEvents.ON_RUN_STEP, + data: createRunStep({ + id: 'step-message', + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + index: 0, + }), + }, + submission, + ); + result.current.stepHandler( + { event: StepEvents.ON_MESSAGE_DELTA, data: createMessageDelta('step-message', 'Ready') }, + submission, + ); + }); + + const responseMsg = currentMessages.find((m) => !m.isCreatedByUser); + expect(responseMsg?.content).toEqual([{ type: ContentTypes.TEXT, text: 'Ready' }]); + }); + + it('should not replace the message list from a shorter refresh during tool call steps', () => { + const userMessage = createUserMessage(); + const responseMessage = createResponseMessage(); + mockGetMessages.mockReturnValueOnce([userMessage, responseMessage]).mockReturnValueOnce([]); + + const { result } = renderHook(() => useStepHandler(createHookParams())); + + const runStep = createToolCallRunStep({ + runId: responseMessage.messageId, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-eli', + name: `oauth${Constants.mcp_delimiter}ELI`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }); + + act(() => { + result.current.stepHandler( + { event: StepEvents.ON_RUN_STEP, data: runStep }, + createSubmission({ userMessage, initialResponse: responseMessage }), + ); + }); + + const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0]; + expect(lastCall.map((message: TMessage) => message.messageId)).toEqual([ + userMessage.messageId, + responseMessage.messageId, + ]); + }); + it('should replay buffered deltas after registering step', () => { const responseMessage = createResponseMessage(); mockGetMessages.mockReturnValue([responseMessage]); @@ -756,6 +932,59 @@ describe('useStepHandler', () => { ); consoleSpy.mockRestore(); }); + + it('should mark completed OAuth prompts as finished', () => { + const responseMessage = createResponseMessage(); + mockGetMessages.mockReturnValue([responseMessage]); + + const { result } = renderHook(() => useStepHandler(createHookParams())); + + const runStep = createToolCallRunStep({ + id: 'step-oauth-eli', + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [ + { + id: 'tool-call-eli', + name: `oauth${Constants.mcp_delimiter}ELI`, + args: '', + type: ToolCallTypes.TOOL_CALL, + }, + ], + }, + }); + const submission = createSubmission(); + + act(() => { + result.current.stepHandler({ event: StepEvents.ON_RUN_STEP, data: runStep }, submission); + }); + + act(() => { + result.current.stepHandler( + { + event: StepEvents.ON_RUN_STEP_COMPLETED, + data: { + result: { + id: 'step-oauth-eli', + index: 0, + tool_call: { + id: 'tool-call-eli', + name: `oauth${Constants.mcp_delimiter}ELI`, + args: '', + output: 'OAuth authentication completed', + type: ToolCallTypes.TOOL_CALL, + }, + }, + }, + }, + submission, + ); + }); + + const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0]; + const responseMsg = lastCall.find((m: TMessage) => !m.isCreatedByUser); + expect(responseMsg?.content?.[0]?.tool_call?.progress).toBe(1); + }); }); describe('clearStepMaps', () => { diff --git a/client/src/hooks/SSE/useStepHandler.ts b/client/src/hooks/SSE/useStepHandler.ts index e4c6283269..e67821e6bd 100644 --- a/client/src/hooks/SSE/useStepHandler.ts +++ b/client/src/hooks/SSE/useStepHandler.ts @@ -63,6 +63,14 @@ type AllContentTypes = | ContentTypes.SUMMARY | ContentTypes.ERROR; +const isOAuthToolCallName = (name?: string) => + typeof name === 'string' && name.startsWith(`oauth${Constants.mcp_delimiter}`); + +const isOAuthToolCallContent = (part?: Partial) => + part?.type === ContentTypes.TOOL_CALL && + 'tool_call' in part && + isOAuthToolCallName(part.tool_call?.name); + export default function useStepHandler({ setMessages, getMessages, @@ -100,6 +108,14 @@ export default function useStepHandler({ */ const knownSubagentAtomKeys = useRef(new Set()); + const getCurrentMessages = useCallback( + (messages: TMessage[]) => { + const freshMessages = getMessages(); + return freshMessages && freshMessages.length >= messages.length ? freshMessages : messages; + }, + [getMessages], + ); + /** Both content parts and ticker lines are aggregated incrementally * into the atom as each `ON_SUBAGENT_UPDATE` arrives — we never * retain the raw event array, so no rolling window is needed. A @@ -270,9 +286,20 @@ export default function useStepHandler({ return message; } - const updatedContent = [...(message.content || [])] as Array< + const incomingOAuthToolCall = + contentType === ContentTypes.TOOL_CALL && + 'tool_call' in contentPart && + isOAuthToolCallName(contentPart.tool_call?.name); + + let updatedContent = [...(message.content || [])] as Array< Partial | undefined >; + + const oauthPromptOccupiesSlot = isOAuthToolCallContent(updatedContent[index]); + if (!incomingOAuthToolCall && oauthPromptOccupiesSlot) { + updatedContent = updatedContent.filter((part) => !isOAuthToolCallContent(part)); + } + if (!updatedContent[index] && contentType !== ContentTypes.TOOL_CALL) { updatedContent[index] = { type: contentPart.type as AllContentTypes }; } @@ -517,9 +544,15 @@ export default function useStepHandler({ }); messageMap.current.set(responseMessageId, updatedResponse); - const updatedMessages = messages.map((msg) => - msg.messageId === responseMessageId ? updatedResponse : msg, + const currentMessages = getCurrentMessages(messages); + const hasResponseMessage = currentMessages.some( + (msg) => msg.messageId === responseMessageId, ); + const updatedMessages = hasResponseMessage + ? currentMessages.map((msg) => + msg.messageId === responseMessageId ? updatedResponse : msg, + ) + : [...currentMessages, updatedResponse]; setMessages(updatedMessages); } @@ -724,9 +757,15 @@ export default function useStepHandler({ }); messageMap.current.set(responseMessageId, updatedResponse); - const updatedMessages = messages.map((msg) => - msg.messageId === responseMessageId ? updatedResponse : msg, + const currentMessages = getCurrentMessages(messages); + const hasResponseMessage = currentMessages.some( + (msg) => msg.messageId === responseMessageId, ); + const updatedMessages = hasResponseMessage + ? currentMessages.map((msg) => + msg.messageId === responseMessageId ? updatedResponse : msg, + ) + : [...currentMessages, updatedResponse]; setMessages(updatedMessages); } @@ -767,9 +806,15 @@ export default function useStepHandler({ ); messageMap.current.set(responseMessageId, updatedResponse); - const updatedMessages = messages.map((msg) => - msg.messageId === responseMessageId ? updatedResponse : msg, + const currentMessages = getCurrentMessages(messages); + const hasResponseMessage = currentMessages.some( + (msg) => msg.messageId === responseMessageId, ); + const updatedMessages = hasResponseMessage + ? currentMessages.map((msg) => + msg.messageId === responseMessageId ? updatedResponse : msg, + ) + : [...currentMessages, updatedResponse]; setMessages(updatedMessages); } @@ -883,6 +928,7 @@ export default function useStepHandler({ announcePolite, setMessages, calculateContentIndex, + getCurrentMessages, applySubagentUpdate, ], );