🔐 fix: Handle Multiple Concurrent MCP OAuth Login Prompts (#13200)

* fix: handle multiple MCP OAuth prompts

* fix: address MCP OAuth review feedback

* fix: address MCP OAuth prompt lifecycle review

* fix: narrow OAuth prompt slot cleanup

* fix: format OAuth prompt test

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
janluedemann-esome 2026-06-05 23:18:24 +02:00 committed by GitHub
parent aeb5adff34
commit 2ed59ac98a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 436 additions and 32 deletions

View file

@ -596,11 +596,15 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
const flowManager = getFlowStateManager(flowsCache);
const configServers = await resolveConfigServers(req);
const pendingOAuthServers = new Set();
const oauthToolCallIds = new Map();
const oauthStepIndexes = new Map();
const createOAuthEmitter = (serverName) => {
const createOAuthEmitter = (serverName, index) => {
return async (authURL) => {
const flowId = `${req.user.id}:${serverName}:${Date.now()}`;
const stepId = 'step_oauth_login_' + serverName;
oauthToolCallIds.set(serverName, flowId);
oauthStepIndexes.set(serverName, index);
const toolCall = {
id: flowId,
name: buildOAuthToolCallName(serverName),
@ -611,7 +615,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
id: stepId,
type: StepTypes.TOOL_CALLS,
index: 0,
index,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [toolCall],
@ -645,6 +649,40 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
};
};
const createOAuthEndEmitter = (serverName) => {
return async () => {
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: oauthToolCallIds.get(serverName),
name: buildOAuthToolCallName(serverName),
args: '',
output: 'OAuth authentication completed',
type: 'tool_call',
};
const runStepCompletedEvent = {
event: GraphEvents.ON_RUN_STEP_COMPLETED,
data: {
result: {
id: stepId,
index: oauthStepIndexes.get(serverName) ?? 0,
tool_call: toolCall,
},
},
};
if (streamId) {
await GenerationJobManager.emitChunk(streamId, runStepCompletedEvent);
} else if (res && !res.writableEnded) {
sendEvent(res, runStepCompletedEvent);
} else {
logger.warn(
`[Tool Definitions] Cannot emit OAuth completion for ${serverName}: no streamId and res not available`,
);
}
};
};
const getOrFetchMCPServerTools = async (userId, serverName) => {
let serverConfig;
try {
@ -781,7 +819,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
`[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`,
);
const oauthWaitPromises = serverNames.map(async (serverName) => {
const oauthWaitPromises = serverNames.map(async (serverName, index) => {
try {
const result = await reinitMCPServer({
user: req.user,
@ -790,7 +828,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
userMCPAuthMap,
flowManager,
returnOnOAuth: false,
oauthStart: createOAuthEmitter(serverName),
oauthStart: createOAuthEmitter(serverName, index),
oauthEnd: createOAuthEndEmitter(serverName),
connectionTimeout: Time.TWO_MINUTES,
});

View file

@ -21,6 +21,7 @@ const { getLogStores } = require('~/cache');
* @param {number} [params.connectionTimeout]
* @param {FlowStateManager<any>} [params.flowManager]
* @param {(authURL: string) => Promise<void>} [params.oauthStart]
* @param {() => Promise<void>} [params.oauthEnd]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
@ -35,6 +36,7 @@ async function reinitMCPServer({
oauthStart: _oauthStart,
flowManager: _flowManager,
serverConfig: providedConfig,
oauthEnd,
}) {
/** @type {MCPConnection | null} */
let connection = null;
@ -63,28 +65,29 @@ async function reinitMCPServer({
oauthUrl: null,
tools: null,
};
}
logger.info(
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
);
try {
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
await registry.reinspectServer(serverName, storageLocation, user?.id);
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
} catch (reinspectError) {
logger.error(
`[MCP Reinitialize] Reinspection failed for server ${serverName}:`,
reinspectError,
} else {
logger.info(
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
);
return {
availableTools: null,
success: false,
message: `MCP server '${serverName}' is still unreachable`,
oauthRequired: false,
serverName,
oauthUrl: null,
tools: null,
};
try {
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
await registry.reinspectServer(serverName, storageLocation, user?.id);
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
} catch (reinspectError) {
logger.error(
`[MCP Reinitialize] Reinspection failed for server ${serverName}:`,
reinspectError,
);
return {
availableTools: null,
success: false,
message: `MCP server '${serverName}' is still unreachable`,
oauthRequired: false,
serverName,
oauthUrl: null,
tools: null,
};
}
}
}
@ -132,6 +135,7 @@ async function reinitMCPServer({
flowManager,
tokenMethods,
returnOnOAuth,
oauthEnd,
customUserVars,
connectionTimeout,
serverConfig,

View file

@ -11,6 +11,8 @@ const {
const mockGetEndpointsConfig = jest.fn();
const mockGetMCPServerTools = jest.fn();
const mockGetCachedTools = jest.fn();
const mockSendEvent = jest.fn();
const mockEmitChunk = jest.fn();
jest.mock('~/server/services/Config', () => ({
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args),
getMCPServerTools: (...args) => mockGetMCPServerTools(...args),
@ -23,6 +25,10 @@ jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args),
getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args),
sendEvent: (...args) => mockSendEvent(...args),
GenerationJobManager: {
emitChunk: (...args) => mockEmitChunk(...args),
},
}));
const mockLoadToolsUtil = jest.fn();
@ -93,6 +99,7 @@ const {
processRequiredActions,
resolveAgentCapabilities,
} = require('../ToolService');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
function createMockReq(capabilities) {
return {
@ -297,6 +304,85 @@ describe('ToolService - Action Capability Gating', () => {
expect(result.actionsEnabled).toBe(false);
});
it('emits separate MCP OAuth login steps and completion events for multiple pending servers', async () => {
const req = createMockReq([AgentCapabilities.tools]);
const res = { writableEnded: false };
const servers = ['ELI', 'Vespa'];
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig([AgentCapabilities.tools]));
mockResolveConfigServers.mockResolvedValue(
Object.fromEntries(
servers.map((serverName) => [
serverName,
{
type: 'streamable-http',
url: `https://mcp.example.com/${serverName}`,
requiresOAuth: true,
},
]),
),
);
mockLoadToolDefinitions
.mockImplementationOnce(async (_args, deps) => {
await deps.getOrFetchMCPServerTools(req.user.id, servers[0]);
await deps.getOrFetchMCPServerTools(req.user.id, servers[1]);
return {
toolDefinitions: [],
toolRegistry: new Map(),
hasDeferredTools: false,
};
})
.mockResolvedValue({
toolDefinitions: [],
toolRegistry: new Map(),
hasDeferredTools: false,
});
reinitMCPServer.mockImplementation(
async ({ serverName, returnOnOAuth, oauthStart, oauthEnd }) => {
if (returnOnOAuth === false) {
await oauthStart(`https://auth.example.com/${serverName}`);
await oauthEnd();
return { availableTools: { [`tool_${serverName}`]: {} } };
}
await oauthStart(`https://auth.example.com/${serverName}`);
return { availableTools: null };
},
);
await loadAgentTools({
req,
res,
agent: {
id: 'agent_123',
tools: servers.map((server) => `search${Constants.mcp_delimiter}${server}`),
},
definitionsOnly: true,
});
const runStepEvents = mockSendEvent.mock.calls
.map(([, event]) => event)
.filter((event) => event.data?.stepDetails?.type === 'tool_calls');
const deltaEvents = mockSendEvent.mock.calls
.map(([, event]) => event)
.filter((event) => event.data?.delta?.type === 'tool_calls');
const authDeltaEvents = deltaEvents.filter((event) => event.data.delta.auth);
const completionEvents = mockSendEvent.mock.calls
.map(([, event]) => event)
.filter((event) => event.data?.result?.tool_call?.name?.startsWith('oauth'));
expect(runStepEvents.map((event) => event.data.index)).toEqual([0, 1]);
expect(authDeltaEvents.map((event) => event.data.id)).toEqual([
'step_oauth_login_ELI',
'step_oauth_login_Vespa',
]);
expect(completionEvents.map((event) => event.data.result.id)).toEqual([
'step_oauth_login_ELI',
'step_oauth_login_Vespa',
]);
});
it('should not expose cached MCP tool definitions when the registry lookup fails', async () => {
const serverName = 'private-server';
const mcpTool = `search${Constants.mcp_delimiter}${serverName}`;

View file

@ -234,6 +234,182 @@ describe('useStepHandler', () => {
);
});
it('should preserve multiple tool call steps for the same preliminary response', () => {
const responseMessage = createResponseMessage();
let currentMessages = [responseMessage];
mockGetMessages.mockImplementation(() => currentMessages);
mockSetMessages.mockImplementation((messages) => {
currentMessages = messages;
});
const { result } = renderHook(() => useStepHandler(createHookParams()));
const submission = createSubmission({ initialResponse: responseMessage });
const firstRunStep = createToolCallRunStep({
id: 'step-oauth-eli',
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
index: 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-eli',
name: `oauth${Constants.mcp_delimiter}ELI`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
});
const secondRunStep = createToolCallRunStep({
id: 'step-oauth-vespa',
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
index: 1,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-vespa',
name: `oauth${Constants.mcp_delimiter}Vespa`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
});
act(() => {
result.current.stepHandler(
{ event: StepEvents.ON_RUN_STEP, data: firstRunStep },
submission,
);
result.current.stepHandler(
{ event: StepEvents.ON_RUN_STEP, data: secondRunStep },
submission,
);
});
const responseMsg = currentMessages.find((m) => !m.isCreatedByUser);
expect(responseMsg?.content).toHaveLength(2);
expect(responseMsg?.content?.[0]?.tool_call?.name).toBe(`oauth${Constants.mcp_delimiter}ELI`);
expect(responseMsg?.content?.[1]?.tool_call?.name).toBe(
`oauth${Constants.mcp_delimiter}Vespa`,
);
});
it('should clear OAuth prompt slots when one occupies the real response slot', () => {
const responseMessage = createResponseMessage();
let currentMessages = [responseMessage];
mockGetMessages.mockImplementation(() => currentMessages);
mockSetMessages.mockImplementation((messages) => {
currentMessages = messages;
});
const { result } = renderHook(() => useStepHandler(createHookParams()));
const submission = createSubmission({ initialResponse: responseMessage });
act(() => {
result.current.stepHandler(
{
event: StepEvents.ON_RUN_STEP,
data: createToolCallRunStep({
id: 'step-oauth-eli',
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
index: 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-eli',
name: `oauth${Constants.mcp_delimiter}ELI`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
}),
},
submission,
);
result.current.stepHandler(
{
event: StepEvents.ON_RUN_STEP,
data: createToolCallRunStep({
id: 'step-oauth-vespa',
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
index: 1,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-vespa',
name: `oauth${Constants.mcp_delimiter}Vespa`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
}),
},
submission,
);
result.current.stepHandler(
{
event: StepEvents.ON_RUN_STEP,
data: createRunStep({
id: 'step-message',
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
index: 0,
}),
},
submission,
);
result.current.stepHandler(
{ event: StepEvents.ON_MESSAGE_DELTA, data: createMessageDelta('step-message', 'Ready') },
submission,
);
});
const responseMsg = currentMessages.find((m) => !m.isCreatedByUser);
expect(responseMsg?.content).toEqual([{ type: ContentTypes.TEXT, text: 'Ready' }]);
});
it('should not replace the message list from a shorter refresh during tool call steps', () => {
const userMessage = createUserMessage();
const responseMessage = createResponseMessage();
mockGetMessages.mockReturnValueOnce([userMessage, responseMessage]).mockReturnValueOnce([]);
const { result } = renderHook(() => useStepHandler(createHookParams()));
const runStep = createToolCallRunStep({
runId: responseMessage.messageId,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-eli',
name: `oauth${Constants.mcp_delimiter}ELI`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
});
act(() => {
result.current.stepHandler(
{ event: StepEvents.ON_RUN_STEP, data: runStep },
createSubmission({ userMessage, initialResponse: responseMessage }),
);
});
const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0];
expect(lastCall.map((message: TMessage) => message.messageId)).toEqual([
userMessage.messageId,
responseMessage.messageId,
]);
});
it('should replay buffered deltas after registering step', () => {
const responseMessage = createResponseMessage();
mockGetMessages.mockReturnValue([responseMessage]);
@ -756,6 +932,59 @@ describe('useStepHandler', () => {
);
consoleSpy.mockRestore();
});
it('should mark completed OAuth prompts as finished', () => {
const responseMessage = createResponseMessage();
mockGetMessages.mockReturnValue([responseMessage]);
const { result } = renderHook(() => useStepHandler(createHookParams()));
const runStep = createToolCallRunStep({
id: 'step-oauth-eli',
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [
{
id: 'tool-call-eli',
name: `oauth${Constants.mcp_delimiter}ELI`,
args: '',
type: ToolCallTypes.TOOL_CALL,
},
],
},
});
const submission = createSubmission();
act(() => {
result.current.stepHandler({ event: StepEvents.ON_RUN_STEP, data: runStep }, submission);
});
act(() => {
result.current.stepHandler(
{
event: StepEvents.ON_RUN_STEP_COMPLETED,
data: {
result: {
id: 'step-oauth-eli',
index: 0,
tool_call: {
id: 'tool-call-eli',
name: `oauth${Constants.mcp_delimiter}ELI`,
args: '',
output: 'OAuth authentication completed',
type: ToolCallTypes.TOOL_CALL,
},
},
},
},
submission,
);
});
const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0];
const responseMsg = lastCall.find((m: TMessage) => !m.isCreatedByUser);
expect(responseMsg?.content?.[0]?.tool_call?.progress).toBe(1);
});
});
describe('clearStepMaps', () => {

View file

@ -63,6 +63,14 @@ type AllContentTypes =
| ContentTypes.SUMMARY
| ContentTypes.ERROR;
const isOAuthToolCallName = (name?: string) =>
typeof name === 'string' && name.startsWith(`oauth${Constants.mcp_delimiter}`);
const isOAuthToolCallContent = (part?: Partial<TMessageContentParts>) =>
part?.type === ContentTypes.TOOL_CALL &&
'tool_call' in part &&
isOAuthToolCallName(part.tool_call?.name);
export default function useStepHandler({
setMessages,
getMessages,
@ -100,6 +108,14 @@ export default function useStepHandler({
*/
const knownSubagentAtomKeys = useRef(new Set<string>());
const getCurrentMessages = useCallback(
(messages: TMessage[]) => {
const freshMessages = getMessages();
return freshMessages && freshMessages.length >= messages.length ? freshMessages : messages;
},
[getMessages],
);
/** Both content parts and ticker lines are aggregated incrementally
* into the atom as each `ON_SUBAGENT_UPDATE` arrives we never
* retain the raw event array, so no rolling window is needed. A
@ -270,9 +286,20 @@ export default function useStepHandler({
return message;
}
const updatedContent = [...(message.content || [])] as Array<
const incomingOAuthToolCall =
contentType === ContentTypes.TOOL_CALL &&
'tool_call' in contentPart &&
isOAuthToolCallName(contentPart.tool_call?.name);
let updatedContent = [...(message.content || [])] as Array<
Partial<TMessageContentParts> | undefined
>;
const oauthPromptOccupiesSlot = isOAuthToolCallContent(updatedContent[index]);
if (!incomingOAuthToolCall && oauthPromptOccupiesSlot) {
updatedContent = updatedContent.filter((part) => !isOAuthToolCallContent(part));
}
if (!updatedContent[index] && contentType !== ContentTypes.TOOL_CALL) {
updatedContent[index] = { type: contentPart.type as AllContentTypes };
}
@ -517,9 +544,15 @@ export default function useStepHandler({
});
messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
const currentMessages = getCurrentMessages(messages);
const hasResponseMessage = currentMessages.some(
(msg) => msg.messageId === responseMessageId,
);
const updatedMessages = hasResponseMessage
? currentMessages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
)
: [...currentMessages, updatedResponse];
setMessages(updatedMessages);
}
@ -724,9 +757,15 @@ export default function useStepHandler({
});
messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
const currentMessages = getCurrentMessages(messages);
const hasResponseMessage = currentMessages.some(
(msg) => msg.messageId === responseMessageId,
);
const updatedMessages = hasResponseMessage
? currentMessages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
)
: [...currentMessages, updatedResponse];
setMessages(updatedMessages);
}
@ -767,9 +806,15 @@ export default function useStepHandler({
);
messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
const currentMessages = getCurrentMessages(messages);
const hasResponseMessage = currentMessages.some(
(msg) => msg.messageId === responseMessageId,
);
const updatedMessages = hasResponseMessage
? currentMessages.map((msg) =>
msg.messageId === responseMessageId ? updatedResponse : msg,
)
: [...currentMessages, updatedResponse];
setMessages(updatedMessages);
}
@ -883,6 +928,7 @@ export default function useStepHandler({
announcePolite,
setMessages,
calculateContentIndex,
getCurrentMessages,
applySubagentUpdate,
],
);