📡 fix: Tighten Streaming Message Cache Preservation (#13271)

* fix: tighten streaming message cache preservation

* fix: sync streaming ref after commit
This commit is contained in:
Danny Avila 2026-05-23 11:53:25 -04:00 committed by GitHub
parent 294bf7c87d
commit 1e0ffcf2fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 177 additions and 24 deletions

View file

@ -32,6 +32,7 @@ function LoadingSpinner() {
function ChatView({ index = 0 }: { index?: number }) {
const { conversationId } = useParams();
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
const isSubmitting = useRecoilValue(store.isSubmittingFamily(index));
const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding);
const methods = useForm<ChatFormValues>({
@ -40,16 +41,20 @@ function ChatView({ index = 0 }: { index?: number }) {
const fileMap = useFileMapContext();
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
select: useCallback(
(data: TMessage[]) => {
const dataTree = buildTree({ messages: data, fileMap });
return dataTree?.length === 0 ? null : (dataTree ?? null);
},
[fileMap],
),
enabled: !!fileMap,
});
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(
conversationId ?? '',
{
select: useCallback(
(data: TMessage[]) => {
const dataTree = buildTree({ messages: data, fileMap });
return dataTree?.length === 0 ? null : (dataTree ?? null);
},
[fileMap],
),
enabled: !!fileMap,
},
{ isStreaming: isSubmitting },
);
const chatHelpers = useChatHelpers(index, conversationId);
const addedChatHelpers = useAddedResponse();

View file

@ -4,7 +4,7 @@ import { ContentTypes } from 'librechat-data-provider';
import { HoverCard, HoverCardTrigger, HoverCardPortal, HoverCardContent } from '@librechat/client';
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
import { useGetMessagesByConvoId } from '~/data-provider';
import { useMessagesConversation } from '~/Providers';
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
@ -145,9 +145,14 @@ const chevronButtonClasses = cn(
function MessageNav({ scrollableRef }: { scrollableRef: React.RefObject<HTMLDivElement> }) {
const localize = useLocalize();
const { conversationId } = useMessagesConversation();
const { data: messages } = useGetMessagesByConvoId(conversationId ?? '', {
enabled: !!conversationId,
});
const { isSubmitting } = useMessagesSubmission();
const { data: messages } = useGetMessagesByConvoId(
conversationId ?? '',
{
enabled: !!conversationId,
},
{ isStreaming: isSubmitting },
);
const messagesById = useMemo(() => {
const map = new Map<string, TMessage>();
if (messages) {

View file

@ -13,6 +13,7 @@ type TestMessage = {
const mockUseGetMessagesByConvoId = jest.fn();
const mockUseMessagesConversation = jest.fn();
const mockUseMessagesSubmission = jest.fn();
jest.mock('~/data-provider', () => ({
useGetMessagesByConvoId: (...args: unknown[]) => mockUseGetMessagesByConvoId(...args),
@ -20,6 +21,7 @@ jest.mock('~/data-provider', () => ({
jest.mock('~/Providers', () => ({
useMessagesConversation: (...args: unknown[]) => mockUseMessagesConversation(...args),
useMessagesSubmission: (...args: unknown[]) => mockUseMessagesSubmission(...args),
}));
jest.mock('~/hooks', () => ({
@ -158,6 +160,7 @@ beforeEach(() => {
).IntersectionObserver = MockIntersectionObserver;
jest.useFakeTimers();
mockUseMessagesConversation.mockReturnValue({ conversationId: 'test-convo' });
mockUseMessagesSubmission.mockReturnValue({ isSubmitting: false });
mockUseGetMessagesByConvoId.mockReturnValue({ data: [] });
});

View file

@ -21,6 +21,7 @@ describe('getStableMessages', () => {
message({ messageId: 'user-2', createdAt: undefined, updatedAt: undefined }),
message({
messageId: 'user-2_',
parentMessageId: 'user-2',
isCreatedByUser: false,
createdAt: undefined,
updatedAt: undefined,
@ -31,26 +32,122 @@ describe('getStableMessages', () => {
pathname: '/c/convo-id',
result: [],
currentMessages,
isStreaming: true,
});
expect(result).toBe(currentMessages);
});
it('keeps cache when a one-message result returns for a larger cache', () => {
it('keeps cache when a prefix result races with a pending assistant tail', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
message({ messageId: 'persisted-2' }),
message({ messageId: 'user-2', createdAt: undefined, updatedAt: undefined }),
message({
messageId: 'user-2_',
parentMessageId: 'user-2',
isCreatedByUser: false,
createdAt: undefined,
updatedAt: undefined,
}),
];
const result = getStableMessages({
pathname: '/c/convo-id',
result: [currentMessages[0]],
currentMessages,
isStreaming: true,
});
expect(result).toBe(currentMessages);
});
it('accepts a shorter result when the app is not streaming', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
message({
messageId: 'user-2_',
parentMessageId: 'user-2',
isCreatedByUser: false,
createdAt: undefined,
updatedAt: undefined,
}),
];
const serverMessages = [currentMessages[0]];
const result = getStableMessages({
pathname: '/c/convo-id',
result: serverMessages,
currentMessages,
isStreaming: false,
});
expect(result).toBe(serverMessages);
});
it('accepts a shorter result when an unhydrated message is not the assistant tail', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
message({ messageId: 'old-unhydrated', createdAt: undefined, updatedAt: undefined }),
message({ messageId: 'persisted-3' }),
];
const serverMessages = [currentMessages[0]];
const result = getStableMessages({
pathname: '/c/convo-id',
result: serverMessages,
currentMessages,
isStreaming: true,
});
expect(result).toBe(serverMessages);
});
it('accepts a shorter result when the server result is not a prefix of cache', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
message({ messageId: 'persisted-2' }),
message({
messageId: 'persisted-2_',
parentMessageId: 'persisted-2',
isCreatedByUser: false,
createdAt: undefined,
updatedAt: undefined,
}),
];
const serverMessages = [message({ messageId: 'different-1' })];
const result = getStableMessages({
pathname: '/c/convo-id',
result: serverMessages,
currentMessages,
isStreaming: true,
});
expect(result).toBe(serverMessages);
});
it('accepts a shorter result when an unhydrated assistant tail has no parent turn', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
message({
messageId: 'orphaned-response_',
isCreatedByUser: false,
createdAt: undefined,
updatedAt: undefined,
}),
];
const serverMessages = [currentMessages[0]];
const result = getStableMessages({
pathname: '/c/convo-id',
result: serverMessages,
currentMessages,
isStreaming: true,
});
expect(result).toBe(serverMessages);
});
it('accepts fewer persisted messages when the cache is fully hydrated', () => {
const currentMessages = [
message({ messageId: 'persisted-1' }),
@ -63,6 +160,7 @@ describe('getStableMessages', () => {
pathname: '/c/convo-id',
result: serverMessages,
currentMessages,
isStreaming: true,
});
expect(result).toBe(serverMessages);
@ -75,6 +173,7 @@ describe('getStableMessages', () => {
pathname: '/c/new',
result: [],
currentMessages,
isStreaming: true,
});
expect(result).toEqual([]);

View file

@ -1,26 +1,46 @@
import { useLayoutEffect, useRef } from 'react';
import { useLocation } from 'react-router-dom';
import { useQuery, useQueryClient } from '@tanstack/react-query';
import type { UseQueryOptions, QueryObserverResult } from '@tanstack/react-query';
import { QueryKeys, dataService } from 'librechat-data-provider';
import type { UseQueryOptions, QueryObserverResult, QueryClient } from '@tanstack/react-query';
import { Constants, QueryKeys, dataService } from 'librechat-data-provider';
import type * as t from 'librechat-data-provider';
import { logger } from '~/utils';
type StableMessagesParams = {
pathname: string;
result: t.TMessage[];
isStreaming?: boolean;
currentMessages?: t.TMessage[];
};
function hasUnhydratedMessage(messages: t.TMessage[]) {
return messages.some((message) => {
const messageId = message.messageId ?? '';
return message.createdAt == null || message.updatedAt == null || messageId.endsWith('_');
});
type ActiveJobs = {
activeJobIds?: string[];
};
function isUnhydratedMessage(message: t.TMessage) {
const messageId = message.messageId ?? '';
return message.createdAt == null || message.updatedAt == null || messageId.endsWith('_');
}
function hasPendingAssistantTail(messages: t.TMessage[]) {
const lastMessage = messages[messages.length - 1];
const parentMessageId = lastMessage?.parentMessageId ?? '';
return (
lastMessage?.isCreatedByUser !== true &&
parentMessageId !== '' &&
parentMessageId !== Constants.NO_PARENT &&
isUnhydratedMessage(lastMessage)
);
}
function isMessagePrefix(result: t.TMessage[], currentMessages: t.TMessage[]) {
return result.every((message, index) => message.messageId === currentMessages[index]?.messageId);
}
export function getStableMessages({
pathname,
result,
isStreaming = false,
currentMessages,
}: StableMessagesParams): t.TMessage[] {
if (pathname.includes('/c/new') || !currentMessages?.length) {
@ -31,19 +51,39 @@ export function getStableMessages({
return result;
}
if (result.length === 1 || hasUnhydratedMessage(currentMessages)) {
if (
isStreaming &&
hasPendingAssistantTail(currentMessages) &&
isMessagePrefix(result, currentMessages)
) {
return currentMessages;
}
return result;
}
function hasActiveJob(queryClient: QueryClient, id: string) {
if (!id) {
return false;
}
const activeJobs = queryClient.getQueryData<ActiveJobs>([QueryKeys.activeJobs]);
return activeJobs?.activeJobIds?.includes(id) === true;
}
export const useGetMessagesByConvoId = <TData = t.TMessage[]>(
id: string,
config?: UseQueryOptions<t.TMessage[], unknown, TData>,
options?: { isStreaming?: boolean },
): QueryObserverResult<TData> => {
const location = useLocation();
const queryClient = useQueryClient();
const isStreaming = options?.isStreaming === true;
const isStreamingRef = useRef(isStreaming);
useLayoutEffect(() => {
isStreamingRef.current = isStreaming;
}, [isStreaming]);
return useQuery<t.TMessage[], unknown, TData>(
[QueryKeys.messages, id],
async () => {
@ -53,6 +93,7 @@ export const useGetMessagesByConvoId = <TData = t.TMessage[]>(
pathname: location.pathname,
result,
currentMessages,
isStreaming: isStreamingRef.current || hasActiveJob(queryClient, id),
});
if (stableMessages === currentMessages) {