refactor: Fast-Fail MCP Tool Discovery on 401 for Non-OAuth Servers (#12395)

* fix: fast-fail MCP discovery for non-OAuth servers on auth errors

Always attach oauthHandler in discoverToolsInternal regardless of
useOAuth flag. Previously, non-OAuth servers hitting 401 would hang
for 30s because connectClient's oauthHandledPromise had no listener
to emit oauthFailed, waiting until withTimeout killed it.

* chore: import order
This commit is contained in:
Danny Avila 2026-03-25 13:18:02 -04:00 committed by GitHub
parent 3f805d68a1
commit 221e49222d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 80 additions and 17 deletions

View file

@ -81,7 +81,7 @@ export class MCPConnectionFactory {
useSSRFProtection: this.useSSRFProtection, useSSRFProtection: this.useSSRFProtection,
}); });
const oauthHandler = async () => { const oauthHandler = () => {
logger.info( logger.info(
`${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`, `${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`,
); );
@ -89,9 +89,9 @@ export class MCPConnectionFactory {
connection.emit('oauthFailed', new Error('OAuth required during tool discovery')); connection.emit('oauthFailed', new Error('OAuth required during tool discovery'));
}; };
if (this.useOAuth) { // Register unconditionally: non-OAuth servers that return 401 also emit 'oauthRequired',
connection.on('oauthRequired', oauthHandler); // and without this listener, connectClient()'s oauthHandledPromise hangs for 30s+.
} connection.once('oauthRequired', oauthHandler);
try { try {
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
@ -103,9 +103,7 @@ export class MCPConnectionFactory {
if (await connection.isConnected()) { if (await connection.isConnected()) {
const tools = await connection.fetchTools(); const tools = await connection.fetchTools();
if (this.useOAuth) { connection.removeListener('oauthRequired', oauthHandler);
connection.removeListener('oauthRequired', oauthHandler);
}
return { tools, connection, oauthRequired: false, oauthUrl: null }; return { tools, connection, oauthRequired: false, oauthUrl: null };
} }
} catch { } catch {
@ -117,9 +115,7 @@ export class MCPConnectionFactory {
try { try {
const tools = await this.attemptUnauthenticatedToolListing(); const tools = await this.attemptUnauthenticatedToolListing();
if (this.useOAuth) { connection.removeListener('oauthRequired', oauthHandler);
connection.removeListener('oauthRequired', oauthHandler);
}
if (tools && tools.length > 0) { if (tools && tools.length > 0) {
logger.info( logger.info(
`${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`, `${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`,
@ -137,9 +133,7 @@ export class MCPConnectionFactory {
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
} }
if (this.useOAuth) { connection.removeListener('oauthRequired', oauthHandler);
connection.removeListener('oauthRequired', oauthHandler);
}
try { try {
await connection.disconnect(); await connection.disconnect();

View file

@ -825,17 +825,17 @@ describe('MCPConnectionFactory', () => {
mockConnectionInstance.isConnected.mockResolvedValue(false); mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined); mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
let oauthHandler: (() => Promise<void>) | undefined; let oauthHandler: (() => void) | undefined;
mockConnectionInstance.on.mockImplementation((event, handler) => { mockConnectionInstance.once.mockImplementation((event, handler) => {
if (event === 'oauthRequired') { if (event === 'oauthRequired') {
oauthHandler = handler as () => Promise<void>; oauthHandler = handler as () => void;
} }
return mockConnectionInstance; return mockConnectionInstance;
}); });
mockConnectionInstance.connect.mockImplementation(async () => { mockConnectionInstance.connect.mockImplementation(async () => {
if (oauthHandler) { if (oauthHandler) {
await oauthHandler(); oauthHandler();
} }
throw new Error('OAuth required'); throw new Error('OAuth required');
}); });
@ -849,6 +849,46 @@ describe('MCPConnectionFactory', () => {
expect(mockOAuthStart).not.toHaveBeenCalled(); expect(mockOAuthStart).not.toHaveBeenCalled();
}); });
it('should fast-fail discovery when non-OAuth server returns 401', async () => {
const basicOptions = {
serverName: 'github',
serverConfig: {
...mockServerConfig,
url: 'https://api.githubcopilot.com/mcp/',
type: 'streamable-http' as const,
initTimeout: 30000,
} as t.StreamableHTTPOptions,
};
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
let oauthHandler: (() => void) | undefined;
mockConnectionInstance.once.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthHandler = handler as () => void;
}
return mockConnectionInstance;
});
mockConnectionInstance.connect.mockImplementation(async () => {
if (oauthHandler) {
oauthHandler();
}
throw Object.assign(new Error('unauthorized'), { code: 401 });
});
const start = Date.now();
const result = await MCPConnectionFactory.discoverTools(basicOptions);
const elapsed = Date.now() - start;
expect(elapsed).toBeLessThan(5000);
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBeNull();
expect(result.connection).toBeNull();
});
it('should return null tools when discovery fails completely', async () => { it('should return null tools when discovery fails completely', async () => {
const basicOptions = { const basicOptions = {
serverName: 'test-server', serverName: 'test-server',

View file

@ -6,8 +6,10 @@
*/ */
import { MCPConnection } from '~/mcp/connection'; import { MCPConnection } from '~/mcp/connection';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { createOAuthMCPServer } from './helpers/oauthTestServer'; import { createOAuthMCPServer } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer'; import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { StreamableHTTPOptions } from '~/mcp/types';
import type { MCPOAuthTokens } from '~/mcp/oauth'; import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({ jest.mock('@librechat/data-schemas', () => ({
@ -265,4 +267,31 @@ describe('MCPConnection OAuth Events — Real Server', () => {
expect(await connection.isConnected()).toBe(true); expect(await connection.isConnected()).toBe(true);
}); });
}); });
describe('MCPConnectionFactory.discoverTools — non-OAuth 401 fast-fail', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should fast-fail when a non-OAuth discovery hits 401', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
} as StreamableHTTPOptions,
};
const start = Date.now();
const result = await MCPConnectionFactory.discoverTools(basicOptions);
const elapsed = Date.now() - start;
expect(elapsed).toBeLessThan(5000);
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBeNull();
expect(result.connection).toBeNull();
});
});
}); });