From 77854decdfd35933c231b0e426f7371081bdf6bf Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 23 Jun 2026 08:43:09 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=A3=20fix:=20Cap=20Context=20Projectio?= =?UTF-8?q?n=20Workload=20Before=20Tokenization=20(#13910)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: bound context projection workload * fix: Address context projection CI failures * fix: Bound context projection database reads * fix: Sort projection spec imports * fix: Cap projection body reads with stats --- .../ContextProjectionController.js | 6 +- .../limiters/contextProjectionLimiter.js | 19 ++ api/server/middleware/limiters/index.js | 2 + api/server/routes/endpoints.js | 9 +- packages/api/src/endpoints/projection.spec.ts | 206 ++++++++++++++++++ packages/api/src/endpoints/projection.ts | 189 +++++++++++++++- .../data-schemas/src/methods/message.spec.ts | 58 +++++ packages/data-schemas/src/methods/message.ts | 135 +++++++++++- 8 files changed, 606 insertions(+), 18 deletions(-) create mode 100644 api/server/middleware/limiters/contextProjectionLimiter.js create mode 100644 packages/api/src/endpoints/projection.spec.ts diff --git a/api/server/controllers/ContextProjectionController.js b/api/server/controllers/ContextProjectionController.js index eaf9592e73..9c56b2ae34 100644 --- a/api/server/controllers/ContextProjectionController.js +++ b/api/server/controllers/ContextProjectionController.js @@ -18,7 +18,11 @@ async function contextProjectionController(req, res) { return; } const projection = await resolveContextProjection( - { userId: req.user?.id, getMessages: db.getMessages }, + { + userId: req.user?.id, + getMessages: db.getMessages, + getMessageTextStats: db.getMessageTextStats, + }, params, ); res.json(projection ?? null); diff --git a/api/server/middleware/limiters/contextProjectionLimiter.js b/api/server/middleware/limiters/contextProjectionLimiter.js new file mode 100644 index 0000000000..1f70c7ea8e --- /dev/null +++ b/api/server/middleware/limiters/contextProjectionLimiter.js @@ -0,0 +1,19 @@ +const rateLimit = require('express-rate-limit'); +const { limiterCache } = require('@librechat/api'); + +const { CONTEXT_PROJECTION_WINDOW = 1, CONTEXT_PROJECTION_MAX = 20 } = process.env; + +const windowMs = (parseInt(CONTEXT_PROJECTION_WINDOW, 10) || 1) * 60 * 1000; +const max = parseInt(CONTEXT_PROJECTION_MAX, 10) || 20; + +const contextProjectionLimiter = rateLimit({ + windowMs, + max, + handler: (_req, res) => { + res.status(429).json({ message: 'Too many context projection requests. Try again later' }); + }, + keyGenerator: (req) => req.user?.id, + store: limiterCache('context_projection_limiter'), +}); + +module.exports = contextProjectionLimiter; diff --git a/api/server/middleware/limiters/index.js b/api/server/middleware/limiters/index.js index 4a569e2698..19f246d039 100644 --- a/api/server/middleware/limiters/index.js +++ b/api/server/middleware/limiters/index.js @@ -9,6 +9,7 @@ const registerLimiter = require('./registerLimiter'); const toolCallLimiter = require('./toolCallLimiter'); const messageLimiters = require('./messageLimiters'); const promptUsageLimiter = require('./promptUsageLimiter'); +const contextProjectionLimiter = require('./contextProjectionLimiter'); const verifyEmailLimiter = require('./verifyEmailLimiter'); const resetPasswordLimiter = require('./resetPasswordLimiter'); const twoFactorTempLimiter = require('./twoFactorTempLimiter'); @@ -22,6 +23,7 @@ module.exports = { loginLimiter, registerLimiter, toolCallLimiter, + contextProjectionLimiter, createTTSLimiters, createSTTLimiters, verifyEmailLimiter, diff --git a/api/server/routes/endpoints.js b/api/server/routes/endpoints.js index b11de153df..ea55a9e54a 100644 --- a/api/server/routes/endpoints.js +++ b/api/server/routes/endpoints.js @@ -4,11 +4,18 @@ const configMiddleware = require('~/server/middleware/config/app'); const endpointController = require('~/server/controllers/EndpointController'); const tokenConfigController = require('~/server/controllers/TokenConfigController'); const contextProjectionController = require('~/server/controllers/ContextProjectionController'); +const { contextProjectionLimiter } = require('~/server/middleware/limiters'); const router = express.Router(); /** Auth required for role/tenant-scoped endpoint config resolution. */ router.get('/', requireJwtAuth, endpointController); router.get('/token-config', requireJwtAuth, configMiddleware, tokenConfigController); -router.post('/context-projection', requireJwtAuth, configMiddleware, contextProjectionController); +router.post( + '/context-projection', + requireJwtAuth, + contextProjectionLimiter, + configMiddleware, + contextProjectionController, +); module.exports = router; diff --git a/packages/api/src/endpoints/projection.spec.ts b/packages/api/src/endpoints/projection.spec.ts new file mode 100644 index 0000000000..1fb7f91a4c --- /dev/null +++ b/packages/api/src/endpoints/projection.spec.ts @@ -0,0 +1,206 @@ +import { resolveContextProjection } from './projection'; +import { QUOTE_MAX_COUNT } from '~/utils/quotes'; + +jest.mock('@librechat/agents', () => ({ + Providers: { OPENAI: 'openai' }, + createTokenCounter: jest.fn(async () => jest.fn(() => 1)), + projectAgentContextUsage: jest.fn(() => ({ tokenCount: 1, maxContextTokens: 1000 })), +})); + +const GRAPH_SELECT = 'messageId parentMessageId metadata.summaryUsedTokens'; +const BODY_SELECT = 'messageId parentMessageId tokenCount isCreatedByUser text quotes'; + +function textStats(messageId: string, textBytes = 5) { + return { + messageId, + textBytes, + quoteCount: 0, + quoteBytes: 0, + quoteLineCount: 0, + nonStringQuoteCount: 0, + }; +} + +describe('resolveContextProjection', () => { + const baseParams = { + conversationId: 'conversation-1', + messageId: 'message-1', + endpoint: 'openai', + maxContextTokens: 1000, + model: 'gpt-4o', + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('returns null before tokenization when the conversation is too large', async () => { + const { createTokenCounter } = jest.requireMock('@librechat/agents'); + const messages = Array.from({ length: 513 }, (_, index) => ({ + messageId: `message-${index}`, + parentMessageId: index === 0 ? null : `message-${index - 1}`, + isCreatedByUser: true, + text: 'hello', + })); + const getMessages = jest.fn(async () => messages); + const getMessageTextStats = jest.fn(); + + const result = await resolveContextProjection( + { userId: 'user-1', getMessages, getMessageTextStats }, + { ...baseParams, messageId: 'message-512' }, + ); + + expect(result).toBeNull(); + expect(getMessages).toHaveBeenCalledTimes(1); + expect(getMessages).toHaveBeenCalledWith( + { conversationId: 'conversation-1', user: 'user-1' }, + GRAPH_SELECT, + { limit: 513, sort: false }, + ); + expect(getMessageTextStats).not.toHaveBeenCalled(); + expect(createTokenCounter).not.toHaveBeenCalled(); + }); + + it('returns null before tokenization when the branch is too long', async () => { + const { createTokenCounter } = jest.requireMock('@librechat/agents'); + const messages = Array.from({ length: 257 }, (_, index) => ({ + messageId: `message-${index}`, + parentMessageId: index === 0 ? null : `message-${index - 1}`, + isCreatedByUser: true, + text: 'hello', + })); + const getMessages = jest.fn(async () => messages); + const getMessageTextStats = jest.fn(); + + const result = await resolveContextProjection( + { userId: 'user-1', getMessages, getMessageTextStats }, + { ...baseParams, messageId: 'message-256' }, + ); + + expect(result).toBeNull(); + expect(getMessages).toHaveBeenCalledTimes(1); + expect(getMessageTextStats).not.toHaveBeenCalled(); + expect(createTokenCounter).not.toHaveBeenCalled(); + }); + + it('returns null before loading bodies when the branch text is too large', async () => { + const { createTokenCounter } = jest.requireMock('@librechat/agents'); + const getMessages = jest.fn(async () => [ + { + messageId: 'message-1', + parentMessageId: null, + }, + ]); + const getMessageTextStats = jest.fn(async () => [textStats('message-1', 512 * 1024 + 1)]); + const result = await resolveContextProjection( + { + userId: 'user-1', + getMessages, + getMessageTextStats, + }, + baseParams, + ); + + expect(result).toBeNull(); + expect(getMessages).toHaveBeenCalledTimes(1); + expect(getMessageTextStats).toHaveBeenCalledWith( + { + conversationId: 'conversation-1', + user: 'user-1', + messageId: { $in: ['message-1'] }, + }, + { limit: 1 }, + ); + expect(createTokenCounter).not.toHaveBeenCalled(); + }); + + it('loads only branch message bodies after resolving the graph', async () => { + const graph = [ + { messageId: 'message-1', parentMessageId: null }, + { messageId: 'message-2', parentMessageId: 'message-1' }, + { messageId: 'off-branch', parentMessageId: null }, + ]; + const bodies = [ + { + messageId: 'message-1', + parentMessageId: null, + isCreatedByUser: true, + text: 'first', + tokenCount: 5, + }, + { + messageId: 'message-2', + parentMessageId: 'message-1', + isCreatedByUser: false, + text: 'second', + tokenCount: 6, + }, + ]; + const getMessages = jest.fn(async (_filter: object, select?: string) => + select === GRAPH_SELECT ? graph : bodies, + ); + const getMessageTextStats = jest.fn(async () => [ + textStats('message-1', 5), + textStats('message-2', 6), + ]); + + const result = await resolveContextProjection( + { userId: 'user-1', getMessages, getMessageTextStats }, + { ...baseParams, messageId: 'message-2' }, + ); + + expect(result).toEqual({ tokenCount: 1, maxContextTokens: 1000 }); + expect(getMessages).toHaveBeenNthCalledWith( + 1, + { conversationId: 'conversation-1', user: 'user-1' }, + GRAPH_SELECT, + { limit: 513, sort: false }, + ); + expect(getMessageTextStats).toHaveBeenCalledWith( + { + conversationId: 'conversation-1', + user: 'user-1', + messageId: { $in: ['message-1', 'message-2'] }, + }, + { limit: 2 }, + ); + expect(getMessages).toHaveBeenNthCalledWith( + 2, + { + conversationId: 'conversation-1', + user: 'user-1', + messageId: { $in: ['message-1', 'message-2'] }, + }, + BODY_SELECT, + { limit: 2, sort: false }, + ); + }); + + it('returns null before loading bodies when a branch message has too many quotes', async () => { + const { createTokenCounter } = jest.requireMock('@librechat/agents'); + const getMessages = jest.fn(async () => [ + { + messageId: 'message-1', + parentMessageId: null, + }, + ]); + const getMessageTextStats = jest.fn(async () => [ + { + ...textStats('message-1'), + quoteCount: QUOTE_MAX_COUNT + 1, + quoteBytes: 10, + quoteLineCount: QUOTE_MAX_COUNT + 1, + }, + ]); + + const result = await resolveContextProjection( + { userId: 'user-1', getMessages, getMessageTextStats }, + baseParams, + ); + + expect(result).toBeNull(); + expect(getMessages).toHaveBeenCalledTimes(1); + expect(getMessageTextStats).toHaveBeenCalledTimes(1); + expect(createTokenCounter).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/endpoints/projection.ts b/packages/api/src/endpoints/projection.ts index 365cccc4e8..8622e4b981 100644 --- a/packages/api/src/endpoints/projection.ts +++ b/packages/api/src/endpoints/projection.ts @@ -2,7 +2,13 @@ import { HumanMessage, AIMessage } from '@langchain/core/messages'; import { Providers, createTokenCounter, projectAgentContextUsage } from '@librechat/agents'; import type { TContextProjectionRequest, TContextUsageEvent } from 'librechat-data-provider'; import type { BaseMessage } from '@langchain/core/messages'; -import { mergeQuotedText } from '~/utils/quotes'; +import { QUOTE_MAX_COUNT, mergeQuotedText } from '~/utils/quotes'; + +const MAX_PROJECTION_MESSAGES = 512; +const MAX_PROJECTION_BRANCH_MESSAGES = 256; +const MAX_PROJECTION_BRANCH_TEXT_BYTES = 512 * 1024; +const PROJECTION_GRAPH_SELECT = 'messageId parentMessageId metadata.summaryUsedTokens'; +const PROJECTION_BODY_SELECT = 'messageId parentMessageId tokenCount isCreatedByUser text quotes'; interface ProjectionMessage { messageId: string; @@ -18,13 +24,42 @@ interface ProjectionMessage { metadata?: { summaryUsedTokens?: number }; } +interface ProjectionMessageFilter { + conversationId: string; + user?: string; + messageId?: string | { $in: string[] }; +} + +interface ProjectionMessageQueryOptions { + limit?: number; + sort?: false; +} + +interface ProjectionMessageTextStats { + messageId: string; + textBytes: number; + quoteCount: number; + quoteBytes: number; + quoteLineCount: number; + nonStringQuoteCount: number; +} + +interface ProjectionMessageTextStatsOptions { + limit?: number; +} + export interface ContextProjectionDeps { /** Authenticated requester — branch lookups are scoped to this user. */ userId?: string; getMessages: ( - filter: { conversationId: string; user?: string }, + filter: ProjectionMessageFilter, select?: string, + options?: ProjectionMessageQueryOptions, ) => Promise; + getMessageTextStats: ( + filter: ProjectionMessageFilter, + options?: ProjectionMessageTextStatsOptions, + ) => Promise; } /** @@ -51,6 +86,83 @@ function resolveBranch(messages: ProjectionMessage[], tailId: string): Projectio return branch.reverse(); } +function hasValidProjectionIds(params: TContextProjectionRequest): boolean { + return typeof params.conversationId === 'string' && typeof params.messageId === 'string'; +} + +function getProjectionText(message: ProjectionMessage): string | null { + const hasQuotes = + message.isCreatedByUser === true && Array.isArray(message.quotes) && message.quotes.length > 0; + if (!hasQuotes) { + return message.text ?? ''; + } + if (message.quotes == null || message.quotes.length > QUOTE_MAX_COUNT) { + return null; + } + for (const quote of message.quotes) { + if (typeof quote !== 'string') { + return null; + } + } + return mergeQuotedText(message.text ?? '', message.quotes); +} + +function hasExceededBranchTextLimit(branch: ProjectionMessage[]): boolean { + let bytes = 0; + for (const message of branch) { + const text = getProjectionText(message); + if (text == null) { + return true; + } + bytes += Buffer.byteLength(text, 'utf8'); + if (bytes > MAX_PROJECTION_BRANCH_TEXT_BYTES) { + return true; + } + } + return false; +} + +function getEstimatedMergedTextBytes(stats: ProjectionMessageTextStats): number | null { + if ( + stats.nonStringQuoteCount > 0 || + stats.quoteCount > QUOTE_MAX_COUNT || + stats.quoteLineCount < stats.quoteCount + ) { + return null; + } + if (stats.quoteCount === 0) { + return stats.textBytes; + } + + const quotePrefixBytes = stats.quoteLineCount * 2; + const quoteLineBreakBytes = stats.quoteLineCount - stats.quoteCount; + const quoteSeparatorBytes = (stats.quoteCount - 1) * 2; + const bodySeparatorBytes = stats.textBytes > 0 ? 2 : 0; + return ( + stats.textBytes + + stats.quoteBytes + + quotePrefixBytes + + quoteLineBreakBytes + + quoteSeparatorBytes + + bodySeparatorBytes + ); +} + +function hasExceededBranchTextStatsLimit(stats: ProjectionMessageTextStats[]): boolean { + let bytes = 0; + for (const messageStats of stats) { + const messageBytes = getEstimatedMergedTextBytes(messageStats); + if (messageBytes == null) { + return true; + } + bytes += messageBytes; + if (bytes > MAX_PROJECTION_BRANCH_TEXT_BYTES) { + return true; + } + } + return false; +} + /** Maps an endpoint/provider string to the agents `Providers` enum. */ function resolveProvider(value?: string): Providers { if (value == null || value === '') { @@ -74,6 +186,43 @@ function resolveProvider(value?: string): Providers { return Providers.OPENAI; } +async function getBranchMessages( + deps: ContextProjectionDeps, + baseFilter: ProjectionMessageFilter, + branch: ProjectionMessage[], +): Promise { + const branchIds = branch.map((message) => message.messageId); + const stats = await deps.getMessageTextStats( + { ...baseFilter, messageId: { $in: branchIds } }, + { limit: branchIds.length }, + ); + if (stats.length !== branchIds.length || hasExceededBranchTextStatsLimit(stats)) { + return null; + } + + const stored = await deps.getMessages( + { ...baseFilter, messageId: { $in: branchIds } }, + PROJECTION_BODY_SELECT, + { limit: branchIds.length, sort: false }, + ); + if (stored.length !== branchIds.length) { + return null; + } + const byId = new Map(); + for (const message of stored) { + byId.set(message.messageId, message); + } + const ordered: ProjectionMessage[] = []; + for (const messageId of branchIds) { + const message = byId.get(messageId); + if (message == null) { + return null; + } + ordered.push(message); + } + return ordered; +} + /** * Server-side context-usage projection: reconstructs the viewed branch and asks * the agents SDK what the next call's context would be, WITHOUT invoking the @@ -90,19 +239,31 @@ export async function resolveContextProjection( deps: ContextProjectionDeps, params: TContextProjectionRequest, ): Promise { + if (!hasValidProjectionIds(params)) { + return null; + } + const maxContextTokens = params.maxContextTokens; if (maxContextTokens == null || maxContextTokens <= 0) { return null; } - const stored = await deps.getMessages( - { conversationId: params.conversationId, user: deps.userId }, - 'messageId parentMessageId tokenCount isCreatedByUser text quotes metadata', - ); + const baseFilter = { conversationId: params.conversationId, user: deps.userId }; + const stored = await deps.getMessages(baseFilter, PROJECTION_GRAPH_SELECT, { + limit: MAX_PROJECTION_MESSAGES + 1, + sort: false, + }); + if (stored.length > MAX_PROJECTION_MESSAGES) { + return null; + } + const branch = resolveBranch(stored, params.messageId); if (branch.length === 0) { return null; } + if (branch.length > MAX_PROJECTION_BRANCH_MESSAGES) { + return null; + } /** A summarized/compacted branch's next call sends the saved summary + the * post-summary tail, NOT this raw parent chain — projecting from the full @@ -114,23 +275,29 @@ export async function resolveContextProjection( return null; } + const bodyBranch = await getBranchMessages(deps, baseFilter, branch); + if (bodyBranch == null || hasExceededBranchTextLimit(bodyBranch)) { + return null; + } + const model = params.model; const encoding = (model ?? '').toLowerCase().includes('claude') ? 'claude' : 'o200k_base'; const tokenCounter = await createTokenCounter(encoding); const messages: BaseMessage[] = []; const indexTokenCountMap: Record = {}; - for (let i = 0; i < branch.length; i++) { - const message = branch[i]; + for (let i = 0; i < bodyBranch.length; i++) { + const message = bodyBranch[i]; /** Mirror the live path: prepend quoted excerpts into the user text the model * receives so the gauge counts the same prompt. */ const hasQuotes = message.isCreatedByUser === true && Array.isArray(message.quotes) && message.quotes.length > 0; - const text = hasQuotes - ? mergeQuotedText(message.text ?? '', message.quotes ?? []) - : (message.text ?? ''); + const text = getProjectionText(message); + if (text == null) { + return null; + } const lcMessage = message.isCreatedByUser === true ? new HumanMessage(text) : new AIMessage(text); messages.push(lcMessage); diff --git a/packages/data-schemas/src/methods/message.spec.ts b/packages/data-schemas/src/methods/message.spec.ts index 3c3986de93..7e3747d12c 100644 --- a/packages/data-schemas/src/methods/message.spec.ts +++ b/packages/data-schemas/src/methods/message.spec.ts @@ -21,6 +21,7 @@ let mongoServer: InstanceType; let Message: mongoose.Model; let saveMessage: ReturnType['saveMessage']; let getMessages: ReturnType['getMessages']; +let getMessageTextStats: ReturnType['getMessageTextStats']; let updateMessage: ReturnType['updateMessage']; let deleteMessages: ReturnType['deleteMessages']; let bulkSaveMessages: ReturnType['bulkSaveMessages']; @@ -39,6 +40,7 @@ beforeAll(async () => { const methods = createMessageMethods(mongoose); saveMessage = methods.saveMessage; getMessages = methods.getMessages; + getMessageTextStats = methods.getMessageTextStats; updateMessage = methods.updateMessage; deleteMessages = methods.deleteMessages; bulkSaveMessages = methods.bulkSaveMessages; @@ -240,6 +242,62 @@ describe('Message Operations', () => { expect(messages[0].text).toBe('First message'); expect(messages[1].text).toBe('Second message'); }); + + it('should limit retrieved messages when requested', async () => { + const conversationId = uuidv4(); + + await saveMessage(mockCtx, { + messageId: 'msg1', + conversationId, + text: 'First message', + user: 'user123', + }); + + await saveMessage(mockCtx, { + messageId: 'msg2', + conversationId, + text: 'Second message', + user: 'user123', + }); + + await saveMessage(mockCtx, { + messageId: 'msg3', + conversationId, + text: 'Third message', + user: 'user123', + }); + + const messages = await getMessages({ conversationId }, undefined, { limit: 2 }); + + expect(messages).toHaveLength(2); + expect(messages[0].text).toBe('First message'); + expect(messages[1].text).toBe('Second message'); + }); + + it('should retrieve message text stats without returning message bodies', async () => { + const conversationId = uuidv4(); + + await saveMessage(mockCtx, { + messageId: 'msg1', + conversationId, + text: 'hello', + quotes: ['a\nb', ''], + user: 'user123', + }); + + const stats = await getMessageTextStats({ conversationId, user: 'user123' }, { limit: 1 }); + + expect(stats).toEqual([ + { + messageId: 'msg1', + textBytes: 5, + quoteCount: 2, + quoteBytes: 3, + quoteLineCount: 3, + nonStringQuoteCount: 0, + }, + ]); + }); }); describe('deleteMessages', () => { diff --git a/packages/data-schemas/src/methods/message.ts b/packages/data-schemas/src/methods/message.ts index 9ee15c7978..42273cbf5f 100644 --- a/packages/data-schemas/src/methods/message.ts +++ b/packages/data-schemas/src/methods/message.ts @@ -1,5 +1,5 @@ import { RetentionMode } from 'librechat-data-provider'; -import type { DeleteResult, FilterQuery, Model } from 'mongoose'; +import type { DeleteResult, FilterQuery, Model, PipelineStage } from 'mongoose'; import type { AppConfig, IMessage } from '~/types'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; import { createFallbackRetentionDate } from '~/utils/retention'; @@ -9,6 +9,24 @@ import logger from '~/config/winston'; /** Simple UUID v4 regex to replace zod validation */ const UUID_REGEX = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; +interface MessageQueryOptions { + limit?: number; + sort?: Record | false; +} + +interface MessageTextStatsOptions { + limit?: number; +} + +export interface MessageTextStats { + messageId: string; + textBytes: number; + quoteCount: number; + quoteBytes: number; + quoteLineCount: number; + nonStringQuoteCount: number; +} + export interface MessageMethods { saveMessage( ctx: { userId: string; isTemporary?: boolean; interfaceConfig?: AppConfig['interfaceConfig'] }, @@ -37,7 +55,15 @@ export interface MessageMethods { userId: string, params: { messageId: string; conversationId: string }, ): Promise; - getMessages(filter: FilterQuery, select?: string): Promise; + getMessages( + filter: FilterQuery, + select?: string, + options?: MessageQueryOptions, + ): Promise; + getMessageTextStats( + filter: FilterQuery, + options?: MessageTextStatsOptions, + ): Promise; getMessage(params: { user: string; messageId: string }): Promise; getMessagesByCursor( filter: FilterQuery, @@ -323,20 +349,118 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa /** * Retrieves messages from the database. */ - async function getMessages(filter: FilterQuery, select?: string) { + async function getMessages( + filter: FilterQuery, + select?: string, + options: MessageQueryOptions = {}, + ) { try { const Message = mongoose.models.Message as Model; + const query = Message.find(filter); if (select) { - return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean(); + query.select(select); + } + if (options.sort !== false) { + query.sort(options.sort ?? { createdAt: 1 }); + } + if (options.limit != null && options.limit > 0) { + query.limit(options.limit); } - return await Message.find(filter).sort({ createdAt: 1 }).lean(); + return await query.lean(); } catch (err) { logger.error('Error getting messages:', err); throw err; } } + async function getMessageTextStats( + filter: FilterQuery, + options: MessageTextStatsOptions = {}, + ) { + try { + const Message = mongoose.models.Message as Model; + const pipeline: PipelineStage[] = [{ $match: filter }]; + if (options.limit != null && options.limit > 0) { + pipeline.push({ $limit: options.limit }); + } + pipeline.push({ + $project: { + _id: 0, + messageId: 1, + textBytes: { + $cond: [{ $eq: [{ $type: '$text' }, 'string'] }, { $strLenBytes: '$text' }, 0], + }, + quoteCount: { + $cond: [{ $isArray: '$quotes' }, { $size: '$quotes' }, 0], + }, + quoteBytes: { + $cond: [ + { $isArray: '$quotes' }, + { + $sum: { + $map: { + input: '$quotes', + as: 'quote', + in: { + $cond: [ + { $eq: [{ $type: '$$quote' }, 'string'] }, + { $strLenBytes: '$$quote' }, + 0, + ], + }, + }, + }, + }, + 0, + ], + }, + quoteLineCount: { + $cond: [ + { $isArray: '$quotes' }, + { + $sum: { + $map: { + input: '$quotes', + as: 'quote', + in: { + $cond: [ + { $eq: [{ $type: '$$quote' }, 'string'] }, + { $size: { $split: ['$$quote', '\n'] } }, + 0, + ], + }, + }, + }, + }, + 0, + ], + }, + nonStringQuoteCount: { + $cond: [ + { $isArray: '$quotes' }, + { + $size: { + $filter: { + input: '$quotes', + as: 'quote', + cond: { $ne: [{ $type: '$$quote' }, 'string'] }, + }, + }, + }, + 0, + ], + }, + }, + }); + + return await Message.aggregate(pipeline); + } catch (err) { + logger.error('Error getting message text stats:', err); + throw err; + } + } + /** * Retrieves a single message from the database. */ @@ -423,6 +547,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa updateMessage, deleteMessagesSince, getMessages, + getMessageTextStats, getMessage, getMessagesByCursor, searchMessages,