diff --git a/api/models/ConversationTag.js b/api/models/ConversationTag.js index e6dc96be64..47a6c2bbf5 100644 --- a/api/models/ConversationTag.js +++ b/api/models/ConversationTag.js @@ -239,10 +239,46 @@ const updateTagsForConversation = async (user, conversationId, tags) => { } }; +/** + * Increments tag counts for existing tags only. + * @param {string} user - The user ID. + * @param {string[]} tags - Array of tag names to increment + * @returns {Promise} + */ +const bulkIncrementTagCounts = async (user, tags) => { + if (!tags || tags.length === 0) { + return; + } + + try { + const uniqueTags = [...new Set(tags.filter(Boolean))]; + if (uniqueTags.length === 0) { + return; + } + + const bulkOps = uniqueTags.map((tag) => ({ + updateOne: { + filter: { user, tag }, + update: { $inc: { count: 1 } }, + }, + })); + + const result = await ConversationTag.bulkWrite(bulkOps); + if (result && result.modifiedCount > 0) { + logger.debug( + `user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`, + ); + } + } catch (error) { + logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error); + } +}; + module.exports = { getConversationTags, createConversationTag, updateConversationTag, deleteConversationTag, + bulkIncrementTagCounts, updateTagsForConversation, }; diff --git a/api/server/utils/import/fork.spec.js b/api/server/utils/import/fork.spec.js index 4520e977bf..552620dc89 100644 --- a/api/server/utils/import/fork.spec.js +++ b/api/server/utils/import/fork.spec.js @@ -10,6 +10,10 @@ jest.mock('~/models/Message', () => ({ bulkSaveMessages: jest.fn(), })); +jest.mock('~/models/ConversationTag', () => ({ + bulkIncrementTagCounts: jest.fn(), +})); + let mockIdCounter = 0; jest.mock('uuid', () => { return { @@ -22,11 +26,13 @@ jest.mock('uuid', () => { const { forkConversation, + duplicateConversation, splitAtTargetLevel, getAllMessagesUpToParent, getMessagesUpToTargetLevel, cloneMessagesWithTimestamps, } = require('./fork'); +const { bulkIncrementTagCounts } = require('~/models/ConversationTag'); const { getConvo, bulkSaveConvos } = require('~/models/Conversation'); const { getMessages, bulkSaveMessages } = require('~/models/Message'); const { createImportBatchBuilder } = require('./importBatchBuilder'); @@ -181,6 +187,120 @@ describe('forkConversation', () => { }), ).rejects.toThrow('Failed to fetch messages'); }); + + test('should increment tag counts when forking conversation with tags', async () => { + const mockConvoWithTags = { + ...mockConversation, + tags: ['bookmark1', 'bookmark2'], + }; + getConvo.mockResolvedValue(mockConvoWithTags); + + await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.DIRECT_PATH, + }); + + // Verify that bulkIncrementTagCounts was called with correct tags + expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['bookmark1', 'bookmark2']); + }); + + test('should handle conversation without tags when forking', async () => { + const mockConvoWithoutTags = { + ...mockConversation, + // No tags field + }; + getConvo.mockResolvedValue(mockConvoWithoutTags); + + await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.DIRECT_PATH, + }); + + // bulkIncrementTagCounts will be called with array containing undefined + expect(bulkIncrementTagCounts).toHaveBeenCalled(); + }); + + test('should handle empty tags array when forking', async () => { + const mockConvoWithEmptyTags = { + ...mockConversation, + tags: [], + }; + getConvo.mockResolvedValue(mockConvoWithEmptyTags); + + await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.DIRECT_PATH, + }); + + // bulkIncrementTagCounts will be called with empty array + expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []); + }); +}); + +describe('duplicateConversation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockIdCounter = 0; + getConvo.mockResolvedValue(mockConversation); + getMessages.mockResolvedValue(mockMessages); + bulkSaveConvos.mockResolvedValue(null); + bulkSaveMessages.mockResolvedValue(null); + bulkIncrementTagCounts.mockResolvedValue(null); + }); + + test('should duplicate conversation and increment tag counts', async () => { + const mockConvoWithTags = { + ...mockConversation, + tags: ['important', 'work', 'project'], + }; + getConvo.mockResolvedValue(mockConvoWithTags); + + await duplicateConversation({ + userId: 'user1', + conversationId: 'abc123', + }); + + // Verify that bulkIncrementTagCounts was called with correct tags + expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['important', 'work', 'project']); + }); + + test('should duplicate conversation without tags', async () => { + const mockConvoWithoutTags = { + ...mockConversation, + // No tags field + }; + getConvo.mockResolvedValue(mockConvoWithoutTags); + + await duplicateConversation({ + userId: 'user1', + conversationId: 'abc123', + }); + + // bulkIncrementTagCounts will be called with array containing undefined + expect(bulkIncrementTagCounts).toHaveBeenCalled(); + }); + + test('should handle empty tags array when duplicating', async () => { + const mockConvoWithEmptyTags = { + ...mockConversation, + tags: [], + }; + getConvo.mockResolvedValue(mockConvoWithEmptyTags); + + await duplicateConversation({ + userId: 'user1', + conversationId: 'abc123', + }); + + // bulkIncrementTagCounts will be called with empty array + expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []); + }); }); const mockMessagesComplex = [ diff --git a/api/server/utils/import/importBatchBuilder.js b/api/server/utils/import/importBatchBuilder.js index f42e675eb9..d20f200939 100644 --- a/api/server/utils/import/importBatchBuilder.js +++ b/api/server/utils/import/importBatchBuilder.js @@ -1,5 +1,6 @@ const { v4: uuidv4 } = require('uuid'); const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider'); +const { bulkIncrementTagCounts } = require('~/models/ConversationTag'); const { bulkSaveConvos } = require('~/models/Conversation'); const { bulkSaveMessages } = require('~/models/Message'); const { logger } = require('~/config'); @@ -93,13 +94,22 @@ class ImportBatchBuilder { /** * Saves the batch of conversations and messages to the DB. + * Also increments tag counts for any existing tags. * @returns {Promise} A promise that resolves when the batch is saved. * @throws {Error} If there is an error saving the batch. */ async saveBatch() { try { - await bulkSaveConvos(this.conversations); - await bulkSaveMessages(this.messages, true); + const promises = []; + promises.push(bulkSaveConvos(this.conversations)); + promises.push(bulkSaveMessages(this.messages, true)); + promises.push( + bulkIncrementTagCounts( + this.requestUserId, + this.conversations.flatMap((convo) => convo.tags), + ), + ); + await Promise.all(promises); logger.debug( `user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`, ); diff --git a/client/src/data-provider/mutations.ts b/client/src/data-provider/mutations.ts index cda530ed38..21fb765fb1 100644 --- a/client/src/data-provider/mutations.ts +++ b/client/src/data-provider/mutations.ts @@ -96,7 +96,7 @@ export const useArchiveConvoMutation = ( const queryClient = useQueryClient(); const convoQueryKey = [QueryKeys.allConversations]; const archivedConvoQueryKey = [QueryKeys.archivedConversations]; - const { onMutate, onError, onSettled, onSuccess, ..._options } = options || {}; + const { onMutate, onError, onSuccess, ..._options } = options || {}; return useMutation( (payload: t.TArchiveConversationRequest) => dataService.archiveConversation(payload), @@ -567,6 +567,19 @@ export const useDuplicateConversationMutation = ( queryKey: [QueryKeys.allConversations], refetchPage: (_, index) => index === 0, }); + + if (duplicatedConversation.tags && duplicatedConversation.tags.length > 0) { + queryClient.setQueryData([QueryKeys.conversationTags], (oldTags) => { + if (!oldTags) return oldTags; + return oldTags.map((tag) => { + if (duplicatedConversation.tags?.includes(tag.tag)) { + return { ...tag, count: tag.count + 1 }; + } + return tag; + }); + }); + } + onSuccess?.(data, vars, context); }, ..._options, @@ -597,6 +610,19 @@ export const useForkConvoMutation = ( queryKey: [QueryKeys.allConversations], refetchPage: (_, index) => index === 0, }); + + if (forkedConversation.tags && forkedConversation.tags.length > 0) { + queryClient.setQueryData([QueryKeys.conversationTags], (oldTags) => { + if (!oldTags) return oldTags; + return oldTags.map((tag) => { + if (forkedConversation.tags?.includes(tag.tag)) { + return { ...tag, count: tag.count + 1 }; + } + return tag; + }); + }); + } + onSuccess?.(data, vars, context); }, ..._options, @@ -871,7 +897,7 @@ export const useUploadAssistantAvatarMutation = ( unknown // context > => { return useMutation([MutationKeys.assistantAvatarUpload], { - mutationFn: ({ postCreation, ...variables }: t.AssistantAvatarVariables) => + mutationFn: ({ postCreation: _postCreation, ...variables }: t.AssistantAvatarVariables) => dataService.uploadAssistantAvatar(variables), ...(options || {}), });