mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-26 01:16:24 +00:00
📡 fix: Tighten Streaming Message Cache Preservation (#13271)
* fix: tighten streaming message cache preservation * fix: sync streaming ref after commit
This commit is contained in:
parent
294bf7c87d
commit
1e0ffcf2fd
5 changed files with 177 additions and 24 deletions
|
|
@ -32,6 +32,7 @@ function LoadingSpinner() {
|
|||
function ChatView({ index = 0 }: { index?: number }) {
|
||||
const { conversationId } = useParams();
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const isSubmitting = useRecoilValue(store.isSubmittingFamily(index));
|
||||
const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding);
|
||||
|
||||
const methods = useForm<ChatFormValues>({
|
||||
|
|
@ -40,16 +41,20 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
|
||||
select: useCallback(
|
||||
(data: TMessage[]) => {
|
||||
const dataTree = buildTree({ messages: data, fileMap });
|
||||
return dataTree?.length === 0 ? null : (dataTree ?? null);
|
||||
},
|
||||
[fileMap],
|
||||
),
|
||||
enabled: !!fileMap,
|
||||
});
|
||||
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(
|
||||
conversationId ?? '',
|
||||
{
|
||||
select: useCallback(
|
||||
(data: TMessage[]) => {
|
||||
const dataTree = buildTree({ messages: data, fileMap });
|
||||
return dataTree?.length === 0 ? null : (dataTree ?? null);
|
||||
},
|
||||
[fileMap],
|
||||
),
|
||||
enabled: !!fileMap,
|
||||
},
|
||||
{ isStreaming: isSubmitting },
|
||||
);
|
||||
|
||||
const chatHelpers = useChatHelpers(index, conversationId);
|
||||
const addedChatHelpers = useAddedResponse();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { ContentTypes } from 'librechat-data-provider';
|
|||
import { HoverCard, HoverCardTrigger, HoverCardPortal, HoverCardContent } from '@librechat/client';
|
||||
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
|
||||
import { useGetMessagesByConvoId } from '~/data-provider';
|
||||
import { useMessagesConversation } from '~/Providers';
|
||||
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
|
|
@ -145,9 +145,14 @@ const chevronButtonClasses = cn(
|
|||
function MessageNav({ scrollableRef }: { scrollableRef: React.RefObject<HTMLDivElement> }) {
|
||||
const localize = useLocalize();
|
||||
const { conversationId } = useMessagesConversation();
|
||||
const { data: messages } = useGetMessagesByConvoId(conversationId ?? '', {
|
||||
enabled: !!conversationId,
|
||||
});
|
||||
const { isSubmitting } = useMessagesSubmission();
|
||||
const { data: messages } = useGetMessagesByConvoId(
|
||||
conversationId ?? '',
|
||||
{
|
||||
enabled: !!conversationId,
|
||||
},
|
||||
{ isStreaming: isSubmitting },
|
||||
);
|
||||
const messagesById = useMemo(() => {
|
||||
const map = new Map<string, TMessage>();
|
||||
if (messages) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ type TestMessage = {
|
|||
|
||||
const mockUseGetMessagesByConvoId = jest.fn();
|
||||
const mockUseMessagesConversation = jest.fn();
|
||||
const mockUseMessagesSubmission = jest.fn();
|
||||
|
||||
jest.mock('~/data-provider', () => ({
|
||||
useGetMessagesByConvoId: (...args: unknown[]) => mockUseGetMessagesByConvoId(...args),
|
||||
|
|
@ -20,6 +21,7 @@ jest.mock('~/data-provider', () => ({
|
|||
|
||||
jest.mock('~/Providers', () => ({
|
||||
useMessagesConversation: (...args: unknown[]) => mockUseMessagesConversation(...args),
|
||||
useMessagesSubmission: (...args: unknown[]) => mockUseMessagesSubmission(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks', () => ({
|
||||
|
|
@ -158,6 +160,7 @@ beforeEach(() => {
|
|||
).IntersectionObserver = MockIntersectionObserver;
|
||||
jest.useFakeTimers();
|
||||
mockUseMessagesConversation.mockReturnValue({ conversationId: 'test-convo' });
|
||||
mockUseMessagesSubmission.mockReturnValue({ isSubmitting: false });
|
||||
mockUseGetMessagesByConvoId.mockReturnValue({ data: [] });
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ describe('getStableMessages', () => {
|
|||
message({ messageId: 'user-2', createdAt: undefined, updatedAt: undefined }),
|
||||
message({
|
||||
messageId: 'user-2_',
|
||||
parentMessageId: 'user-2',
|
||||
isCreatedByUser: false,
|
||||
createdAt: undefined,
|
||||
updatedAt: undefined,
|
||||
|
|
@ -31,26 +32,122 @@ describe('getStableMessages', () => {
|
|||
pathname: '/c/convo-id',
|
||||
result: [],
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(currentMessages);
|
||||
});
|
||||
|
||||
it('keeps cache when a one-message result returns for a larger cache', () => {
|
||||
it('keeps cache when a prefix result races with a pending assistant tail', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
message({ messageId: 'persisted-2' }),
|
||||
message({ messageId: 'user-2', createdAt: undefined, updatedAt: undefined }),
|
||||
message({
|
||||
messageId: 'user-2_',
|
||||
parentMessageId: 'user-2',
|
||||
isCreatedByUser: false,
|
||||
createdAt: undefined,
|
||||
updatedAt: undefined,
|
||||
}),
|
||||
];
|
||||
|
||||
const result = getStableMessages({
|
||||
pathname: '/c/convo-id',
|
||||
result: [currentMessages[0]],
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(currentMessages);
|
||||
});
|
||||
|
||||
it('accepts a shorter result when the app is not streaming', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
message({
|
||||
messageId: 'user-2_',
|
||||
parentMessageId: 'user-2',
|
||||
isCreatedByUser: false,
|
||||
createdAt: undefined,
|
||||
updatedAt: undefined,
|
||||
}),
|
||||
];
|
||||
const serverMessages = [currentMessages[0]];
|
||||
|
||||
const result = getStableMessages({
|
||||
pathname: '/c/convo-id',
|
||||
result: serverMessages,
|
||||
currentMessages,
|
||||
isStreaming: false,
|
||||
});
|
||||
|
||||
expect(result).toBe(serverMessages);
|
||||
});
|
||||
|
||||
it('accepts a shorter result when an unhydrated message is not the assistant tail', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
message({ messageId: 'old-unhydrated', createdAt: undefined, updatedAt: undefined }),
|
||||
message({ messageId: 'persisted-3' }),
|
||||
];
|
||||
const serverMessages = [currentMessages[0]];
|
||||
|
||||
const result = getStableMessages({
|
||||
pathname: '/c/convo-id',
|
||||
result: serverMessages,
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(serverMessages);
|
||||
});
|
||||
|
||||
it('accepts a shorter result when the server result is not a prefix of cache', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
message({ messageId: 'persisted-2' }),
|
||||
message({
|
||||
messageId: 'persisted-2_',
|
||||
parentMessageId: 'persisted-2',
|
||||
isCreatedByUser: false,
|
||||
createdAt: undefined,
|
||||
updatedAt: undefined,
|
||||
}),
|
||||
];
|
||||
const serverMessages = [message({ messageId: 'different-1' })];
|
||||
|
||||
const result = getStableMessages({
|
||||
pathname: '/c/convo-id',
|
||||
result: serverMessages,
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(serverMessages);
|
||||
});
|
||||
|
||||
it('accepts a shorter result when an unhydrated assistant tail has no parent turn', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
message({
|
||||
messageId: 'orphaned-response_',
|
||||
isCreatedByUser: false,
|
||||
createdAt: undefined,
|
||||
updatedAt: undefined,
|
||||
}),
|
||||
];
|
||||
const serverMessages = [currentMessages[0]];
|
||||
|
||||
const result = getStableMessages({
|
||||
pathname: '/c/convo-id',
|
||||
result: serverMessages,
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(serverMessages);
|
||||
});
|
||||
|
||||
it('accepts fewer persisted messages when the cache is fully hydrated', () => {
|
||||
const currentMessages = [
|
||||
message({ messageId: 'persisted-1' }),
|
||||
|
|
@ -63,6 +160,7 @@ describe('getStableMessages', () => {
|
|||
pathname: '/c/convo-id',
|
||||
result: serverMessages,
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(serverMessages);
|
||||
|
|
@ -75,6 +173,7 @@ describe('getStableMessages', () => {
|
|||
pathname: '/c/new',
|
||||
result: [],
|
||||
currentMessages,
|
||||
isStreaming: true,
|
||||
});
|
||||
|
||||
expect(result).toEqual([]);
|
||||
|
|
|
|||
|
|
@ -1,26 +1,46 @@
|
|||
import { useLayoutEffect, useRef } from 'react';
|
||||
import { useLocation } from 'react-router-dom';
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query';
|
||||
import type { UseQueryOptions, QueryObserverResult } from '@tanstack/react-query';
|
||||
import { QueryKeys, dataService } from 'librechat-data-provider';
|
||||
import type { UseQueryOptions, QueryObserverResult, QueryClient } from '@tanstack/react-query';
|
||||
import { Constants, QueryKeys, dataService } from 'librechat-data-provider';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import { logger } from '~/utils';
|
||||
|
||||
type StableMessagesParams = {
|
||||
pathname: string;
|
||||
result: t.TMessage[];
|
||||
isStreaming?: boolean;
|
||||
currentMessages?: t.TMessage[];
|
||||
};
|
||||
|
||||
function hasUnhydratedMessage(messages: t.TMessage[]) {
|
||||
return messages.some((message) => {
|
||||
const messageId = message.messageId ?? '';
|
||||
return message.createdAt == null || message.updatedAt == null || messageId.endsWith('_');
|
||||
});
|
||||
type ActiveJobs = {
|
||||
activeJobIds?: string[];
|
||||
};
|
||||
|
||||
function isUnhydratedMessage(message: t.TMessage) {
|
||||
const messageId = message.messageId ?? '';
|
||||
return message.createdAt == null || message.updatedAt == null || messageId.endsWith('_');
|
||||
}
|
||||
|
||||
function hasPendingAssistantTail(messages: t.TMessage[]) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const parentMessageId = lastMessage?.parentMessageId ?? '';
|
||||
return (
|
||||
lastMessage?.isCreatedByUser !== true &&
|
||||
parentMessageId !== '' &&
|
||||
parentMessageId !== Constants.NO_PARENT &&
|
||||
isUnhydratedMessage(lastMessage)
|
||||
);
|
||||
}
|
||||
|
||||
function isMessagePrefix(result: t.TMessage[], currentMessages: t.TMessage[]) {
|
||||
return result.every((message, index) => message.messageId === currentMessages[index]?.messageId);
|
||||
}
|
||||
|
||||
export function getStableMessages({
|
||||
pathname,
|
||||
result,
|
||||
isStreaming = false,
|
||||
currentMessages,
|
||||
}: StableMessagesParams): t.TMessage[] {
|
||||
if (pathname.includes('/c/new') || !currentMessages?.length) {
|
||||
|
|
@ -31,19 +51,39 @@ export function getStableMessages({
|
|||
return result;
|
||||
}
|
||||
|
||||
if (result.length === 1 || hasUnhydratedMessage(currentMessages)) {
|
||||
if (
|
||||
isStreaming &&
|
||||
hasPendingAssistantTail(currentMessages) &&
|
||||
isMessagePrefix(result, currentMessages)
|
||||
) {
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function hasActiveJob(queryClient: QueryClient, id: string) {
|
||||
if (!id) {
|
||||
return false;
|
||||
}
|
||||
const activeJobs = queryClient.getQueryData<ActiveJobs>([QueryKeys.activeJobs]);
|
||||
return activeJobs?.activeJobIds?.includes(id) === true;
|
||||
}
|
||||
|
||||
export const useGetMessagesByConvoId = <TData = t.TMessage[]>(
|
||||
id: string,
|
||||
config?: UseQueryOptions<t.TMessage[], unknown, TData>,
|
||||
options?: { isStreaming?: boolean },
|
||||
): QueryObserverResult<TData> => {
|
||||
const location = useLocation();
|
||||
const queryClient = useQueryClient();
|
||||
const isStreaming = options?.isStreaming === true;
|
||||
const isStreamingRef = useRef(isStreaming);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
isStreamingRef.current = isStreaming;
|
||||
}, [isStreaming]);
|
||||
|
||||
return useQuery<t.TMessage[], unknown, TData>(
|
||||
[QueryKeys.messages, id],
|
||||
async () => {
|
||||
|
|
@ -53,6 +93,7 @@ export const useGetMessagesByConvoId = <TData = t.TMessage[]>(
|
|||
pathname: location.pathname,
|
||||
result,
|
||||
currentMessages,
|
||||
isStreaming: isStreamingRef.current || hasActiveJob(queryClient, id),
|
||||
});
|
||||
|
||||
if (stableMessages === currentMessages) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue