diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 2c63b81294..8b8fa17503 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -74,9 +74,23 @@ export class MCPConnectionFactory { oauthTokens, }); - if (this.useOAuth) this.handleOAuthEvents(connection); - await this.attemptToConnect(connection); - return connection; + let cleanupOAuthHandlers: (() => void) | null = null; + if (this.useOAuth) { + cleanupOAuthHandlers = this.handleOAuthEvents(connection); + } + + try { + await this.attemptToConnect(connection); + if (cleanupOAuthHandlers) { + cleanupOAuthHandlers(); + } + return connection; + } catch (error) { + if (cleanupOAuthHandlers) { + cleanupOAuthHandlers(); + } + throw error; + } } /** Retrieves existing OAuth tokens from storage or returns null */ @@ -133,8 +147,8 @@ export class MCPConnectionFactory { } /** Sets up OAuth event handlers for the connection */ - protected handleOAuthEvents(connection: MCPConnection): void { - connection.on('oauthRequired', async (data) => { + protected handleOAuthEvents(connection: MCPConnection): () => void { + const oauthHandler = async (data: { serverUrl?: string }) => { logger.info(`${this.logPrefix} oauthRequired event received`); // If we just want to initiate OAuth and return, handle it differently @@ -202,7 +216,23 @@ export class MCPConnectionFactory { logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`); connection.emit('oauthFailed', new Error('OAuth authentication failed')); } - }); + }; + + connection.on('oauthRequired', oauthHandler); + + /** Handler reference for cleanup when connection state changes to disconnected */ + const cleanupHandler = (state: string) => { + if (state === 'disconnected') { + connection.removeListener('oauthRequired', oauthHandler); + connection.removeListener('connectionChange', cleanupHandler); + } + }; + connection.on('connectionChange', cleanupHandler); + + return () => { + connection.removeListener('oauthRequired', oauthHandler); + connection.removeListener('connectionChange', cleanupHandler); + }; } /** Attempts to establish connection with timeout handling */ diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index c739b8ffce..e96d207f29 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -56,6 +56,9 @@ describe('MCPConnectionFactory', () => { isConnected: jest.fn(), setOAuthTokens: jest.fn(), on: jest.fn().mockReturnValue(mockConnectionInstance), + once: jest.fn().mockReturnValue(mockConnectionInstance), + off: jest.fn().mockReturnValue(mockConnectionInstance), + removeListener: jest.fn().mockReturnValue(mockConnectionInstance), emit: jest.fn(), } as unknown as jest.Mocked;