mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-10 01:44:44 +00:00
🔐 fix: Handle Multiple Concurrent MCP OAuth Login Prompts (#13200)
* fix: handle multiple MCP OAuth prompts * fix: address MCP OAuth review feedback * fix: address MCP OAuth prompt lifecycle review * fix: narrow OAuth prompt slot cleanup * fix: format OAuth prompt test --------- Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
parent
aeb5adff34
commit
2ed59ac98a
5 changed files with 436 additions and 32 deletions
|
|
@ -596,11 +596,15 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const pendingOAuthServers = new Set();
|
||||
const oauthToolCallIds = new Map();
|
||||
const oauthStepIndexes = new Map();
|
||||
|
||||
const createOAuthEmitter = (serverName) => {
|
||||
const createOAuthEmitter = (serverName, index) => {
|
||||
return async (authURL) => {
|
||||
const flowId = `${req.user.id}:${serverName}:${Date.now()}`;
|
||||
const stepId = 'step_oauth_login_' + serverName;
|
||||
oauthToolCallIds.set(serverName, flowId);
|
||||
oauthStepIndexes.set(serverName, index);
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
|
|
@ -611,7 +615,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
id: stepId,
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
index: 0,
|
||||
index,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [toolCall],
|
||||
|
|
@ -645,6 +649,40 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
};
|
||||
};
|
||||
|
||||
const createOAuthEndEmitter = (serverName) => {
|
||||
return async () => {
|
||||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: oauthToolCallIds.get(serverName),
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
args: '',
|
||||
output: 'OAuth authentication completed',
|
||||
type: 'tool_call',
|
||||
};
|
||||
|
||||
const runStepCompletedEvent = {
|
||||
event: GraphEvents.ON_RUN_STEP_COMPLETED,
|
||||
data: {
|
||||
result: {
|
||||
id: stepId,
|
||||
index: oauthStepIndexes.get(serverName) ?? 0,
|
||||
tool_call: toolCall,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if (streamId) {
|
||||
await GenerationJobManager.emitChunk(streamId, runStepCompletedEvent);
|
||||
} else if (res && !res.writableEnded) {
|
||||
sendEvent(res, runStepCompletedEvent);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[Tool Definitions] Cannot emit OAuth completion for ${serverName}: no streamId and res not available`,
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
const getOrFetchMCPServerTools = async (userId, serverName) => {
|
||||
let serverConfig;
|
||||
try {
|
||||
|
|
@ -781,7 +819,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
`[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`,
|
||||
);
|
||||
|
||||
const oauthWaitPromises = serverNames.map(async (serverName) => {
|
||||
const oauthWaitPromises = serverNames.map(async (serverName, index) => {
|
||||
try {
|
||||
const result = await reinitMCPServer({
|
||||
user: req.user,
|
||||
|
|
@ -790,7 +828,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
userMCPAuthMap,
|
||||
flowManager,
|
||||
returnOnOAuth: false,
|
||||
oauthStart: createOAuthEmitter(serverName),
|
||||
oauthStart: createOAuthEmitter(serverName, index),
|
||||
oauthEnd: createOAuthEndEmitter(serverName),
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const { getLogStores } = require('~/cache');
|
|||
* @param {number} [params.connectionTimeout]
|
||||
* @param {FlowStateManager<any>} [params.flowManager]
|
||||
* @param {(authURL: string) => Promise<void>} [params.oauthStart]
|
||||
* @param {() => Promise<void>} [params.oauthEnd]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
*/
|
||||
async function reinitMCPServer({
|
||||
|
|
@ -35,6 +36,7 @@ async function reinitMCPServer({
|
|||
oauthStart: _oauthStart,
|
||||
flowManager: _flowManager,
|
||||
serverConfig: providedConfig,
|
||||
oauthEnd,
|
||||
}) {
|
||||
/** @type {MCPConnection | null} */
|
||||
let connection = null;
|
||||
|
|
@ -63,28 +65,29 @@ async function reinitMCPServer({
|
|||
oauthUrl: null,
|
||||
tools: null,
|
||||
};
|
||||
}
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
|
||||
);
|
||||
try {
|
||||
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
|
||||
await registry.reinspectServer(serverName, storageLocation, user?.id);
|
||||
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
|
||||
} catch (reinspectError) {
|
||||
logger.error(
|
||||
`[MCP Reinitialize] Reinspection failed for server ${serverName}:`,
|
||||
reinspectError,
|
||||
} else {
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
|
||||
);
|
||||
return {
|
||||
availableTools: null,
|
||||
success: false,
|
||||
message: `MCP server '${serverName}' is still unreachable`,
|
||||
oauthRequired: false,
|
||||
serverName,
|
||||
oauthUrl: null,
|
||||
tools: null,
|
||||
};
|
||||
try {
|
||||
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
|
||||
await registry.reinspectServer(serverName, storageLocation, user?.id);
|
||||
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
|
||||
} catch (reinspectError) {
|
||||
logger.error(
|
||||
`[MCP Reinitialize] Reinspection failed for server ${serverName}:`,
|
||||
reinspectError,
|
||||
);
|
||||
return {
|
||||
availableTools: null,
|
||||
success: false,
|
||||
message: `MCP server '${serverName}' is still unreachable`,
|
||||
oauthRequired: false,
|
||||
serverName,
|
||||
oauthUrl: null,
|
||||
tools: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -132,6 +135,7 @@ async function reinitMCPServer({
|
|||
flowManager,
|
||||
tokenMethods,
|
||||
returnOnOAuth,
|
||||
oauthEnd,
|
||||
customUserVars,
|
||||
connectionTimeout,
|
||||
serverConfig,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ const {
|
|||
const mockGetEndpointsConfig = jest.fn();
|
||||
const mockGetMCPServerTools = jest.fn();
|
||||
const mockGetCachedTools = jest.fn();
|
||||
const mockSendEvent = jest.fn();
|
||||
const mockEmitChunk = jest.fn();
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args),
|
||||
getMCPServerTools: (...args) => mockGetMCPServerTools(...args),
|
||||
|
|
@ -23,6 +25,10 @@ jest.mock('@librechat/api', () => ({
|
|||
...jest.requireActual('@librechat/api'),
|
||||
loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args),
|
||||
getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args),
|
||||
sendEvent: (...args) => mockSendEvent(...args),
|
||||
GenerationJobManager: {
|
||||
emitChunk: (...args) => mockEmitChunk(...args),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockLoadToolsUtil = jest.fn();
|
||||
|
|
@ -93,6 +99,7 @@ const {
|
|||
processRequiredActions,
|
||||
resolveAgentCapabilities,
|
||||
} = require('../ToolService');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
|
||||
function createMockReq(capabilities) {
|
||||
return {
|
||||
|
|
@ -297,6 +304,85 @@ describe('ToolService - Action Capability Gating', () => {
|
|||
expect(result.actionsEnabled).toBe(false);
|
||||
});
|
||||
|
||||
it('emits separate MCP OAuth login steps and completion events for multiple pending servers', async () => {
|
||||
const req = createMockReq([AgentCapabilities.tools]);
|
||||
const res = { writableEnded: false };
|
||||
const servers = ['ELI', 'Vespa'];
|
||||
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig([AgentCapabilities.tools]));
|
||||
mockResolveConfigServers.mockResolvedValue(
|
||||
Object.fromEntries(
|
||||
servers.map((serverName) => [
|
||||
serverName,
|
||||
{
|
||||
type: 'streamable-http',
|
||||
url: `https://mcp.example.com/${serverName}`,
|
||||
requiresOAuth: true,
|
||||
},
|
||||
]),
|
||||
),
|
||||
);
|
||||
|
||||
mockLoadToolDefinitions
|
||||
.mockImplementationOnce(async (_args, deps) => {
|
||||
await deps.getOrFetchMCPServerTools(req.user.id, servers[0]);
|
||||
await deps.getOrFetchMCPServerTools(req.user.id, servers[1]);
|
||||
return {
|
||||
toolDefinitions: [],
|
||||
toolRegistry: new Map(),
|
||||
hasDeferredTools: false,
|
||||
};
|
||||
})
|
||||
.mockResolvedValue({
|
||||
toolDefinitions: [],
|
||||
toolRegistry: new Map(),
|
||||
hasDeferredTools: false,
|
||||
});
|
||||
|
||||
reinitMCPServer.mockImplementation(
|
||||
async ({ serverName, returnOnOAuth, oauthStart, oauthEnd }) => {
|
||||
if (returnOnOAuth === false) {
|
||||
await oauthStart(`https://auth.example.com/${serverName}`);
|
||||
await oauthEnd();
|
||||
return { availableTools: { [`tool_${serverName}`]: {} } };
|
||||
}
|
||||
|
||||
await oauthStart(`https://auth.example.com/${serverName}`);
|
||||
return { availableTools: null };
|
||||
},
|
||||
);
|
||||
|
||||
await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent: {
|
||||
id: 'agent_123',
|
||||
tools: servers.map((server) => `search${Constants.mcp_delimiter}${server}`),
|
||||
},
|
||||
definitionsOnly: true,
|
||||
});
|
||||
|
||||
const runStepEvents = mockSendEvent.mock.calls
|
||||
.map(([, event]) => event)
|
||||
.filter((event) => event.data?.stepDetails?.type === 'tool_calls');
|
||||
const deltaEvents = mockSendEvent.mock.calls
|
||||
.map(([, event]) => event)
|
||||
.filter((event) => event.data?.delta?.type === 'tool_calls');
|
||||
const authDeltaEvents = deltaEvents.filter((event) => event.data.delta.auth);
|
||||
const completionEvents = mockSendEvent.mock.calls
|
||||
.map(([, event]) => event)
|
||||
.filter((event) => event.data?.result?.tool_call?.name?.startsWith('oauth'));
|
||||
|
||||
expect(runStepEvents.map((event) => event.data.index)).toEqual([0, 1]);
|
||||
expect(authDeltaEvents.map((event) => event.data.id)).toEqual([
|
||||
'step_oauth_login_ELI',
|
||||
'step_oauth_login_Vespa',
|
||||
]);
|
||||
expect(completionEvents.map((event) => event.data.result.id)).toEqual([
|
||||
'step_oauth_login_ELI',
|
||||
'step_oauth_login_Vespa',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should not expose cached MCP tool definitions when the registry lookup fails', async () => {
|
||||
const serverName = 'private-server';
|
||||
const mcpTool = `search${Constants.mcp_delimiter}${serverName}`;
|
||||
|
|
|
|||
|
|
@ -234,6 +234,182 @@ describe('useStepHandler', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should preserve multiple tool call steps for the same preliminary response', () => {
|
||||
const responseMessage = createResponseMessage();
|
||||
let currentMessages = [responseMessage];
|
||||
mockGetMessages.mockImplementation(() => currentMessages);
|
||||
mockSetMessages.mockImplementation((messages) => {
|
||||
currentMessages = messages;
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useStepHandler(createHookParams()));
|
||||
const submission = createSubmission({ initialResponse: responseMessage });
|
||||
|
||||
const firstRunStep = createToolCallRunStep({
|
||||
id: 'step-oauth-eli',
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
index: 0,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-eli',
|
||||
name: `oauth${Constants.mcp_delimiter}ELI`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
const secondRunStep = createToolCallRunStep({
|
||||
id: 'step-oauth-vespa',
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
index: 1,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-vespa',
|
||||
name: `oauth${Constants.mcp_delimiter}Vespa`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.stepHandler(
|
||||
{ event: StepEvents.ON_RUN_STEP, data: firstRunStep },
|
||||
submission,
|
||||
);
|
||||
result.current.stepHandler(
|
||||
{ event: StepEvents.ON_RUN_STEP, data: secondRunStep },
|
||||
submission,
|
||||
);
|
||||
});
|
||||
|
||||
const responseMsg = currentMessages.find((m) => !m.isCreatedByUser);
|
||||
expect(responseMsg?.content).toHaveLength(2);
|
||||
expect(responseMsg?.content?.[0]?.tool_call?.name).toBe(`oauth${Constants.mcp_delimiter}ELI`);
|
||||
expect(responseMsg?.content?.[1]?.tool_call?.name).toBe(
|
||||
`oauth${Constants.mcp_delimiter}Vespa`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear OAuth prompt slots when one occupies the real response slot', () => {
|
||||
const responseMessage = createResponseMessage();
|
||||
let currentMessages = [responseMessage];
|
||||
mockGetMessages.mockImplementation(() => currentMessages);
|
||||
mockSetMessages.mockImplementation((messages) => {
|
||||
currentMessages = messages;
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useStepHandler(createHookParams()));
|
||||
const submission = createSubmission({ initialResponse: responseMessage });
|
||||
|
||||
act(() => {
|
||||
result.current.stepHandler(
|
||||
{
|
||||
event: StepEvents.ON_RUN_STEP,
|
||||
data: createToolCallRunStep({
|
||||
id: 'step-oauth-eli',
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
index: 0,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-eli',
|
||||
name: `oauth${Constants.mcp_delimiter}ELI`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
},
|
||||
submission,
|
||||
);
|
||||
result.current.stepHandler(
|
||||
{
|
||||
event: StepEvents.ON_RUN_STEP,
|
||||
data: createToolCallRunStep({
|
||||
id: 'step-oauth-vespa',
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
index: 1,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-vespa',
|
||||
name: `oauth${Constants.mcp_delimiter}Vespa`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
},
|
||||
submission,
|
||||
);
|
||||
result.current.stepHandler(
|
||||
{
|
||||
event: StepEvents.ON_RUN_STEP,
|
||||
data: createRunStep({
|
||||
id: 'step-message',
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
index: 0,
|
||||
}),
|
||||
},
|
||||
submission,
|
||||
);
|
||||
result.current.stepHandler(
|
||||
{ event: StepEvents.ON_MESSAGE_DELTA, data: createMessageDelta('step-message', 'Ready') },
|
||||
submission,
|
||||
);
|
||||
});
|
||||
|
||||
const responseMsg = currentMessages.find((m) => !m.isCreatedByUser);
|
||||
expect(responseMsg?.content).toEqual([{ type: ContentTypes.TEXT, text: 'Ready' }]);
|
||||
});
|
||||
|
||||
it('should not replace the message list from a shorter refresh during tool call steps', () => {
|
||||
const userMessage = createUserMessage();
|
||||
const responseMessage = createResponseMessage();
|
||||
mockGetMessages.mockReturnValueOnce([userMessage, responseMessage]).mockReturnValueOnce([]);
|
||||
|
||||
const { result } = renderHook(() => useStepHandler(createHookParams()));
|
||||
|
||||
const runStep = createToolCallRunStep({
|
||||
runId: responseMessage.messageId,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-eli',
|
||||
name: `oauth${Constants.mcp_delimiter}ELI`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.stepHandler(
|
||||
{ event: StepEvents.ON_RUN_STEP, data: runStep },
|
||||
createSubmission({ userMessage, initialResponse: responseMessage }),
|
||||
);
|
||||
});
|
||||
|
||||
const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0];
|
||||
expect(lastCall.map((message: TMessage) => message.messageId)).toEqual([
|
||||
userMessage.messageId,
|
||||
responseMessage.messageId,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should replay buffered deltas after registering step', () => {
|
||||
const responseMessage = createResponseMessage();
|
||||
mockGetMessages.mockReturnValue([responseMessage]);
|
||||
|
|
@ -756,6 +932,59 @@ describe('useStepHandler', () => {
|
|||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should mark completed OAuth prompts as finished', () => {
|
||||
const responseMessage = createResponseMessage();
|
||||
mockGetMessages.mockReturnValue([responseMessage]);
|
||||
|
||||
const { result } = renderHook(() => useStepHandler(createHookParams()));
|
||||
|
||||
const runStep = createToolCallRunStep({
|
||||
id: 'step-oauth-eli',
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'tool-call-eli',
|
||||
name: `oauth${Constants.mcp_delimiter}ELI`,
|
||||
args: '',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
const submission = createSubmission();
|
||||
|
||||
act(() => {
|
||||
result.current.stepHandler({ event: StepEvents.ON_RUN_STEP, data: runStep }, submission);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.stepHandler(
|
||||
{
|
||||
event: StepEvents.ON_RUN_STEP_COMPLETED,
|
||||
data: {
|
||||
result: {
|
||||
id: 'step-oauth-eli',
|
||||
index: 0,
|
||||
tool_call: {
|
||||
id: 'tool-call-eli',
|
||||
name: `oauth${Constants.mcp_delimiter}ELI`,
|
||||
args: '',
|
||||
output: 'OAuth authentication completed',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
submission,
|
||||
);
|
||||
});
|
||||
|
||||
const lastCall = mockSetMessages.mock.calls[mockSetMessages.mock.calls.length - 1][0];
|
||||
const responseMsg = lastCall.find((m: TMessage) => !m.isCreatedByUser);
|
||||
expect(responseMsg?.content?.[0]?.tool_call?.progress).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearStepMaps', () => {
|
||||
|
|
|
|||
|
|
@ -63,6 +63,14 @@ type AllContentTypes =
|
|||
| ContentTypes.SUMMARY
|
||||
| ContentTypes.ERROR;
|
||||
|
||||
const isOAuthToolCallName = (name?: string) =>
|
||||
typeof name === 'string' && name.startsWith(`oauth${Constants.mcp_delimiter}`);
|
||||
|
||||
const isOAuthToolCallContent = (part?: Partial<TMessageContentParts>) =>
|
||||
part?.type === ContentTypes.TOOL_CALL &&
|
||||
'tool_call' in part &&
|
||||
isOAuthToolCallName(part.tool_call?.name);
|
||||
|
||||
export default function useStepHandler({
|
||||
setMessages,
|
||||
getMessages,
|
||||
|
|
@ -100,6 +108,14 @@ export default function useStepHandler({
|
|||
*/
|
||||
const knownSubagentAtomKeys = useRef(new Set<string>());
|
||||
|
||||
const getCurrentMessages = useCallback(
|
||||
(messages: TMessage[]) => {
|
||||
const freshMessages = getMessages();
|
||||
return freshMessages && freshMessages.length >= messages.length ? freshMessages : messages;
|
||||
},
|
||||
[getMessages],
|
||||
);
|
||||
|
||||
/** Both content parts and ticker lines are aggregated incrementally
|
||||
* into the atom as each `ON_SUBAGENT_UPDATE` arrives — we never
|
||||
* retain the raw event array, so no rolling window is needed. A
|
||||
|
|
@ -270,9 +286,20 @@ export default function useStepHandler({
|
|||
return message;
|
||||
}
|
||||
|
||||
const updatedContent = [...(message.content || [])] as Array<
|
||||
const incomingOAuthToolCall =
|
||||
contentType === ContentTypes.TOOL_CALL &&
|
||||
'tool_call' in contentPart &&
|
||||
isOAuthToolCallName(contentPart.tool_call?.name);
|
||||
|
||||
let updatedContent = [...(message.content || [])] as Array<
|
||||
Partial<TMessageContentParts> | undefined
|
||||
>;
|
||||
|
||||
const oauthPromptOccupiesSlot = isOAuthToolCallContent(updatedContent[index]);
|
||||
if (!incomingOAuthToolCall && oauthPromptOccupiesSlot) {
|
||||
updatedContent = updatedContent.filter((part) => !isOAuthToolCallContent(part));
|
||||
}
|
||||
|
||||
if (!updatedContent[index] && contentType !== ContentTypes.TOOL_CALL) {
|
||||
updatedContent[index] = { type: contentPart.type as AllContentTypes };
|
||||
}
|
||||
|
|
@ -517,9 +544,15 @@ export default function useStepHandler({
|
|||
});
|
||||
|
||||
messageMap.current.set(responseMessageId, updatedResponse);
|
||||
const updatedMessages = messages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
const currentMessages = getCurrentMessages(messages);
|
||||
const hasResponseMessage = currentMessages.some(
|
||||
(msg) => msg.messageId === responseMessageId,
|
||||
);
|
||||
const updatedMessages = hasResponseMessage
|
||||
? currentMessages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
)
|
||||
: [...currentMessages, updatedResponse];
|
||||
|
||||
setMessages(updatedMessages);
|
||||
}
|
||||
|
|
@ -724,9 +757,15 @@ export default function useStepHandler({
|
|||
});
|
||||
|
||||
messageMap.current.set(responseMessageId, updatedResponse);
|
||||
const updatedMessages = messages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
const currentMessages = getCurrentMessages(messages);
|
||||
const hasResponseMessage = currentMessages.some(
|
||||
(msg) => msg.messageId === responseMessageId,
|
||||
);
|
||||
const updatedMessages = hasResponseMessage
|
||||
? currentMessages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
)
|
||||
: [...currentMessages, updatedResponse];
|
||||
|
||||
setMessages(updatedMessages);
|
||||
}
|
||||
|
|
@ -767,9 +806,15 @@ export default function useStepHandler({
|
|||
);
|
||||
|
||||
messageMap.current.set(responseMessageId, updatedResponse);
|
||||
const updatedMessages = messages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
const currentMessages = getCurrentMessages(messages);
|
||||
const hasResponseMessage = currentMessages.some(
|
||||
(msg) => msg.messageId === responseMessageId,
|
||||
);
|
||||
const updatedMessages = hasResponseMessage
|
||||
? currentMessages.map((msg) =>
|
||||
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||
)
|
||||
: [...currentMessages, updatedResponse];
|
||||
|
||||
setMessages(updatedMessages);
|
||||
}
|
||||
|
|
@ -883,6 +928,7 @@ export default function useStepHandler({
|
|||
announcePolite,
|
||||
setMessages,
|
||||
calculateContentIndex,
|
||||
getCurrentMessages,
|
||||
applySubagentUpdate,
|
||||
],
|
||||
);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue