diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4577d8a9ea..45623f9a9e 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -38,8 +38,8 @@ const { const { createMCPTool, createMCPTools, + createMCPPermissionContext, resolveConfigServers, - userCanUseMCPServers, } = require('~/server/services/MCP'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); @@ -233,7 +233,11 @@ const loadTools = async ({ const requestedTools = {}; const hasMCPTools = tools.some((toolName) => toolName && mcpToolPattern.test(toolName)); - const canUseMCP = hasMCPTools ? await userCanUseMCPServers(options.req?.user) : true; + const mcpPermissionContext = + options.mcpPermissionContext ?? createMCPPermissionContext(options.req); + const canUseMCP = hasMCPTools + ? await mcpPermissionContext.canUseServers(options.req?.user) + : true; let loggedMCPDenied = false; if (functions === true) { @@ -458,6 +462,7 @@ const loadTools = async ({ continue; } const mcpParams = { + mcpPermissionContext, index, signal, user: safeUser, diff --git a/api/server/controllers/agents/filterAuthorizedTools.spec.js b/api/server/controllers/agents/filterAuthorizedTools.spec.js index 4bdb870524..89835fac06 100644 --- a/api/server/controllers/agents/filterAuthorizedTools.spec.js +++ b/api/server/controllers/agents/filterAuthorizedTools.spec.js @@ -25,6 +25,9 @@ jest.mock('~/config', () => ({ jest.mock('~/server/services/MCP', () => ({ resolveConfigServers: jest.fn().mockResolvedValue({}), + createMCPPermissionContext: jest.fn((req) => ({ + canUseServers: (user) => mockUserCanUseMCPServers(user, req), + })), userCanUseMCPServers: (...args) => mockUserCanUseMCPServers(...args), })); diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 79ebbbd2ae..adb78d1b56 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -43,7 +43,11 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { getFileStrategy } = require('~/server/utils/getFileStrategy'); const { filterFile } = require('~/server/services/Files/process'); const { getCachedTools } = require('~/server/services/Config'); -const { resolveConfigServers, userCanUseMCPServers } = require('~/server/services/MCP'); +const { + createMCPPermissionContext, + resolveConfigServers, + userCanUseMCPServers, +} = require('~/server/services/MCP'); const { getMCPServersRegistry } = require('~/config'); const { getLogStores } = require('~/cache'); const db = require('~/models'); @@ -192,6 +196,7 @@ const isSubagentsCapabilityEnabled = (req) => { * @param {string} params.userId - Requesting user ID for MCP server access check * @param {string} [params.role] - Requesting user's role for ACL principal resolution * @param {object} [params.user] - Requesting user for MCP server use permission checks + * @param {{ canUseServers: (user?: object) => Promise }} [params.mcpPermissionContext] - Request-scoped MCP permission context * @param {Record} params.availableTools - Global non-MCP tool cache * @param {string[]} [params.existingTools] - Tools already persisted on the agent document * @param {Record} [params.configServers] - Config-source MCP servers resolved from appConfig overrides @@ -202,6 +207,7 @@ const filterAuthorizedTools = async ({ userId, role, user, + mcpPermissionContext, availableTools, existingTools, configServers, @@ -211,7 +217,11 @@ const filterAuthorizedTools = async ({ let registryUnavailable = false; const existingToolSet = existingTools?.length ? new Set(existingTools) : null; const hasMCPTools = tools.some((tool) => tool?.includes(Constants.mcp_delimiter)); - const canUseMCP = hasMCPTools ? await userCanUseMCPServers(user) : true; + const canUseMCP = hasMCPTools + ? await (mcpPermissionContext + ? mcpPermissionContext.canUseServers(user) + : userCanUseMCPServers(user)) + : true; let loggedMCPDenied = false; for (const tool of tools) { @@ -403,11 +413,13 @@ const createAgentHandler = async (req, res) => { getCachedTools().then((t) => t ?? {}), hasMCPTools ? resolveConfigServers(req) : Promise.resolve(undefined), ]); + const mcpPermissionContext = createMCPPermissionContext(req); agentData.tools = await filterAuthorizedTools({ tools, userId, role: req.user.role, user: req.user, + mcpPermissionContext, availableTools, configServers, }); @@ -637,7 +649,8 @@ const updateAgentHandler = async (req, res) => { const existingMCPTools = existingTools.filter(isMCPTool); if (requestedMCPTools.length > 0 || (hasToolUpdate && existingMCPTools.length > 0)) { - if (!(await userCanUseMCPServers(req.user))) { + const mcpPermissionContext = createMCPPermissionContext(req); + if (!(await mcpPermissionContext.canUseServers(req.user))) { if (editingOwnAgent) { updateData.tools = effectiveTools.filter((t) => !isMCPTool(t)); } else if (hasToolUpdate) { @@ -667,6 +680,7 @@ const updateAgentHandler = async (req, res) => { userId: req.user.id, role: req.user.role, user: req.user, + mcpPermissionContext, availableTools, configServers, }); @@ -826,11 +840,13 @@ const duplicateAgentHandler = async (req, res) => { getCachedTools().then((t) => t ?? {}), resolveConfigServers(req), ]); + const mcpPermissionContext = createMCPPermissionContext(req); newAgentData.tools = await filterAuthorizedTools({ tools: newAgentData.tools, userId, role: req.user.role, user: req.user, + mcpPermissionContext, availableTools, existingTools: newAgentData.tools, configServers, @@ -1202,11 +1218,13 @@ const revertAgentVersionHandler = async (req, res) => { getCachedTools().then((t) => t ?? {}), resolveConfigServers(req), ]); + const mcpPermissionContext = createMCPPermissionContext(req); const filteredTools = await filterAuthorizedTools({ tools: updatedAgent.tools, userId: req.user.id, role: req.user.role, user: req.user, + mcpPermissionContext, availableTools, existingTools: updatedAgent.tools, configServers, diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 2d81d66070..e8f6820d6e 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -7,7 +7,6 @@ const { Constants: AgentConstants, } = require('@librechat/agents'); const { - checkAccess, sendEvent, MCPOAuthHandler, isMCPDomainAllowed, @@ -16,6 +15,7 @@ const { GenerationJobManager, resolveJsonSchemaRefs, buildOAuthToolCallName, + checkAccessWithRequestCache, } = require('@librechat/api'); const { Time, @@ -45,13 +45,14 @@ const RECONNECT_THROTTLE_MS = 10_000; const missingToolCache = new Map(); const MISSING_TOOL_TTL_MS = 10_000; -async function userCanUseMCPServers(user) { +async function userCanUseMCPServers(user, req) { if (!user?.id || !user?.role) { return false; } try { - return await checkAccess({ + return await checkAccessWithRequestCache({ + req, user, permissionType: PermissionTypes.MCP_SERVERS, permissions: [Permissions.USE], @@ -63,6 +64,12 @@ async function userCanUseMCPServers(user) { } } +function createMCPPermissionContext(req) { + return { + canUseServers: (user = req?.user) => userCanUseMCPServers(user, req), + }; +} + function evictStale(map, ttl) { if (map.size <= MAX_CACHE_SIZE) { return; @@ -436,6 +443,7 @@ async function reconnectServer({ * * @param {Object} params * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {{ canUseServers: (user?: IUser) => Promise }} [params.mcpPermissionContext] - Request-scoped MCP permission context. * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName * @param {string} params.model @@ -449,6 +457,7 @@ async function reconnectServer({ */ async function createMCPTools({ res, + mcpPermissionContext, user, index, signal, @@ -503,6 +512,7 @@ async function createMCPTools({ for (const tool of result.tools) { const toolInstance = await createMCPTool({ res, + mcpPermissionContext, user, provider, userMCPAuthMap, @@ -524,6 +534,7 @@ async function createMCPTools({ * Creates a single tool from the specified MCP Server via `toolKey`. * @param {Object} params * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {{ canUseServers: (user?: IUser) => Promise }} [params.mcpPermissionContext] - Request-scoped MCP permission context. * @param {IUser} params.user - The user from the request object. * @param {string} params.toolKey - The toolKey for the tool. * @param {string} params.model - The model for the tool. @@ -538,6 +549,7 @@ async function createMCPTools({ */ async function createMCPTool({ res, + mcpPermissionContext, user, index, signal, @@ -613,6 +625,7 @@ async function createMCPTool({ return createToolInstance({ res, + mcpPermissionContext, user, provider, toolName, @@ -625,6 +638,7 @@ async function createMCPTool({ function createToolInstance({ res, + mcpPermissionContext, user: capturedUser = null, toolName, serverName, @@ -663,7 +677,9 @@ function createToolInstance({ try { const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase(); - const canUseMCP = await userCanUseMCPServers(permissionUser); + const canUseMCP = mcpPermissionContext + ? await mcpPermissionContext.canUseServers(permissionUser) + : await userCanUseMCPServers(permissionUser); if (!canUseMCP) { throw new Error('Forbidden: Insufficient MCP server permissions'); } @@ -923,6 +939,7 @@ async function getServerConnectionStatus( module.exports = { createMCPTool, createMCPTools, + createMCPPermissionContext, userCanUseMCPServers, getMCPSetupData, resolveConfigServers, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 31d8055239..38fbb192a4 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -43,6 +43,7 @@ const D = Constants.mcp_delimiter; const { createMCPTool, createMCPTools, + createMCPPermissionContext, getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus, @@ -864,6 +865,78 @@ describe('User parameter passing tests', () => { ); expect(mockGetMCPManager).not.toHaveBeenCalled(); }); + + it('should reuse request-scoped MCP permission checks across tool executions', async () => { + const mockUser = { id: 'mcp-allowed-user', role: 'USER' }; + const mockReq = { user: mockUser }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + const { getRoleByName } = require('~/models'); + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.MCP_SERVERS]: { + [Permissions.USE]: true, + }, + }, + }); + + const mockCallTool = jest.fn().mockResolvedValue(['ok', null]); + mockGetMCPManager.mockReturnValue({ + callTool: mockCallTool, + }); + + const availableTools = { + [`search${D}test-server`]: { + function: { + description: 'Search tool', + parameters: { type: 'object', properties: {} }, + }, + }, + [`fetch${D}test-server`]: { + function: { + description: 'Fetch tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }; + const mcpPermissionContext = createMCPPermissionContext(mockReq); + + const searchTool = await createMCPTool({ + mcpPermissionContext, + res: mockRes, + user: mockUser, + toolKey: `search${D}test-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + const fetchTool = await createMCPTool({ + mcpPermissionContext, + res: mockRes, + user: mockUser, + toolKey: `fetch${D}test-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + + const invocationConfig = { + configurable: { + user: mockUser, + }, + metadata: { + provider: 'openai', + thread_id: 'thread-1', + run_id: 'run-1', + }, + toolCall: {}, + }; + + await expect(searchTool.invoke({}, invocationConfig)).resolves.toBe('ok'); + await expect(fetchTool.invoke({}, invocationConfig)).resolves.toBe('ok'); + + expect(getRoleByName).toHaveBeenCalledTimes(1); + expect(mockCallTool).toHaveBeenCalledTimes(2); + }); }); describe('reinitMCPServer (via reconnectServer)', () => { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 261ac42547..67b68bdad1 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -63,7 +63,7 @@ const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/pro const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); -const { resolveConfigServers, userCanUseMCPServers } = require('~/server/services/MCP'); +const { createMCPPermissionContext, resolveConfigServers } = require('~/server/services/MCP'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); @@ -552,7 +552,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to agent.tools?.includes(Tools.execute_code) === true && enabledCapabilities.has(AgentCapabilities.execute_code); const hasMCPTools = agent.tools?.some((tool) => tool?.includes(Constants.mcp_delimiter)); - const canUseMCP = hasMCPTools ? await userCanUseMCPServers(req.user) : true; + const mcpPermissionContext = createMCPPermissionContext(req); + const canUseMCP = hasMCPTools ? await mcpPermissionContext.canUseServers(req.user) : true; const filteredTools = agent.tools?.filter((tool) => { if (tool === Tools.file_search) { @@ -989,7 +990,8 @@ async function loadAgentTools({ const areToolsEnabled = checkCapability(AgentCapabilities.tools); const actionsEnabled = checkCapability(AgentCapabilities.actions); const hasMCPTools = agent.tools?.some((tool) => tool?.includes(Constants.mcp_delimiter)); - const canUseMCP = hasMCPTools ? await userCanUseMCPServers(req.user) : true; + const mcpPermissionContext = createMCPPermissionContext(req); + const canUseMCP = hasMCPTools ? await mcpPermissionContext.canUseServers(req.user) : true; let includesWebSearch = false; const _agentTools = agent.tools?.filter((tool) => { @@ -1044,6 +1046,7 @@ async function loadAgentTools({ processFileURL, uploadImageBuffer, returnMetadata: true, + mcpPermissionContext, [Tools.web_search]: webSearchCallbacks, }, webSearch: appConfig.webSearch, diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index 451120eaa3..3fc9e58eb6 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -37,6 +37,7 @@ const mockDecryptMetadata = jest.fn(); const mockCreateActionTool = jest.fn(); const mockGetServerConfig = jest.fn(); const mockResolveConfigServers = jest.fn(); +const mockUserCanUseMCPServers = jest.fn().mockResolvedValue(true); jest.mock('~/server/services/Tools/credentials', () => ({ loadAuthValues: jest.fn().mockResolvedValue({}), })); @@ -77,7 +78,10 @@ jest.mock('~/config', () => ({ })); jest.mock('~/server/services/MCP', () => ({ resolveConfigServers: (...args) => mockResolveConfigServers(...args), - userCanUseMCPServers: jest.fn().mockResolvedValue(true), + createMCPPermissionContext: jest.fn((req) => ({ + canUseServers: (user) => mockUserCanUseMCPServers(user, req), + })), + userCanUseMCPServers: mockUserCanUseMCPServers, })); jest.mock('~/cache', () => ({ getLogStores: jest.fn(() => ({})), diff --git a/packages/api/src/middleware/access.spec.ts b/packages/api/src/middleware/access.spec.ts index 99257adf6d..c77508b153 100644 --- a/packages/api/src/middleware/access.spec.ts +++ b/packages/api/src/middleware/access.spec.ts @@ -6,7 +6,12 @@ import { EndpointURLs, } from 'librechat-data-provider'; import type { IRole, IUser } from '@librechat/data-schemas'; -import { checkAccess, generateCheckAccess, skipAgentCheck } from './access'; +import { + checkAccess, + checkAccessWithRequestCache, + generateCheckAccess, + skipAgentCheck, +} from './access'; // Mock logger jest.mock('@librechat/data-schemas', () => ({ @@ -269,6 +274,100 @@ describe('access middleware', () => { }); }); + describe('checkAccessWithRequestCache', () => { + const defaultParams = { + user: { + id: 'user123', + role: 'user', + email: 'test@example.com', + emailVerified: true, + provider: 'local', + } as IUser, + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE], + getRoleByName: jest.fn(), + }; + + const allowedRole = { + name: 'user', + permissions: { + [PermissionTypes.MCP_SERVERS]: { + [Permissions.USE]: true, + [Permissions.CREATE]: true, + }, + }, + } as unknown as IRole; + + it('should memoize permission checks for the same request', async () => { + defaultParams.getRoleByName.mockResolvedValue(allowedRole); + + const params = { + ...defaultParams, + req: mockReq as Request, + }; + + await expect(checkAccessWithRequestCache(params)).resolves.toBe(true); + await expect(checkAccessWithRequestCache(params)).resolves.toBe(true); + + expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(1); + }); + + it('should share an in-flight permission check for the same request', async () => { + defaultParams.getRoleByName.mockResolvedValue(allowedRole); + + const params = { + ...defaultParams, + req: mockReq as Request, + }; + + await expect( + Promise.all([checkAccessWithRequestCache(params), checkAccessWithRequestCache(params)]), + ).resolves.toEqual([true, true]); + + expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(1); + }); + + it('should isolate memoized checks between requests', async () => { + defaultParams.getRoleByName.mockResolvedValue(allowedRole); + + await expect( + checkAccessWithRequestCache({ + ...defaultParams, + req: mockReq as Request, + }), + ).resolves.toBe(true); + await expect( + checkAccessWithRequestCache({ + ...defaultParams, + req: { ...mockReq } as Request, + }), + ).resolves.toBe(true); + + expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(2); + }); + + it('should use separate cache entries for different permissions', async () => { + defaultParams.getRoleByName.mockResolvedValue(allowedRole); + + await expect( + checkAccessWithRequestCache({ + ...defaultParams, + req: mockReq as Request, + permissions: [Permissions.USE], + }), + ).resolves.toBe(true); + await expect( + checkAccessWithRequestCache({ + ...defaultParams, + req: mockReq as Request, + permissions: [Permissions.CREATE], + }), + ).resolves.toBe(true); + + expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(2); + }); + }); + describe('generateCheckAccess', () => { it('should create middleware that allows access when user has permissions', async () => { const mockRole = { diff --git a/packages/api/src/middleware/access.ts b/packages/api/src/middleware/access.ts index 8b3d83d037..3b90d49d40 100644 --- a/packages/api/src/middleware/access.ts +++ b/packages/api/src/middleware/access.ts @@ -24,6 +24,54 @@ export function skipAgentCheck(req?: ServerRequest): boolean { return !isAgentsEndpoint(req.body.endpoint); } +export interface CheckAccessParams { + user: IUser; + req?: ServerRequest; + permissionType: PermissionTypes; + permissions: Permissions[]; + bodyProps?: Record; + checkObject?: object; + /** If skipCheck function is provided and returns true, skip permission checking */ + skipCheck?: (req?: ServerRequest) => boolean; + getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; +} + +export type CheckAccessWithRequestCacheParams = Omit< + CheckAccessParams, + 'bodyProps' | 'checkObject' | 'skipCheck' +>; + +type RequestPermissionCache = Map>; + +const requestPermissionCacheKey = '__librechatRequestPermissionCache'; + +function getRequestPermissionCache(req?: ServerRequest): RequestPermissionCache | null { + if (!req) { + return null; + } + + const reqWithCache = req as ServerRequest & { + [requestPermissionCacheKey]?: RequestPermissionCache; + }; + + if (!reqWithCache[requestPermissionCacheKey]) { + Object.defineProperty(reqWithCache, requestPermissionCacheKey, { + value: new Map>(), + enumerable: false, + }); + } + + return reqWithCache[requestPermissionCacheKey] ?? null; +} + +function getRequestPermissionCacheKey({ + user, + permissionType, + permissions, +}: CheckAccessWithRequestCacheParams): string { + return [permissionType, [...permissions].sort().join(','), user.id, user.role].join(':'); +} + /** * Core function to check if a user has one or more required permissions * @param user - The user object @@ -43,17 +91,7 @@ export const checkAccess = async ({ bodyProps = {} as Record, checkObject = {}, skipCheck, -}: { - user: IUser; - req?: ServerRequest; - permissionType: PermissionTypes; - permissions: Permissions[]; - bodyProps?: Record; - checkObject?: object; - /** If skipCheck function is provided and returns true, skip permission checking */ - skipCheck?: (req?: ServerRequest) => boolean; - getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; -}): Promise => { +}: CheckAccessParams): Promise => { if (skipCheck && skipCheck(req)) { return true; } @@ -85,6 +123,35 @@ export const checkAccess = async ({ return false; }; +/** + * Checks simple role permissions using a per-request promise cache. + * Use this only for checks whose result is fully described by user, role, permission type, and permissions. + */ +export const checkAccessWithRequestCache = async ( + params: CheckAccessWithRequestCacheParams, +): Promise => { + if (!params.req || !params.user?.id || !params.user?.role) { + return await checkAccess(params); + } + + const cache = getRequestPermissionCache(params.req); + if (!cache) { + return await checkAccess(params); + } + + const cacheKey = getRequestPermissionCacheKey(params); + let cachedCheck = cache.get(cacheKey); + if (!cachedCheck) { + cachedCheck = checkAccess(params).catch((error) => { + cache.delete(cacheKey); + throw error; + }); + cache.set(cacheKey, cachedCheck); + } + + return await cachedCheck; +}; + /** * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. * @param permissionType - The type of permission to check.