mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-09 17:31:19 +00:00
🧠 refactor: Memoize MCP Permission Checks Per Request (#13419)
This commit is contained in:
parent
100871c3ec
commit
479e9d59b7
9 changed files with 314 additions and 25 deletions
|
|
@ -38,8 +38,8 @@ const {
|
|||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
createMCPPermissionContext,
|
||||
resolveConfigServers,
|
||||
userCanUseMCPServers,
|
||||
} = require('~/server/services/MCP');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
|
|
@ -233,7 +233,11 @@ const loadTools = async ({
|
|||
|
||||
const requestedTools = {};
|
||||
const hasMCPTools = tools.some((toolName) => toolName && mcpToolPattern.test(toolName));
|
||||
const canUseMCP = hasMCPTools ? await userCanUseMCPServers(options.req?.user) : true;
|
||||
const mcpPermissionContext =
|
||||
options.mcpPermissionContext ?? createMCPPermissionContext(options.req);
|
||||
const canUseMCP = hasMCPTools
|
||||
? await mcpPermissionContext.canUseServers(options.req?.user)
|
||||
: true;
|
||||
let loggedMCPDenied = false;
|
||||
|
||||
if (functions === true) {
|
||||
|
|
@ -458,6 +462,7 @@ const loadTools = async ({
|
|||
continue;
|
||||
}
|
||||
const mcpParams = {
|
||||
mcpPermissionContext,
|
||||
index,
|
||||
signal,
|
||||
user: safeUser,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ jest.mock('~/config', () => ({
|
|||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
createMCPPermissionContext: jest.fn((req) => ({
|
||||
canUseServers: (user) => mockUserCanUseMCPServers(user, req),
|
||||
})),
|
||||
userCanUseMCPServers: (...args) => mockUserCanUseMCPServers(...args),
|
||||
}));
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,11 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
|||
const { getFileStrategy } = require('~/server/utils/getFileStrategy');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { resolveConfigServers, userCanUseMCPServers } = require('~/server/services/MCP');
|
||||
const {
|
||||
createMCPPermissionContext,
|
||||
resolveConfigServers,
|
||||
userCanUseMCPServers,
|
||||
} = require('~/server/services/MCP');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
|
@ -192,6 +196,7 @@ const isSubagentsCapabilityEnabled = (req) => {
|
|||
* @param {string} params.userId - Requesting user ID for MCP server access check
|
||||
* @param {string} [params.role] - Requesting user's role for ACL principal resolution
|
||||
* @param {object} [params.user] - Requesting user for MCP server use permission checks
|
||||
* @param {{ canUseServers: (user?: object) => Promise<boolean> }} [params.mcpPermissionContext] - Request-scoped MCP permission context
|
||||
* @param {Record<string, unknown>} params.availableTools - Global non-MCP tool cache
|
||||
* @param {string[]} [params.existingTools] - Tools already persisted on the agent document
|
||||
* @param {Record<string, unknown>} [params.configServers] - Config-source MCP servers resolved from appConfig overrides
|
||||
|
|
@ -202,6 +207,7 @@ const filterAuthorizedTools = async ({
|
|||
userId,
|
||||
role,
|
||||
user,
|
||||
mcpPermissionContext,
|
||||
availableTools,
|
||||
existingTools,
|
||||
configServers,
|
||||
|
|
@ -211,7 +217,11 @@ const filterAuthorizedTools = async ({
|
|||
let registryUnavailable = false;
|
||||
const existingToolSet = existingTools?.length ? new Set(existingTools) : null;
|
||||
const hasMCPTools = tools.some((tool) => tool?.includes(Constants.mcp_delimiter));
|
||||
const canUseMCP = hasMCPTools ? await userCanUseMCPServers(user) : true;
|
||||
const canUseMCP = hasMCPTools
|
||||
? await (mcpPermissionContext
|
||||
? mcpPermissionContext.canUseServers(user)
|
||||
: userCanUseMCPServers(user))
|
||||
: true;
|
||||
let loggedMCPDenied = false;
|
||||
|
||||
for (const tool of tools) {
|
||||
|
|
@ -403,11 +413,13 @@ const createAgentHandler = async (req, res) => {
|
|||
getCachedTools().then((t) => t ?? {}),
|
||||
hasMCPTools ? resolveConfigServers(req) : Promise.resolve(undefined),
|
||||
]);
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
agentData.tools = await filterAuthorizedTools({
|
||||
tools,
|
||||
userId,
|
||||
role: req.user.role,
|
||||
user: req.user,
|
||||
mcpPermissionContext,
|
||||
availableTools,
|
||||
configServers,
|
||||
});
|
||||
|
|
@ -637,7 +649,8 @@ const updateAgentHandler = async (req, res) => {
|
|||
const existingMCPTools = existingTools.filter(isMCPTool);
|
||||
|
||||
if (requestedMCPTools.length > 0 || (hasToolUpdate && existingMCPTools.length > 0)) {
|
||||
if (!(await userCanUseMCPServers(req.user))) {
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
if (!(await mcpPermissionContext.canUseServers(req.user))) {
|
||||
if (editingOwnAgent) {
|
||||
updateData.tools = effectiveTools.filter((t) => !isMCPTool(t));
|
||||
} else if (hasToolUpdate) {
|
||||
|
|
@ -667,6 +680,7 @@ const updateAgentHandler = async (req, res) => {
|
|||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
user: req.user,
|
||||
mcpPermissionContext,
|
||||
availableTools,
|
||||
configServers,
|
||||
});
|
||||
|
|
@ -826,11 +840,13 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
getCachedTools().then((t) => t ?? {}),
|
||||
resolveConfigServers(req),
|
||||
]);
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
newAgentData.tools = await filterAuthorizedTools({
|
||||
tools: newAgentData.tools,
|
||||
userId,
|
||||
role: req.user.role,
|
||||
user: req.user,
|
||||
mcpPermissionContext,
|
||||
availableTools,
|
||||
existingTools: newAgentData.tools,
|
||||
configServers,
|
||||
|
|
@ -1202,11 +1218,13 @@ const revertAgentVersionHandler = async (req, res) => {
|
|||
getCachedTools().then((t) => t ?? {}),
|
||||
resolveConfigServers(req),
|
||||
]);
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
const filteredTools = await filterAuthorizedTools({
|
||||
tools: updatedAgent.tools,
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
user: req.user,
|
||||
mcpPermissionContext,
|
||||
availableTools,
|
||||
existingTools: updatedAgent.tools,
|
||||
configServers,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ const {
|
|||
Constants: AgentConstants,
|
||||
} = require('@librechat/agents');
|
||||
const {
|
||||
checkAccess,
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
|
|
@ -16,6 +15,7 @@ const {
|
|||
GenerationJobManager,
|
||||
resolveJsonSchemaRefs,
|
||||
buildOAuthToolCallName,
|
||||
checkAccessWithRequestCache,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -45,13 +45,14 @@ const RECONNECT_THROTTLE_MS = 10_000;
|
|||
const missingToolCache = new Map();
|
||||
const MISSING_TOOL_TTL_MS = 10_000;
|
||||
|
||||
async function userCanUseMCPServers(user) {
|
||||
async function userCanUseMCPServers(user, req) {
|
||||
if (!user?.id || !user?.role) {
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
return await checkAccess({
|
||||
return await checkAccessWithRequestCache({
|
||||
req,
|
||||
user,
|
||||
permissionType: PermissionTypes.MCP_SERVERS,
|
||||
permissions: [Permissions.USE],
|
||||
|
|
@ -63,6 +64,12 @@ async function userCanUseMCPServers(user) {
|
|||
}
|
||||
}
|
||||
|
||||
function createMCPPermissionContext(req) {
|
||||
return {
|
||||
canUseServers: (user = req?.user) => userCanUseMCPServers(user, req),
|
||||
};
|
||||
}
|
||||
|
||||
function evictStale(map, ttl) {
|
||||
if (map.size <= MAX_CACHE_SIZE) {
|
||||
return;
|
||||
|
|
@ -436,6 +443,7 @@ async function reconnectServer({
|
|||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {{ canUseServers: (user?: IUser) => Promise<boolean> }} [params.mcpPermissionContext] - Request-scoped MCP permission context.
|
||||
* @param {IUser} params.user - The user from the request object.
|
||||
* @param {string} params.serverName
|
||||
* @param {string} params.model
|
||||
|
|
@ -449,6 +457,7 @@ async function reconnectServer({
|
|||
*/
|
||||
async function createMCPTools({
|
||||
res,
|
||||
mcpPermissionContext,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
|
|
@ -503,6 +512,7 @@ async function createMCPTools({
|
|||
for (const tool of result.tools) {
|
||||
const toolInstance = await createMCPTool({
|
||||
res,
|
||||
mcpPermissionContext,
|
||||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
|
|
@ -524,6 +534,7 @@ async function createMCPTools({
|
|||
* Creates a single tool from the specified MCP Server via `toolKey`.
|
||||
* @param {Object} params
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {{ canUseServers: (user?: IUser) => Promise<boolean> }} [params.mcpPermissionContext] - Request-scoped MCP permission context.
|
||||
* @param {IUser} params.user - The user from the request object.
|
||||
* @param {string} params.toolKey - The toolKey for the tool.
|
||||
* @param {string} params.model - The model for the tool.
|
||||
|
|
@ -538,6 +549,7 @@ async function createMCPTools({
|
|||
*/
|
||||
async function createMCPTool({
|
||||
res,
|
||||
mcpPermissionContext,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
|
|
@ -613,6 +625,7 @@ async function createMCPTool({
|
|||
|
||||
return createToolInstance({
|
||||
res,
|
||||
mcpPermissionContext,
|
||||
user,
|
||||
provider,
|
||||
toolName,
|
||||
|
|
@ -625,6 +638,7 @@ async function createMCPTool({
|
|||
|
||||
function createToolInstance({
|
||||
res,
|
||||
mcpPermissionContext,
|
||||
user: capturedUser = null,
|
||||
toolName,
|
||||
serverName,
|
||||
|
|
@ -663,7 +677,9 @@ function createToolInstance({
|
|||
|
||||
try {
|
||||
const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase();
|
||||
const canUseMCP = await userCanUseMCPServers(permissionUser);
|
||||
const canUseMCP = mcpPermissionContext
|
||||
? await mcpPermissionContext.canUseServers(permissionUser)
|
||||
: await userCanUseMCPServers(permissionUser);
|
||||
if (!canUseMCP) {
|
||||
throw new Error('Forbidden: Insufficient MCP server permissions');
|
||||
}
|
||||
|
|
@ -923,6 +939,7 @@ async function getServerConnectionStatus(
|
|||
module.exports = {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
createMCPPermissionContext,
|
||||
userCanUseMCPServers,
|
||||
getMCPSetupData,
|
||||
resolveConfigServers,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ const D = Constants.mcp_delimiter;
|
|||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
createMCPPermissionContext,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
|
|
@ -864,6 +865,78 @@ describe('User parameter passing tests', () => {
|
|||
);
|
||||
expect(mockGetMCPManager).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reuse request-scoped MCP permission checks across tool executions', async () => {
|
||||
const mockUser = { id: 'mcp-allowed-user', role: 'USER' };
|
||||
const mockReq = { user: mockUser };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
const { getRoleByName } = require('~/models');
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.MCP_SERVERS]: {
|
||||
[Permissions.USE]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const mockCallTool = jest.fn().mockResolvedValue(['ok', null]);
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
callTool: mockCallTool,
|
||||
});
|
||||
|
||||
const availableTools = {
|
||||
[`search${D}test-server`]: {
|
||||
function: {
|
||||
description: 'Search tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
[`fetch${D}test-server`]: {
|
||||
function: {
|
||||
description: 'Fetch tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
const mcpPermissionContext = createMCPPermissionContext(mockReq);
|
||||
|
||||
const searchTool = await createMCPTool({
|
||||
mcpPermissionContext,
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: `search${D}test-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
const fetchTool = await createMCPTool({
|
||||
mcpPermissionContext,
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: `fetch${D}test-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
const invocationConfig = {
|
||||
configurable: {
|
||||
user: mockUser,
|
||||
},
|
||||
metadata: {
|
||||
provider: 'openai',
|
||||
thread_id: 'thread-1',
|
||||
run_id: 'run-1',
|
||||
},
|
||||
toolCall: {},
|
||||
};
|
||||
|
||||
await expect(searchTool.invoke({}, invocationConfig)).resolves.toBe('ok');
|
||||
await expect(fetchTool.invoke({}, invocationConfig)).resolves.toBe('ok');
|
||||
|
||||
expect(getRoleByName).toHaveBeenCalledTimes(1);
|
||||
expect(mockCallTool).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reinitMCPServer (via reconnectServer)', () => {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/pro
|
|||
const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
|
||||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { resolveConfigServers, userCanUseMCPServers } = require('~/server/services/MCP');
|
||||
const { createMCPPermissionContext, resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { redactMessage } = require('~/config/parsers');
|
||||
|
|
@ -552,7 +552,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
agent.tools?.includes(Tools.execute_code) === true &&
|
||||
enabledCapabilities.has(AgentCapabilities.execute_code);
|
||||
const hasMCPTools = agent.tools?.some((tool) => tool?.includes(Constants.mcp_delimiter));
|
||||
const canUseMCP = hasMCPTools ? await userCanUseMCPServers(req.user) : true;
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
const canUseMCP = hasMCPTools ? await mcpPermissionContext.canUseServers(req.user) : true;
|
||||
|
||||
const filteredTools = agent.tools?.filter((tool) => {
|
||||
if (tool === Tools.file_search) {
|
||||
|
|
@ -989,7 +990,8 @@ async function loadAgentTools({
|
|||
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
||||
const actionsEnabled = checkCapability(AgentCapabilities.actions);
|
||||
const hasMCPTools = agent.tools?.some((tool) => tool?.includes(Constants.mcp_delimiter));
|
||||
const canUseMCP = hasMCPTools ? await userCanUseMCPServers(req.user) : true;
|
||||
const mcpPermissionContext = createMCPPermissionContext(req);
|
||||
const canUseMCP = hasMCPTools ? await mcpPermissionContext.canUseServers(req.user) : true;
|
||||
|
||||
let includesWebSearch = false;
|
||||
const _agentTools = agent.tools?.filter((tool) => {
|
||||
|
|
@ -1044,6 +1046,7 @@ async function loadAgentTools({
|
|||
processFileURL,
|
||||
uploadImageBuffer,
|
||||
returnMetadata: true,
|
||||
mcpPermissionContext,
|
||||
[Tools.web_search]: webSearchCallbacks,
|
||||
},
|
||||
webSearch: appConfig.webSearch,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ const mockDecryptMetadata = jest.fn();
|
|||
const mockCreateActionTool = jest.fn();
|
||||
const mockGetServerConfig = jest.fn();
|
||||
const mockResolveConfigServers = jest.fn();
|
||||
const mockUserCanUseMCPServers = jest.fn().mockResolvedValue(true);
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
|
@ -77,7 +78,10 @@ jest.mock('~/config', () => ({
|
|||
}));
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: (...args) => mockResolveConfigServers(...args),
|
||||
userCanUseMCPServers: jest.fn().mockResolvedValue(true),
|
||||
createMCPPermissionContext: jest.fn((req) => ({
|
||||
canUseServers: (user) => mockUserCanUseMCPServers(user, req),
|
||||
})),
|
||||
userCanUseMCPServers: mockUserCanUseMCPServers,
|
||||
}));
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({})),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,12 @@ import {
|
|||
EndpointURLs,
|
||||
} from 'librechat-data-provider';
|
||||
import type { IRole, IUser } from '@librechat/data-schemas';
|
||||
import { checkAccess, generateCheckAccess, skipAgentCheck } from './access';
|
||||
import {
|
||||
checkAccess,
|
||||
checkAccessWithRequestCache,
|
||||
generateCheckAccess,
|
||||
skipAgentCheck,
|
||||
} from './access';
|
||||
|
||||
// Mock logger
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -269,6 +274,100 @@ describe('access middleware', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('checkAccessWithRequestCache', () => {
|
||||
const defaultParams = {
|
||||
user: {
|
||||
id: 'user123',
|
||||
role: 'user',
|
||||
email: 'test@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
} as IUser,
|
||||
permissionType: PermissionTypes.MCP_SERVERS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName: jest.fn(),
|
||||
};
|
||||
|
||||
const allowedRole = {
|
||||
name: 'user',
|
||||
permissions: {
|
||||
[PermissionTypes.MCP_SERVERS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
},
|
||||
},
|
||||
} as unknown as IRole;
|
||||
|
||||
it('should memoize permission checks for the same request', async () => {
|
||||
defaultParams.getRoleByName.mockResolvedValue(allowedRole);
|
||||
|
||||
const params = {
|
||||
...defaultParams,
|
||||
req: mockReq as Request,
|
||||
};
|
||||
|
||||
await expect(checkAccessWithRequestCache(params)).resolves.toBe(true);
|
||||
await expect(checkAccessWithRequestCache(params)).resolves.toBe(true);
|
||||
|
||||
expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should share an in-flight permission check for the same request', async () => {
|
||||
defaultParams.getRoleByName.mockResolvedValue(allowedRole);
|
||||
|
||||
const params = {
|
||||
...defaultParams,
|
||||
req: mockReq as Request,
|
||||
};
|
||||
|
||||
await expect(
|
||||
Promise.all([checkAccessWithRequestCache(params), checkAccessWithRequestCache(params)]),
|
||||
).resolves.toEqual([true, true]);
|
||||
|
||||
expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should isolate memoized checks between requests', async () => {
|
||||
defaultParams.getRoleByName.mockResolvedValue(allowedRole);
|
||||
|
||||
await expect(
|
||||
checkAccessWithRequestCache({
|
||||
...defaultParams,
|
||||
req: mockReq as Request,
|
||||
}),
|
||||
).resolves.toBe(true);
|
||||
await expect(
|
||||
checkAccessWithRequestCache({
|
||||
...defaultParams,
|
||||
req: { ...mockReq } as Request,
|
||||
}),
|
||||
).resolves.toBe(true);
|
||||
|
||||
expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use separate cache entries for different permissions', async () => {
|
||||
defaultParams.getRoleByName.mockResolvedValue(allowedRole);
|
||||
|
||||
await expect(
|
||||
checkAccessWithRequestCache({
|
||||
...defaultParams,
|
||||
req: mockReq as Request,
|
||||
permissions: [Permissions.USE],
|
||||
}),
|
||||
).resolves.toBe(true);
|
||||
await expect(
|
||||
checkAccessWithRequestCache({
|
||||
...defaultParams,
|
||||
req: mockReq as Request,
|
||||
permissions: [Permissions.CREATE],
|
||||
}),
|
||||
).resolves.toBe(true);
|
||||
|
||||
expect(defaultParams.getRoleByName).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateCheckAccess', () => {
|
||||
it('should create middleware that allows access when user has permissions', async () => {
|
||||
const mockRole = {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,54 @@ export function skipAgentCheck(req?: ServerRequest): boolean {
|
|||
return !isAgentsEndpoint(req.body.endpoint);
|
||||
}
|
||||
|
||||
export interface CheckAccessParams {
|
||||
user: IUser;
|
||||
req?: ServerRequest;
|
||||
permissionType: PermissionTypes;
|
||||
permissions: Permissions[];
|
||||
bodyProps?: Record<Permissions, string[]>;
|
||||
checkObject?: object;
|
||||
/** If skipCheck function is provided and returns true, skip permission checking */
|
||||
skipCheck?: (req?: ServerRequest) => boolean;
|
||||
getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise<IRole | null>;
|
||||
}
|
||||
|
||||
export type CheckAccessWithRequestCacheParams = Omit<
|
||||
CheckAccessParams,
|
||||
'bodyProps' | 'checkObject' | 'skipCheck'
|
||||
>;
|
||||
|
||||
type RequestPermissionCache = Map<string, Promise<boolean>>;
|
||||
|
||||
const requestPermissionCacheKey = '__librechatRequestPermissionCache';
|
||||
|
||||
function getRequestPermissionCache(req?: ServerRequest): RequestPermissionCache | null {
|
||||
if (!req) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const reqWithCache = req as ServerRequest & {
|
||||
[requestPermissionCacheKey]?: RequestPermissionCache;
|
||||
};
|
||||
|
||||
if (!reqWithCache[requestPermissionCacheKey]) {
|
||||
Object.defineProperty(reqWithCache, requestPermissionCacheKey, {
|
||||
value: new Map<string, Promise<boolean>>(),
|
||||
enumerable: false,
|
||||
});
|
||||
}
|
||||
|
||||
return reqWithCache[requestPermissionCacheKey] ?? null;
|
||||
}
|
||||
|
||||
function getRequestPermissionCacheKey({
|
||||
user,
|
||||
permissionType,
|
||||
permissions,
|
||||
}: CheckAccessWithRequestCacheParams): string {
|
||||
return [permissionType, [...permissions].sort().join(','), user.id, user.role].join(':');
|
||||
}
|
||||
|
||||
/**
|
||||
* Core function to check if a user has one or more required permissions
|
||||
* @param user - The user object
|
||||
|
|
@ -43,17 +91,7 @@ export const checkAccess = async ({
|
|||
bodyProps = {} as Record<Permissions, string[]>,
|
||||
checkObject = {},
|
||||
skipCheck,
|
||||
}: {
|
||||
user: IUser;
|
||||
req?: ServerRequest;
|
||||
permissionType: PermissionTypes;
|
||||
permissions: Permissions[];
|
||||
bodyProps?: Record<Permissions, string[]>;
|
||||
checkObject?: object;
|
||||
/** If skipCheck function is provided and returns true, skip permission checking */
|
||||
skipCheck?: (req?: ServerRequest) => boolean;
|
||||
getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise<IRole | null>;
|
||||
}): Promise<boolean> => {
|
||||
}: CheckAccessParams): Promise<boolean> => {
|
||||
if (skipCheck && skipCheck(req)) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -85,6 +123,35 @@ export const checkAccess = async ({
|
|||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks simple role permissions using a per-request promise cache.
|
||||
* Use this only for checks whose result is fully described by user, role, permission type, and permissions.
|
||||
*/
|
||||
export const checkAccessWithRequestCache = async (
|
||||
params: CheckAccessWithRequestCacheParams,
|
||||
): Promise<boolean> => {
|
||||
if (!params.req || !params.user?.id || !params.user?.role) {
|
||||
return await checkAccess(params);
|
||||
}
|
||||
|
||||
const cache = getRequestPermissionCache(params.req);
|
||||
if (!cache) {
|
||||
return await checkAccess(params);
|
||||
}
|
||||
|
||||
const cacheKey = getRequestPermissionCacheKey(params);
|
||||
let cachedCheck = cache.get(cacheKey);
|
||||
if (!cachedCheck) {
|
||||
cachedCheck = checkAccess(params).catch((error) => {
|
||||
cache.delete(cacheKey);
|
||||
throw error;
|
||||
});
|
||||
cache.set(cacheKey, cachedCheck);
|
||||
}
|
||||
|
||||
return await cachedCheck;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
* @param permissionType - The type of permission to check.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue