mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-28 10:21:39 +00:00
🪣 fix: Cap Context Projection Workload Before Tokenization (#13910)
* fix: bound context projection workload * fix: Address context projection CI failures * fix: Bound context projection database reads * fix: Sort projection spec imports * fix: Cap projection body reads with stats
This commit is contained in:
parent
2f800c5b52
commit
77854decdf
8 changed files with 606 additions and 18 deletions
|
|
@ -18,7 +18,11 @@ async function contextProjectionController(req, res) {
|
|||
return;
|
||||
}
|
||||
const projection = await resolveContextProjection(
|
||||
{ userId: req.user?.id, getMessages: db.getMessages },
|
||||
{
|
||||
userId: req.user?.id,
|
||||
getMessages: db.getMessages,
|
||||
getMessageTextStats: db.getMessageTextStats,
|
||||
},
|
||||
params,
|
||||
);
|
||||
res.json(projection ?? null);
|
||||
|
|
|
|||
19
api/server/middleware/limiters/contextProjectionLimiter.js
Normal file
19
api/server/middleware/limiters/contextProjectionLimiter.js
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
const rateLimit = require('express-rate-limit');
|
||||
const { limiterCache } = require('@librechat/api');
|
||||
|
||||
const { CONTEXT_PROJECTION_WINDOW = 1, CONTEXT_PROJECTION_MAX = 20 } = process.env;
|
||||
|
||||
const windowMs = (parseInt(CONTEXT_PROJECTION_WINDOW, 10) || 1) * 60 * 1000;
|
||||
const max = parseInt(CONTEXT_PROJECTION_MAX, 10) || 20;
|
||||
|
||||
const contextProjectionLimiter = rateLimit({
|
||||
windowMs,
|
||||
max,
|
||||
handler: (_req, res) => {
|
||||
res.status(429).json({ message: 'Too many context projection requests. Try again later' });
|
||||
},
|
||||
keyGenerator: (req) => req.user?.id,
|
||||
store: limiterCache('context_projection_limiter'),
|
||||
});
|
||||
|
||||
module.exports = contextProjectionLimiter;
|
||||
|
|
@ -9,6 +9,7 @@ const registerLimiter = require('./registerLimiter');
|
|||
const toolCallLimiter = require('./toolCallLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
const promptUsageLimiter = require('./promptUsageLimiter');
|
||||
const contextProjectionLimiter = require('./contextProjectionLimiter');
|
||||
const verifyEmailLimiter = require('./verifyEmailLimiter');
|
||||
const resetPasswordLimiter = require('./resetPasswordLimiter');
|
||||
const twoFactorTempLimiter = require('./twoFactorTempLimiter');
|
||||
|
|
@ -22,6 +23,7 @@ module.exports = {
|
|||
loginLimiter,
|
||||
registerLimiter,
|
||||
toolCallLimiter,
|
||||
contextProjectionLimiter,
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
verifyEmailLimiter,
|
||||
|
|
|
|||
|
|
@ -4,11 +4,18 @@ const configMiddleware = require('~/server/middleware/config/app');
|
|||
const endpointController = require('~/server/controllers/EndpointController');
|
||||
const tokenConfigController = require('~/server/controllers/TokenConfigController');
|
||||
const contextProjectionController = require('~/server/controllers/ContextProjectionController');
|
||||
const { contextProjectionLimiter } = require('~/server/middleware/limiters');
|
||||
|
||||
const router = express.Router();
|
||||
/** Auth required for role/tenant-scoped endpoint config resolution. */
|
||||
router.get('/', requireJwtAuth, endpointController);
|
||||
router.get('/token-config', requireJwtAuth, configMiddleware, tokenConfigController);
|
||||
router.post('/context-projection', requireJwtAuth, configMiddleware, contextProjectionController);
|
||||
router.post(
|
||||
'/context-projection',
|
||||
requireJwtAuth,
|
||||
contextProjectionLimiter,
|
||||
configMiddleware,
|
||||
contextProjectionController,
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
206
packages/api/src/endpoints/projection.spec.ts
Normal file
206
packages/api/src/endpoints/projection.spec.ts
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import { resolveContextProjection } from './projection';
|
||||
import { QUOTE_MAX_COUNT } from '~/utils/quotes';
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
Providers: { OPENAI: 'openai' },
|
||||
createTokenCounter: jest.fn(async () => jest.fn(() => 1)),
|
||||
projectAgentContextUsage: jest.fn(() => ({ tokenCount: 1, maxContextTokens: 1000 })),
|
||||
}));
|
||||
|
||||
const GRAPH_SELECT = 'messageId parentMessageId metadata.summaryUsedTokens';
|
||||
const BODY_SELECT = 'messageId parentMessageId tokenCount isCreatedByUser text quotes';
|
||||
|
||||
function textStats(messageId: string, textBytes = 5) {
|
||||
return {
|
||||
messageId,
|
||||
textBytes,
|
||||
quoteCount: 0,
|
||||
quoteBytes: 0,
|
||||
quoteLineCount: 0,
|
||||
nonStringQuoteCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
describe('resolveContextProjection', () => {
|
||||
const baseParams = {
|
||||
conversationId: 'conversation-1',
|
||||
messageId: 'message-1',
|
||||
endpoint: 'openai',
|
||||
maxContextTokens: 1000,
|
||||
model: 'gpt-4o',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns null before tokenization when the conversation is too large', async () => {
|
||||
const { createTokenCounter } = jest.requireMock('@librechat/agents');
|
||||
const messages = Array.from({ length: 513 }, (_, index) => ({
|
||||
messageId: `message-${index}`,
|
||||
parentMessageId: index === 0 ? null : `message-${index - 1}`,
|
||||
isCreatedByUser: true,
|
||||
text: 'hello',
|
||||
}));
|
||||
const getMessages = jest.fn(async () => messages);
|
||||
const getMessageTextStats = jest.fn();
|
||||
|
||||
const result = await resolveContextProjection(
|
||||
{ userId: 'user-1', getMessages, getMessageTextStats },
|
||||
{ ...baseParams, messageId: 'message-512' },
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(getMessages).toHaveBeenCalledTimes(1);
|
||||
expect(getMessages).toHaveBeenCalledWith(
|
||||
{ conversationId: 'conversation-1', user: 'user-1' },
|
||||
GRAPH_SELECT,
|
||||
{ limit: 513, sort: false },
|
||||
);
|
||||
expect(getMessageTextStats).not.toHaveBeenCalled();
|
||||
expect(createTokenCounter).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns null before tokenization when the branch is too long', async () => {
|
||||
const { createTokenCounter } = jest.requireMock('@librechat/agents');
|
||||
const messages = Array.from({ length: 257 }, (_, index) => ({
|
||||
messageId: `message-${index}`,
|
||||
parentMessageId: index === 0 ? null : `message-${index - 1}`,
|
||||
isCreatedByUser: true,
|
||||
text: 'hello',
|
||||
}));
|
||||
const getMessages = jest.fn(async () => messages);
|
||||
const getMessageTextStats = jest.fn();
|
||||
|
||||
const result = await resolveContextProjection(
|
||||
{ userId: 'user-1', getMessages, getMessageTextStats },
|
||||
{ ...baseParams, messageId: 'message-256' },
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(getMessages).toHaveBeenCalledTimes(1);
|
||||
expect(getMessageTextStats).not.toHaveBeenCalled();
|
||||
expect(createTokenCounter).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns null before loading bodies when the branch text is too large', async () => {
|
||||
const { createTokenCounter } = jest.requireMock('@librechat/agents');
|
||||
const getMessages = jest.fn(async () => [
|
||||
{
|
||||
messageId: 'message-1',
|
||||
parentMessageId: null,
|
||||
},
|
||||
]);
|
||||
const getMessageTextStats = jest.fn(async () => [textStats('message-1', 512 * 1024 + 1)]);
|
||||
const result = await resolveContextProjection(
|
||||
{
|
||||
userId: 'user-1',
|
||||
getMessages,
|
||||
getMessageTextStats,
|
||||
},
|
||||
baseParams,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(getMessages).toHaveBeenCalledTimes(1);
|
||||
expect(getMessageTextStats).toHaveBeenCalledWith(
|
||||
{
|
||||
conversationId: 'conversation-1',
|
||||
user: 'user-1',
|
||||
messageId: { $in: ['message-1'] },
|
||||
},
|
||||
{ limit: 1 },
|
||||
);
|
||||
expect(createTokenCounter).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('loads only branch message bodies after resolving the graph', async () => {
|
||||
const graph = [
|
||||
{ messageId: 'message-1', parentMessageId: null },
|
||||
{ messageId: 'message-2', parentMessageId: 'message-1' },
|
||||
{ messageId: 'off-branch', parentMessageId: null },
|
||||
];
|
||||
const bodies = [
|
||||
{
|
||||
messageId: 'message-1',
|
||||
parentMessageId: null,
|
||||
isCreatedByUser: true,
|
||||
text: 'first',
|
||||
tokenCount: 5,
|
||||
},
|
||||
{
|
||||
messageId: 'message-2',
|
||||
parentMessageId: 'message-1',
|
||||
isCreatedByUser: false,
|
||||
text: 'second',
|
||||
tokenCount: 6,
|
||||
},
|
||||
];
|
||||
const getMessages = jest.fn(async (_filter: object, select?: string) =>
|
||||
select === GRAPH_SELECT ? graph : bodies,
|
||||
);
|
||||
const getMessageTextStats = jest.fn(async () => [
|
||||
textStats('message-1', 5),
|
||||
textStats('message-2', 6),
|
||||
]);
|
||||
|
||||
const result = await resolveContextProjection(
|
||||
{ userId: 'user-1', getMessages, getMessageTextStats },
|
||||
{ ...baseParams, messageId: 'message-2' },
|
||||
);
|
||||
|
||||
expect(result).toEqual({ tokenCount: 1, maxContextTokens: 1000 });
|
||||
expect(getMessages).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
{ conversationId: 'conversation-1', user: 'user-1' },
|
||||
GRAPH_SELECT,
|
||||
{ limit: 513, sort: false },
|
||||
);
|
||||
expect(getMessageTextStats).toHaveBeenCalledWith(
|
||||
{
|
||||
conversationId: 'conversation-1',
|
||||
user: 'user-1',
|
||||
messageId: { $in: ['message-1', 'message-2'] },
|
||||
},
|
||||
{ limit: 2 },
|
||||
);
|
||||
expect(getMessages).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
{
|
||||
conversationId: 'conversation-1',
|
||||
user: 'user-1',
|
||||
messageId: { $in: ['message-1', 'message-2'] },
|
||||
},
|
||||
BODY_SELECT,
|
||||
{ limit: 2, sort: false },
|
||||
);
|
||||
});
|
||||
|
||||
it('returns null before loading bodies when a branch message has too many quotes', async () => {
|
||||
const { createTokenCounter } = jest.requireMock('@librechat/agents');
|
||||
const getMessages = jest.fn(async () => [
|
||||
{
|
||||
messageId: 'message-1',
|
||||
parentMessageId: null,
|
||||
},
|
||||
]);
|
||||
const getMessageTextStats = jest.fn(async () => [
|
||||
{
|
||||
...textStats('message-1'),
|
||||
quoteCount: QUOTE_MAX_COUNT + 1,
|
||||
quoteBytes: 10,
|
||||
quoteLineCount: QUOTE_MAX_COUNT + 1,
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await resolveContextProjection(
|
||||
{ userId: 'user-1', getMessages, getMessageTextStats },
|
||||
baseParams,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(getMessages).toHaveBeenCalledTimes(1);
|
||||
expect(getMessageTextStats).toHaveBeenCalledTimes(1);
|
||||
expect(createTokenCounter).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
|
@ -2,7 +2,13 @@ import { HumanMessage, AIMessage } from '@langchain/core/messages';
|
|||
import { Providers, createTokenCounter, projectAgentContextUsage } from '@librechat/agents';
|
||||
import type { TContextProjectionRequest, TContextUsageEvent } from 'librechat-data-provider';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import { mergeQuotedText } from '~/utils/quotes';
|
||||
import { QUOTE_MAX_COUNT, mergeQuotedText } from '~/utils/quotes';
|
||||
|
||||
const MAX_PROJECTION_MESSAGES = 512;
|
||||
const MAX_PROJECTION_BRANCH_MESSAGES = 256;
|
||||
const MAX_PROJECTION_BRANCH_TEXT_BYTES = 512 * 1024;
|
||||
const PROJECTION_GRAPH_SELECT = 'messageId parentMessageId metadata.summaryUsedTokens';
|
||||
const PROJECTION_BODY_SELECT = 'messageId parentMessageId tokenCount isCreatedByUser text quotes';
|
||||
|
||||
interface ProjectionMessage {
|
||||
messageId: string;
|
||||
|
|
@ -18,13 +24,42 @@ interface ProjectionMessage {
|
|||
metadata?: { summaryUsedTokens?: number };
|
||||
}
|
||||
|
||||
interface ProjectionMessageFilter {
|
||||
conversationId: string;
|
||||
user?: string;
|
||||
messageId?: string | { $in: string[] };
|
||||
}
|
||||
|
||||
interface ProjectionMessageQueryOptions {
|
||||
limit?: number;
|
||||
sort?: false;
|
||||
}
|
||||
|
||||
interface ProjectionMessageTextStats {
|
||||
messageId: string;
|
||||
textBytes: number;
|
||||
quoteCount: number;
|
||||
quoteBytes: number;
|
||||
quoteLineCount: number;
|
||||
nonStringQuoteCount: number;
|
||||
}
|
||||
|
||||
interface ProjectionMessageTextStatsOptions {
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
export interface ContextProjectionDeps {
|
||||
/** Authenticated requester — branch lookups are scoped to this user. */
|
||||
userId?: string;
|
||||
getMessages: (
|
||||
filter: { conversationId: string; user?: string },
|
||||
filter: ProjectionMessageFilter,
|
||||
select?: string,
|
||||
options?: ProjectionMessageQueryOptions,
|
||||
) => Promise<ProjectionMessage[]>;
|
||||
getMessageTextStats: (
|
||||
filter: ProjectionMessageFilter,
|
||||
options?: ProjectionMessageTextStatsOptions,
|
||||
) => Promise<ProjectionMessageTextStats[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -51,6 +86,83 @@ function resolveBranch(messages: ProjectionMessage[], tailId: string): Projectio
|
|||
return branch.reverse();
|
||||
}
|
||||
|
||||
function hasValidProjectionIds(params: TContextProjectionRequest): boolean {
|
||||
return typeof params.conversationId === 'string' && typeof params.messageId === 'string';
|
||||
}
|
||||
|
||||
function getProjectionText(message: ProjectionMessage): string | null {
|
||||
const hasQuotes =
|
||||
message.isCreatedByUser === true && Array.isArray(message.quotes) && message.quotes.length > 0;
|
||||
if (!hasQuotes) {
|
||||
return message.text ?? '';
|
||||
}
|
||||
if (message.quotes == null || message.quotes.length > QUOTE_MAX_COUNT) {
|
||||
return null;
|
||||
}
|
||||
for (const quote of message.quotes) {
|
||||
if (typeof quote !== 'string') {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return mergeQuotedText(message.text ?? '', message.quotes);
|
||||
}
|
||||
|
||||
function hasExceededBranchTextLimit(branch: ProjectionMessage[]): boolean {
|
||||
let bytes = 0;
|
||||
for (const message of branch) {
|
||||
const text = getProjectionText(message);
|
||||
if (text == null) {
|
||||
return true;
|
||||
}
|
||||
bytes += Buffer.byteLength(text, 'utf8');
|
||||
if (bytes > MAX_PROJECTION_BRANCH_TEXT_BYTES) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function getEstimatedMergedTextBytes(stats: ProjectionMessageTextStats): number | null {
|
||||
if (
|
||||
stats.nonStringQuoteCount > 0 ||
|
||||
stats.quoteCount > QUOTE_MAX_COUNT ||
|
||||
stats.quoteLineCount < stats.quoteCount
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
if (stats.quoteCount === 0) {
|
||||
return stats.textBytes;
|
||||
}
|
||||
|
||||
const quotePrefixBytes = stats.quoteLineCount * 2;
|
||||
const quoteLineBreakBytes = stats.quoteLineCount - stats.quoteCount;
|
||||
const quoteSeparatorBytes = (stats.quoteCount - 1) * 2;
|
||||
const bodySeparatorBytes = stats.textBytes > 0 ? 2 : 0;
|
||||
return (
|
||||
stats.textBytes +
|
||||
stats.quoteBytes +
|
||||
quotePrefixBytes +
|
||||
quoteLineBreakBytes +
|
||||
quoteSeparatorBytes +
|
||||
bodySeparatorBytes
|
||||
);
|
||||
}
|
||||
|
||||
function hasExceededBranchTextStatsLimit(stats: ProjectionMessageTextStats[]): boolean {
|
||||
let bytes = 0;
|
||||
for (const messageStats of stats) {
|
||||
const messageBytes = getEstimatedMergedTextBytes(messageStats);
|
||||
if (messageBytes == null) {
|
||||
return true;
|
||||
}
|
||||
bytes += messageBytes;
|
||||
if (bytes > MAX_PROJECTION_BRANCH_TEXT_BYTES) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Maps an endpoint/provider string to the agents `Providers` enum. */
|
||||
function resolveProvider(value?: string): Providers {
|
||||
if (value == null || value === '') {
|
||||
|
|
@ -74,6 +186,43 @@ function resolveProvider(value?: string): Providers {
|
|||
return Providers.OPENAI;
|
||||
}
|
||||
|
||||
async function getBranchMessages(
|
||||
deps: ContextProjectionDeps,
|
||||
baseFilter: ProjectionMessageFilter,
|
||||
branch: ProjectionMessage[],
|
||||
): Promise<ProjectionMessage[] | null> {
|
||||
const branchIds = branch.map((message) => message.messageId);
|
||||
const stats = await deps.getMessageTextStats(
|
||||
{ ...baseFilter, messageId: { $in: branchIds } },
|
||||
{ limit: branchIds.length },
|
||||
);
|
||||
if (stats.length !== branchIds.length || hasExceededBranchTextStatsLimit(stats)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const stored = await deps.getMessages(
|
||||
{ ...baseFilter, messageId: { $in: branchIds } },
|
||||
PROJECTION_BODY_SELECT,
|
||||
{ limit: branchIds.length, sort: false },
|
||||
);
|
||||
if (stored.length !== branchIds.length) {
|
||||
return null;
|
||||
}
|
||||
const byId = new Map<string, ProjectionMessage>();
|
||||
for (const message of stored) {
|
||||
byId.set(message.messageId, message);
|
||||
}
|
||||
const ordered: ProjectionMessage[] = [];
|
||||
for (const messageId of branchIds) {
|
||||
const message = byId.get(messageId);
|
||||
if (message == null) {
|
||||
return null;
|
||||
}
|
||||
ordered.push(message);
|
||||
}
|
||||
return ordered;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server-side context-usage projection: reconstructs the viewed branch and asks
|
||||
* the agents SDK what the next call's context would be, WITHOUT invoking the
|
||||
|
|
@ -90,19 +239,31 @@ export async function resolveContextProjection(
|
|||
deps: ContextProjectionDeps,
|
||||
params: TContextProjectionRequest,
|
||||
): Promise<TContextUsageEvent | null> {
|
||||
if (!hasValidProjectionIds(params)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxContextTokens = params.maxContextTokens;
|
||||
if (maxContextTokens == null || maxContextTokens <= 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const stored = await deps.getMessages(
|
||||
{ conversationId: params.conversationId, user: deps.userId },
|
||||
'messageId parentMessageId tokenCount isCreatedByUser text quotes metadata',
|
||||
);
|
||||
const baseFilter = { conversationId: params.conversationId, user: deps.userId };
|
||||
const stored = await deps.getMessages(baseFilter, PROJECTION_GRAPH_SELECT, {
|
||||
limit: MAX_PROJECTION_MESSAGES + 1,
|
||||
sort: false,
|
||||
});
|
||||
if (stored.length > MAX_PROJECTION_MESSAGES) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const branch = resolveBranch(stored, params.messageId);
|
||||
if (branch.length === 0) {
|
||||
return null;
|
||||
}
|
||||
if (branch.length > MAX_PROJECTION_BRANCH_MESSAGES) {
|
||||
return null;
|
||||
}
|
||||
|
||||
/** A summarized/compacted branch's next call sends the saved summary + the
|
||||
* post-summary tail, NOT this raw parent chain — projecting from the full
|
||||
|
|
@ -114,23 +275,29 @@ export async function resolveContextProjection(
|
|||
return null;
|
||||
}
|
||||
|
||||
const bodyBranch = await getBranchMessages(deps, baseFilter, branch);
|
||||
if (bodyBranch == null || hasExceededBranchTextLimit(bodyBranch)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const model = params.model;
|
||||
const encoding = (model ?? '').toLowerCase().includes('claude') ? 'claude' : 'o200k_base';
|
||||
const tokenCounter = await createTokenCounter(encoding);
|
||||
|
||||
const messages: BaseMessage[] = [];
|
||||
const indexTokenCountMap: Record<string, number> = {};
|
||||
for (let i = 0; i < branch.length; i++) {
|
||||
const message = branch[i];
|
||||
for (let i = 0; i < bodyBranch.length; i++) {
|
||||
const message = bodyBranch[i];
|
||||
/** Mirror the live path: prepend quoted excerpts into the user text the model
|
||||
* receives so the gauge counts the same prompt. */
|
||||
const hasQuotes =
|
||||
message.isCreatedByUser === true &&
|
||||
Array.isArray(message.quotes) &&
|
||||
message.quotes.length > 0;
|
||||
const text = hasQuotes
|
||||
? mergeQuotedText(message.text ?? '', message.quotes ?? [])
|
||||
: (message.text ?? '');
|
||||
const text = getProjectionText(message);
|
||||
if (text == null) {
|
||||
return null;
|
||||
}
|
||||
const lcMessage =
|
||||
message.isCreatedByUser === true ? new HumanMessage(text) : new AIMessage(text);
|
||||
messages.push(lcMessage);
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ let mongoServer: InstanceType<typeof MongoMemoryServer>;
|
|||
let Message: mongoose.Model<IMessage>;
|
||||
let saveMessage: ReturnType<typeof createMessageMethods>['saveMessage'];
|
||||
let getMessages: ReturnType<typeof createMessageMethods>['getMessages'];
|
||||
let getMessageTextStats: ReturnType<typeof createMessageMethods>['getMessageTextStats'];
|
||||
let updateMessage: ReturnType<typeof createMessageMethods>['updateMessage'];
|
||||
let deleteMessages: ReturnType<typeof createMessageMethods>['deleteMessages'];
|
||||
let bulkSaveMessages: ReturnType<typeof createMessageMethods>['bulkSaveMessages'];
|
||||
|
|
@ -39,6 +40,7 @@ beforeAll(async () => {
|
|||
const methods = createMessageMethods(mongoose);
|
||||
saveMessage = methods.saveMessage;
|
||||
getMessages = methods.getMessages;
|
||||
getMessageTextStats = methods.getMessageTextStats;
|
||||
updateMessage = methods.updateMessage;
|
||||
deleteMessages = methods.deleteMessages;
|
||||
bulkSaveMessages = methods.bulkSaveMessages;
|
||||
|
|
@ -240,6 +242,62 @@ describe('Message Operations', () => {
|
|||
expect(messages[0].text).toBe('First message');
|
||||
expect(messages[1].text).toBe('Second message');
|
||||
});
|
||||
|
||||
it('should limit retrieved messages when requested', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg3',
|
||||
conversationId,
|
||||
text: 'Third message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
const messages = await getMessages({ conversationId }, undefined, { limit: 2 });
|
||||
|
||||
expect(messages).toHaveLength(2);
|
||||
expect(messages[0].text).toBe('First message');
|
||||
expect(messages[1].text).toBe('Second message');
|
||||
});
|
||||
|
||||
it('should retrieve message text stats without returning message bodies', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
await saveMessage(mockCtx, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'hello',
|
||||
quotes: ['a\nb', ''],
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
const stats = await getMessageTextStats({ conversationId, user: 'user123' }, { limit: 1 });
|
||||
|
||||
expect(stats).toEqual([
|
||||
{
|
||||
messageId: 'msg1',
|
||||
textBytes: 5,
|
||||
quoteCount: 2,
|
||||
quoteBytes: 3,
|
||||
quoteLineCount: 3,
|
||||
nonStringQuoteCount: 0,
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessages', () => {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { RetentionMode } from 'librechat-data-provider';
|
||||
import type { DeleteResult, FilterQuery, Model } from 'mongoose';
|
||||
import type { DeleteResult, FilterQuery, Model, PipelineStage } from 'mongoose';
|
||||
import type { AppConfig, IMessage } from '~/types';
|
||||
import { createTempChatExpirationDate } from '~/utils/tempChatRetention';
|
||||
import { createFallbackRetentionDate } from '~/utils/retention';
|
||||
|
|
@ -9,6 +9,24 @@ import logger from '~/config/winston';
|
|||
/** Simple UUID v4 regex to replace zod validation */
|
||||
const UUID_REGEX = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||
|
||||
interface MessageQueryOptions {
|
||||
limit?: number;
|
||||
sort?: Record<string, 1 | -1> | false;
|
||||
}
|
||||
|
||||
interface MessageTextStatsOptions {
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
export interface MessageTextStats {
|
||||
messageId: string;
|
||||
textBytes: number;
|
||||
quoteCount: number;
|
||||
quoteBytes: number;
|
||||
quoteLineCount: number;
|
||||
nonStringQuoteCount: number;
|
||||
}
|
||||
|
||||
export interface MessageMethods {
|
||||
saveMessage(
|
||||
ctx: { userId: string; isTemporary?: boolean; interfaceConfig?: AppConfig['interfaceConfig'] },
|
||||
|
|
@ -37,7 +55,15 @@ export interface MessageMethods {
|
|||
userId: string,
|
||||
params: { messageId: string; conversationId: string },
|
||||
): Promise<DeleteResult>;
|
||||
getMessages(filter: FilterQuery<IMessage>, select?: string): Promise<IMessage[]>;
|
||||
getMessages(
|
||||
filter: FilterQuery<IMessage>,
|
||||
select?: string,
|
||||
options?: MessageQueryOptions,
|
||||
): Promise<IMessage[]>;
|
||||
getMessageTextStats(
|
||||
filter: FilterQuery<IMessage>,
|
||||
options?: MessageTextStatsOptions,
|
||||
): Promise<MessageTextStats[]>;
|
||||
getMessage(params: { user: string; messageId: string }): Promise<IMessage | null>;
|
||||
getMessagesByCursor(
|
||||
filter: FilterQuery<IMessage>,
|
||||
|
|
@ -323,20 +349,118 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa
|
|||
/**
|
||||
* Retrieves messages from the database.
|
||||
*/
|
||||
async function getMessages(filter: FilterQuery<IMessage>, select?: string) {
|
||||
async function getMessages(
|
||||
filter: FilterQuery<IMessage>,
|
||||
select?: string,
|
||||
options: MessageQueryOptions = {},
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const query = Message.find(filter);
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean<IMessage[]>();
|
||||
query.select(select);
|
||||
}
|
||||
if (options.sort !== false) {
|
||||
query.sort(options.sort ?? { createdAt: 1 });
|
||||
}
|
||||
if (options.limit != null && options.limit > 0) {
|
||||
query.limit(options.limit);
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean<IMessage[]>();
|
||||
return await query.lean<IMessage[]>();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async function getMessageTextStats(
|
||||
filter: FilterQuery<IMessage>,
|
||||
options: MessageTextStatsOptions = {},
|
||||
) {
|
||||
try {
|
||||
const Message = mongoose.models.Message as Model<IMessage>;
|
||||
const pipeline: PipelineStage[] = [{ $match: filter }];
|
||||
if (options.limit != null && options.limit > 0) {
|
||||
pipeline.push({ $limit: options.limit });
|
||||
}
|
||||
pipeline.push({
|
||||
$project: {
|
||||
_id: 0,
|
||||
messageId: 1,
|
||||
textBytes: {
|
||||
$cond: [{ $eq: [{ $type: '$text' }, 'string'] }, { $strLenBytes: '$text' }, 0],
|
||||
},
|
||||
quoteCount: {
|
||||
$cond: [{ $isArray: '$quotes' }, { $size: '$quotes' }, 0],
|
||||
},
|
||||
quoteBytes: {
|
||||
$cond: [
|
||||
{ $isArray: '$quotes' },
|
||||
{
|
||||
$sum: {
|
||||
$map: {
|
||||
input: '$quotes',
|
||||
as: 'quote',
|
||||
in: {
|
||||
$cond: [
|
||||
{ $eq: [{ $type: '$$quote' }, 'string'] },
|
||||
{ $strLenBytes: '$$quote' },
|
||||
0,
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
0,
|
||||
],
|
||||
},
|
||||
quoteLineCount: {
|
||||
$cond: [
|
||||
{ $isArray: '$quotes' },
|
||||
{
|
||||
$sum: {
|
||||
$map: {
|
||||
input: '$quotes',
|
||||
as: 'quote',
|
||||
in: {
|
||||
$cond: [
|
||||
{ $eq: [{ $type: '$$quote' }, 'string'] },
|
||||
{ $size: { $split: ['$$quote', '\n'] } },
|
||||
0,
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
0,
|
||||
],
|
||||
},
|
||||
nonStringQuoteCount: {
|
||||
$cond: [
|
||||
{ $isArray: '$quotes' },
|
||||
{
|
||||
$size: {
|
||||
$filter: {
|
||||
input: '$quotes',
|
||||
as: 'quote',
|
||||
cond: { $ne: [{ $type: '$$quote' }, 'string'] },
|
||||
},
|
||||
},
|
||||
},
|
||||
0,
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return await Message.aggregate<MessageTextStats>(pipeline);
|
||||
} catch (err) {
|
||||
logger.error('Error getting message text stats:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a single message from the database.
|
||||
*/
|
||||
|
|
@ -423,6 +547,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa
|
|||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
getMessageTextStats,
|
||||
getMessage,
|
||||
getMessagesByCursor,
|
||||
searchMessages,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue