🍪 refactor: Refresh CloudFront Media Cookies (#13091)

* fix: refresh CloudFront media cookies

* fix: satisfy changed-file lint

* fix: centralize CloudFront image retry

* fix: honor base path for CloudFront refresh

* fix: bypass auth refresh for CloudFront cookie retry

* fix: pass app auth header to CloudFront retry

* test: cover CloudFront refresh with OpenID reuse

* fix: avoid duplicate CloudFront refresh retries

* fix: clear CloudFront scope cookie with matching flags
This commit is contained in:
Danny Avila 2026-05-12 13:26:05 -04:00 committed by GitHub
parent 05d4e90f91
commit 6b5596ec36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1294 additions and 31 deletions

View file

@ -53,6 +53,7 @@ jest.mock('@librechat/api', () => {
const { tenantStorage } = require('@librechat/data-schemas');
return {
isEnabled: jest.fn(() => false),
maybeRefreshCloudFrontAuthCookiesMiddleware: jest.fn((req, res, next) => next()),
tenantContextMiddleware: (req, res, next) => {
const tenantId = req.user?.tenantId;
if (!tenantId) {
@ -67,7 +68,7 @@ jest.mock('@librechat/api', () => {
const requireJwtAuth = require('../requireJwtAuth');
const { getTenantId } = require('@librechat/data-schemas');
const { isEnabled } = require('@librechat/api');
const { isEnabled, maybeRefreshCloudFrontAuthCookiesMiddleware } = require('@librechat/api');
const passport = require('passport');
const jwtSecret = 'test-refresh-secret';
@ -108,6 +109,7 @@ describe('requireJwtAuth tenant context chaining', () => {
mockPassportError = null;
mockRegisteredStrategies = new Set(['jwt']);
isEnabled.mockReturnValue(false);
maybeRefreshCloudFrontAuthCookiesMiddleware.mockClear();
passport.authenticate.mockClear();
passport._strategy.mockClear();
if (originalJwtSecret === undefined) {
@ -134,6 +136,21 @@ describe('requireJwtAuth tenant context chaining', () => {
expect(tenantId).toBe('tenant-abc');
});
it('refreshes CloudFront auth cookies after passport auth succeeds', () => {
const req = mockReq({ tenantId: 'tenant-abc', role: 'user' });
const res = mockRes();
const next = jest.fn();
requireJwtAuth(req, res, next);
expect(maybeRefreshCloudFrontAuthCookiesMiddleware).toHaveBeenCalledWith(
req,
res,
expect.any(Function),
);
expect(next).toHaveBeenCalled();
});
it('ALS tenant context is NOT set when user has no tenantId', async () => {
const tenantId = await runAuth({ role: 'user' });
expect(tenantId).toBeUndefined();
@ -201,6 +218,11 @@ describe('requireJwtAuth tenant context chaining', () => {
{ session: false },
expect.any(Function),
);
expect(maybeRefreshCloudFrontAuthCookiesMiddleware).toHaveBeenCalledWith(
req,
res,
expect.any(Function),
);
});
it('does not authenticate OpenID JWT when the reuse cookie belongs to another user', () => {
@ -236,6 +258,7 @@ describe('requireJwtAuth tenant context chaining', () => {
{ session: false },
expect.any(Function),
);
expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled();
});
it('does not use OpenID JWT when the signed OpenID reuse cookie is missing', () => {
@ -262,6 +285,7 @@ describe('requireJwtAuth tenant context chaining', () => {
{ session: false },
expect.any(Function),
);
expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled();
});
it('does not use OpenID JWT when the OpenID reuse cookie is invalid', () => {
@ -288,6 +312,7 @@ describe('requireJwtAuth tenant context chaining', () => {
{ session: false },
expect.any(Function),
);
expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled();
});
it('skips OpenID JWT fallback when the strategy was not registered', async () => {

View file

@ -1,7 +1,11 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const passport = require('passport');
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
const {
isEnabled,
tenantContextMiddleware,
maybeRefreshCloudFrontAuthCookiesMiddleware,
} = require('@librechat/api');
const hasPassportStrategy = (strategy) =>
typeof passport._strategy === 'function' && passport._strategy(strategy) != null;
@ -23,6 +27,8 @@ const getValidOpenIdReuseUserId = (parsedCookies) => {
};
const getAuthenticatedUserId = (user) => user?.id?.toString?.() ?? user?._id?.toString?.();
const refreshCloudFrontCookies =
maybeRefreshCloudFrontAuthCookiesMiddleware ?? ((_req, _res, next) => next());
/**
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse.
@ -65,8 +71,13 @@ const requireJwtAuth = (req, res, next) => {
}
req.user = user;
req.authStrategy = strategy;
// req.user is now populated by passport — set up tenant ALS context
tenantContextMiddleware(req, res, next);
refreshCloudFrontCookies(req, res, (refreshErr) => {
if (refreshErr) {
return next(refreshErr);
}
// req.user is now populated by passport — set up tenant ALS context
tenantContextMiddleware(req, res, next);
});
})(req, res, next);
};

View file

@ -20,6 +20,12 @@ jest.mock('@librechat/data-schemas', () => ({
getTenantId: (...args) => mockGetTenantId(...args),
}));
const mockGetCloudFrontConfig = jest.fn(() => null);
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
getCloudFrontConfig: (...args) => mockGetCloudFrontConfig(...args),
}));
const request = require('supertest');
const express = require('express');
const configRoute = require('../config');
@ -187,6 +193,53 @@ describe('GET /api/config', () => {
expect(response.body).toHaveProperty('serverDomain');
});
it('should advertise CloudFront cookie refresh only when signed-cookie mode is active', async () => {
mockGetAppConfig.mockResolvedValue(baseAppConfig);
mockGetCloudFrontConfig.mockReturnValue({
domain: 'https://cdn.example.com',
imageSigning: 'cookies',
cookieDomain: '.example.com',
privateKey: 'test-private-key',
keyPairId: 'K123ABC',
});
const app = createApp(null);
const response = await request(app).get('/api/config');
expect(response.body.cloudFront).toEqual({
cookieRefresh: {
endpoint: '/api/auth/cloudfront/refresh',
domain: 'https://cdn.example.com',
},
});
});
it('should omit CloudFront cookie refresh when signed-cookie mode is inactive', async () => {
mockGetAppConfig.mockResolvedValue(baseAppConfig);
mockGetCloudFrontConfig.mockReturnValue({
domain: 'https://cdn.example.com',
imageSigning: 'url',
});
const app = createApp(null);
const response = await request(app).get('/api/config');
expect(response.body).not.toHaveProperty('cloudFront');
});
it('should omit CloudFront cookie refresh when cookie mode cannot mint cookies', async () => {
mockGetAppConfig.mockResolvedValue(baseAppConfig);
mockGetCloudFrontConfig.mockReturnValue({
domain: 'https://cdn.example.com',
imageSigning: 'cookies',
});
const app = createApp(null);
const response = await request(app).get('/api/config');
expect(response.body).not.toHaveProperty('cloudFront');
});
it('should default allowAccountDeletion to true when env var is unset', async () => {
mockGetAppConfig.mockResolvedValue(baseAppConfig);
const app = createApp(null);

View file

@ -0,0 +1,154 @@
const express = require('express');
const request = require('supertest');
const mockForceRefreshCloudFrontAuthCookies = jest.fn();
jest.mock('@librechat/api', () => ({
createSetBalanceConfig: jest.fn(() => (req, res, next) => next()),
forceRefreshCloudFrontAuthCookies: (...args) => mockForceRefreshCloudFrontAuthCookies(...args),
}));
jest.mock('~/server/controllers/AuthController', () => ({
refreshController: jest.fn((req, res) => res.status(200).end()),
registrationController: jest.fn((req, res) => res.status(200).end()),
resetPasswordController: jest.fn((req, res) => res.status(200).end()),
resetPasswordRequestController: jest.fn((req, res) => res.status(200).end()),
graphTokenController: jest.fn((req, res) => res.status(200).end()),
}));
jest.mock('~/server/controllers/TwoFactorController', () => ({
enable2FA: jest.fn((req, res) => res.status(200).end()),
verify2FA: jest.fn((req, res) => res.status(200).end()),
confirm2FA: jest.fn((req, res) => res.status(200).end()),
disable2FA: jest.fn((req, res) => res.status(200).end()),
regenerateBackupCodes: jest.fn((req, res) => res.status(200).end()),
}));
jest.mock('~/server/controllers/auth/TwoFactorAuthController', () => ({
verify2FAWithTempToken: jest.fn((req, res) => res.status(200).end()),
}));
jest.mock('~/server/controllers/auth/LogoutController', () => ({
logoutController: jest.fn((req, res) => res.status(200).end()),
}));
jest.mock('~/server/controllers/auth/LoginController', () => ({
loginController: jest.fn((req, res) => res.status(200).end()),
}));
jest.mock('~/models', () => ({
findBalanceByUser: jest.fn(),
upsertBalanceFields: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn(),
}));
jest.mock('~/server/middleware', () => {
const pass = (req, res, next) => next();
return {
logHeaders: pass,
loginLimiter: pass,
checkBan: pass,
requireLocalAuth: pass,
requireLdapAuth: pass,
registerLimiter: pass,
checkInviteUser: pass,
validateRegistration: pass,
resetPasswordLimiter: pass,
validatePasswordReset: pass,
requireJwtAuth: jest.fn((req, res, next) => {
if (req.headers.authorization !== 'Bearer ok') {
return res.status(401).json({ message: 'Unauthorized' });
}
req.user = { _id: 'user123', tenantId: 'tenantA' };
if (req.headers['x-cloudfront-warmed'] === 'true') {
req.cloudFrontAuthCookieRefreshResult = {
enabled: true,
attempted: true,
refreshed: true,
expiresInSec: 1800,
refreshAfterSec: 1500,
};
}
return next();
}),
};
});
const authRouter = require('./auth');
describe('POST /api/auth/cloudfront/refresh', () => {
let app;
beforeEach(() => {
jest.clearAllMocks();
app = express();
app.use(express.json());
app.use('/api/auth', authRouter);
});
it('requires authentication', async () => {
await request(app).post('/api/auth/cloudfront/refresh').expect(401);
expect(mockForceRefreshCloudFrontAuthCookies).not.toHaveBeenCalled();
});
it('returns 404 when CloudFront cookie mode is disabled', async () => {
mockForceRefreshCloudFrontAuthCookies.mockReturnValue({
enabled: false,
attempted: false,
refreshed: false,
reason: 'cloudfront_disabled',
});
const response = await request(app)
.post('/api/auth/cloudfront/refresh')
.set('Authorization', 'Bearer ok')
.expect(404);
expect(response.status).toBe(404);
});
it('returns cookie refresh timing when CloudFront cookies are refreshed', async () => {
mockForceRefreshCloudFrontAuthCookies.mockReturnValue({
enabled: true,
attempted: true,
refreshed: true,
expiresInSec: 1800,
refreshAfterSec: 1500,
});
const response = await request(app)
.post('/api/auth/cloudfront/refresh')
.set('Authorization', 'Bearer ok')
.expect(200);
expect(response.body).toEqual({
ok: true,
expiresInSec: 1800,
refreshAfterSec: 1500,
});
expect(mockForceRefreshCloudFrontAuthCookies).toHaveBeenCalledWith(
expect.objectContaining({ user: { _id: 'user123', tenantId: 'tenantA' } }),
expect.any(Object),
{ _id: 'user123', tenantId: 'tenantA' },
);
});
it('reuses the auth middleware refresh result instead of minting cookies twice', async () => {
const response = await request(app)
.post('/api/auth/cloudfront/refresh')
.set('Authorization', 'Bearer ok')
.set('x-cloudfront-warmed', 'true')
.expect(200);
expect(response.body).toEqual({
ok: true,
expiresInSec: 1800,
refreshAfterSec: 1500,
});
expect(mockForceRefreshCloudFrontAuthCookies).not.toHaveBeenCalled();
});
});

View file

@ -1,5 +1,5 @@
const express = require('express');
const { createSetBalanceConfig } = require('@librechat/api');
const { createSetBalanceConfig, forceRefreshCloudFrontAuthCookies } = require('@librechat/api');
const {
resetPasswordRequestController,
resetPasswordController,
@ -28,6 +28,14 @@ const setBalanceConfig = createSetBalanceConfig({
});
const router = express.Router();
const getCloudFrontAuthCookieRefreshResult = (req, res) => {
const warmedResult = req.cloudFrontAuthCookieRefreshResult;
if (warmedResult && (warmedResult.attempted || !warmedResult.enabled)) {
return warmedResult;
}
return forceRefreshCloudFrontAuthCookies(req, res, req.user);
};
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
//Local
@ -42,6 +50,19 @@ router.post(
loginController,
);
router.post('/refresh', refreshController);
router.post('/cloudfront/refresh', middleware.requireJwtAuth, (req, res) => {
const result = getCloudFrontAuthCookieRefreshResult(req, res);
if (!result.enabled) {
return res.sendStatus(404);
}
const status = result.refreshed ? 200 : 500;
return res.status(status).json({
ok: result.refreshed,
expiresInSec: result.expiresInSec,
refreshAfterSec: result.refreshAfterSec,
});
});
router.post(
'/register',
middleware.registerLimiter,

View file

@ -1,5 +1,5 @@
const express = require('express');
const { isEnabled, getBalanceConfig } = require('@librechat/api');
const { isEnabled, getBalanceConfig, getCloudFrontConfig } = require('@librechat/api');
const { defaultSocialLogins } = require('librechat-data-provider');
const { logger, getTenantId, SystemCapabilities } = require('@librechat/data-schemas');
const { hasCapability } = require('~/server/middleware/roles/capabilities');
@ -116,9 +116,30 @@ function buildWebSearchConfig(appConfig) {
};
}
function buildCloudFrontStartupConfig() {
const config = getCloudFrontConfig();
if (
config?.imageSigning !== 'cookies' ||
!config.domain ||
!config.cookieDomain ||
!config.privateKey ||
!config.keyPairId
) {
return undefined;
}
return {
cookieRefresh: {
endpoint: '/api/auth/cloudfront/refresh',
domain: config.domain,
},
};
}
router.get('/', async function (req, res) {
try {
const sharedPayload = buildSharedPayload();
const cloudFront = buildCloudFrontStartupConfig();
if (!req.user) {
const tenantId = getTenantId();
@ -129,6 +150,7 @@ router.get('/', async function (req, res) {
...sharedPayload,
socialLogins: baseConfig?.registration?.socialLogins ?? defaultSocialLogins,
turnstile: baseConfig?.turnstileConfig,
...(cloudFront ? { cloudFront } : {}),
};
const interfaceConfig = baseConfig?.interfaceConfig;
@ -170,6 +192,7 @@ router.get('/', async function (req, res) {
conversationImportMaxFileSize: process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES
? parseInt(process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES, 10)
: 0,
...(cloudFront ? { cloudFront } : {}),
};
const webSearch = buildWebSearchConfig(appConfig);

View file

@ -7,6 +7,21 @@ import type { TUser } from 'librechat-data-provider';
const mockUseHasAccess = jest.fn();
const mockUseMCPServersQuery = jest.fn();
const mockUseMCPToolsQuery = jest.fn();
const mockInstallCloudFrontImageRetry = jest.fn(() => jest.fn());
const mockGetTokenHeader = jest.fn();
jest.mock('@librechat/client', () => ({
installCloudFrontImageRetry: (startupConfig: unknown, options: unknown) =>
mockInstallCloudFrontImageRetry(startupConfig, options),
}));
jest.mock('librechat-data-provider', () => {
const actual = jest.requireActual('librechat-data-provider');
return {
...actual,
getTokenHeader: () => mockGetTokenHeader(),
};
});
jest.mock('~/hooks', () => ({
useHasAccess: (args: unknown) => mockUseHasAccess(args),
@ -52,6 +67,7 @@ const wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
describe('useAppStartup — MCP permission gating', () => {
beforeEach(() => {
mockInstallCloudFrontImageRetry.mockClear();
mockUseMCPServersQuery.mockReturnValue({ data: undefined, isLoading: false });
mockUseMCPToolsQuery.mockReturnValue({ data: undefined, isLoading: false });
});
@ -120,4 +136,27 @@ describe('useAppStartup — MCP permission gating', () => {
expect(mockUseMCPToolsQuery).toHaveBeenCalledWith({ enabled: false });
});
it('installs CloudFront image retry from startup config', () => {
mockUseHasAccess.mockReturnValue(false);
const startupConfig = {
cloudFront: {
cookieRefresh: {
endpoint: '/api/auth/cloudfront/refresh',
domain: 'https://cdn.example.com',
},
},
} as never;
renderHook(() => useAppStartup({ startupConfig, user: mockUser }), { wrapper });
expect(mockInstallCloudFrontImageRetry).toHaveBeenCalledWith(startupConfig, {
getAuthorizationHeader: expect.any(Function),
});
const [, options] = mockInstallCloudFrontImageRetry.mock.calls[0];
mockGetTokenHeader.mockReturnValue('Bearer app-token');
expect(options.getAuthorizationHeader()).toBe('Bearer app-token');
expect(mockGetTokenHeader).toHaveBeenCalledTimes(1);
});
});

View file

@ -1,7 +1,13 @@
import { useEffect } from 'react';
import { useRecoilState } from 'recoil';
import TagManager from 'react-gtm-module';
import { LocalStorageKeys, PermissionTypes, Permissions } from 'librechat-data-provider';
import { installCloudFrontImageRetry } from '@librechat/client';
import {
getTokenHeader,
LocalStorageKeys,
PermissionTypes,
Permissions,
} from 'librechat-data-provider';
import type { TStartupConfig, TUser } from 'librechat-data-provider';
import { useMCPToolsQuery, useMCPServersQuery } from '~/data-provider';
import { cleanupTimestampedStorage } from '~/utils/timestamps';
@ -76,6 +82,10 @@ export default function useAppStartup({
});
}, [defaultPreset, setDefaultPreset, startupConfig?.modelSpecs?.list]);
useEffect(() => {
return installCloudFrontImageRetry(startupConfig, { getAuthorizationHeader: getTokenHeader });
}, [startupConfig]);
useEffect(() => {
if (startupConfig?.analyticsGtmId != null && typeof window.google_tag_manager === 'undefined') {
const tagManagerArgs = {

View file

@ -19,6 +19,8 @@ import type { Response } from 'express';
import {
setCloudFrontCookies,
clearCloudFrontCookies,
forceRefreshCloudFrontAuthCookies,
maybeRefreshCloudFrontAuthCookies,
parseCloudFrontCookieScope,
} from '../cloudfront-cookies';
@ -27,6 +29,24 @@ const { logger: mockLogger } = jest.requireMock('@librechat/data-schemas') as {
};
const defaultScope = { userId: 'user123' };
const encodeScope = (scope: object) =>
Buffer.from(JSON.stringify(scope), 'utf8').toString('base64url');
afterEach(() => {
jest.restoreAllMocks();
});
function defaultCookieConfig(overrides: object = {}) {
return {
domain: 'https://cdn.example.com',
imageSigning: 'cookies',
cookieExpiry: 1800,
cookieDomain: '.example.com',
privateKey: '-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----',
keyPairId: 'K123ABC',
...overrides,
};
}
describe('setCloudFrontCookies', () => {
let mockRes: Partial<Response>;
@ -156,6 +176,33 @@ describe('setCloudFrontCookies', () => {
expect(cookieNames).toContain('CloudFront-Key-Pair-Id');
});
it('sets a non-HttpOnly scope cookie with issuedAt and expiresAt timing', () => {
jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000);
mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig({ cookieExpiry: 1800 }));
mockGetSignedCookies.mockReturnValue({
'CloudFront-Policy': 'policy-value',
'CloudFront-Signature': 'signature-value',
'CloudFront-Key-Pair-Id': 'K123ABC',
});
const result = setCloudFrontCookies(mockRes as Response, {
userId: 'user123',
tenantId: 'tenantA',
storageRegion: 'us-east-2',
});
expect(result).toBe(true);
const [, value, options] = cookieArgs.find(([name]) => name === 'LibreChat-CloudFront-Scope')!;
expect(options).toMatchObject({ httpOnly: false, path: '/' });
expect(parseCloudFrontCookieScope(value)).toEqual({
userId: 'user123',
tenantId: 'tenantA',
storageRegion: 'us-east-2',
issuedAt: 1_700_000_000,
expiresAt: 1_700_001_800,
});
});
it('uses cookieDomain from config with path-scoped cookies', () => {
mockGetCloudFrontConfig.mockReturnValue({
domain: 'https://cdn.example.com',
@ -299,8 +346,8 @@ describe('setCloudFrontCookies', () => {
const [name, value, options] = cookieArgs[cookieArgs.length - 1];
expect(name).toBe('LibreChat-CloudFront-Scope');
expect(options).toMatchObject({ domain: '.example.com', path: '/' });
expect(Buffer.from(value, 'base64url').toString('utf8')).toBe(
JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: null }),
expect(parseCloudFrontCookieScope(value)).toEqual(
expect.objectContaining({ userId: 'user123', tenantId: 'tenantA' }),
);
});
@ -417,8 +464,12 @@ describe('setCloudFrontCookies', () => {
expect(cookieArgs[3][2]).toMatchObject({ path: '/a' });
const [, scopeValue] = cookieArgs[cookieArgs.length - 1];
expect(Buffer.from(scopeValue, 'base64url').toString('utf8')).toBe(
JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: 'us-east-2' }),
expect(parseCloudFrontCookieScope(scopeValue)).toEqual(
expect.objectContaining({
userId: 'user123',
tenantId: 'tenantA',
storageRegion: 'us-east-2',
}),
);
});
@ -461,8 +512,8 @@ describe('setCloudFrontCookies', () => {
Resource: 'https://cdn.example.com/a/r/*/t/tenantA/avatars/*',
}),
]);
expect(Buffer.from(scopeValue, 'base64url').toString('utf8')).toBe(
JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: null }),
expect(parseCloudFrontCookieScope(scopeValue)).toEqual(
expect.objectContaining({ userId: 'user123', tenantId: 'tenantA' }),
);
} finally {
if (originalRegion == null) {
@ -640,9 +691,6 @@ describe('setCloudFrontCookies', () => {
});
describe('parseCloudFrontCookieScope', () => {
const encodeScope = (scope: object) =>
Buffer.from(JSON.stringify(scope), 'utf8').toString('base64url');
it('round-trips a valid user and tenant scope', () => {
const value = encodeScope({ userId: 'user123', tenantId: 'tenantA' });
@ -667,6 +715,163 @@ describe('parseCloudFrontCookieScope', () => {
parseCloudFrontCookieScope(encodeScope({ userId: 'user123', tenantId: 'tenant A' })),
).toBeNull();
});
it('handles old scope cookies without timing fields', () => {
expect(parseCloudFrontCookieScope(encodeScope({ userId: 'user123' }))).toEqual({
userId: 'user123',
});
});
it('drops invalid timing fields while preserving valid scope', () => {
expect(
parseCloudFrontCookieScope(
encodeScope({ userId: 'user123', issuedAt: 'bad', expiresAt: Number.NaN }),
),
).toEqual({ userId: 'user123' });
});
});
describe('maybeRefreshCloudFrontAuthCookies', () => {
let mockRes: Partial<Response>;
beforeEach(() => {
jest.clearAllMocks();
mockRes = {
cookie: jest.fn().mockReturnThis(),
clearCookie: jest.fn().mockReturnThis(),
};
mockGetSignedCookies.mockReturnValue({
'CloudFront-Policy': 'policy-value',
'CloudFront-Signature': 'signature-value',
'CloudFront-Key-Pair-Id': 'K123ABC',
});
mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig());
jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000);
});
it('refreshes when the scope cookie is missing', () => {
const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, {
_id: 'user123',
});
expect(result).toMatchObject({ enabled: true, attempted: true, refreshed: true });
expect(mockGetSignedCookies).toHaveBeenCalled();
});
it('refreshes when the scope cookie is near expiry', () => {
const result = maybeRefreshCloudFrontAuthCookies(
{
cookies: {
'LibreChat-CloudFront-Scope': encodeScope({
userId: 'user123',
expiresAt: 1_700_000_250,
}),
},
},
mockRes as Response,
{ _id: 'user123' },
);
expect(result).toMatchObject({ attempted: true, refreshed: true, reason: 'near_expiry' });
});
it('refreshes when the tenant or user scope mismatches', () => {
const userMismatch = maybeRefreshCloudFrontAuthCookies(
{
cookies: {
'LibreChat-CloudFront-Scope': encodeScope({
userId: 'old-user',
tenantId: 'tenantA',
expiresAt: 1_700_001_000,
}),
},
},
mockRes as Response,
{ _id: 'user123', tenantId: 'tenantA' },
);
const tenantMismatch = maybeRefreshCloudFrontAuthCookies(
{
cookies: {
'LibreChat-CloudFront-Scope': encodeScope({
userId: 'user123',
tenantId: 'old-tenant',
expiresAt: 1_700_001_000,
}),
},
},
mockRes as Response,
{ _id: 'user123', tenantId: 'tenantA' },
);
expect(userMismatch).toMatchObject({ attempted: true, reason: 'user_scope_mismatch' });
expect(tenantMismatch).toMatchObject({ attempted: true, reason: 'tenant_scope_mismatch' });
});
it('does not refresh when the scope cookie is still fresh', () => {
const result = maybeRefreshCloudFrontAuthCookies(
{
cookies: {
'LibreChat-CloudFront-Scope': encodeScope({
userId: 'user123',
expiresAt: 1_700_001_000,
}),
},
},
mockRes as Response,
{ _id: 'user123' },
);
expect(result).toMatchObject({
enabled: true,
attempted: false,
refreshed: false,
reason: 'fresh',
});
expect(mockGetSignedCookies).not.toHaveBeenCalled();
});
it('does not refresh when CloudFront is disabled', () => {
mockGetCloudFrontConfig.mockReturnValue(null);
const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, {
_id: 'user123',
});
expect(result).toMatchObject({ enabled: false, attempted: false, refreshed: false });
expect(mockGetSignedCookies).not.toHaveBeenCalled();
});
it('does not refresh when imageSigning is not cookies', () => {
mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig({ imageSigning: 'url' }));
const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, {
_id: 'user123',
});
expect(result).toMatchObject({ enabled: false, attempted: false, refreshed: false });
expect(mockGetSignedCookies).not.toHaveBeenCalled();
});
it('force-refreshes even when the scope cookie is fresh without calling OIDC refresh', () => {
const oidcRefresh = jest.fn();
const result = forceRefreshCloudFrontAuthCookies(
{
cookies: {
'LibreChat-CloudFront-Scope': encodeScope({
userId: 'user123',
expiresAt: 1_700_001_000,
}),
},
},
mockRes as Response,
{ _id: 'user123' },
);
expect(result).toMatchObject({ attempted: true, refreshed: true, reason: 'forced' });
expect(oidcRefresh).not.toHaveBeenCalled();
});
});
describe('clearCloudFrontCookies', () => {
@ -742,12 +947,17 @@ describe('clearCloudFrontCookies', () => {
secure: true,
sameSite: 'none',
};
const scopePathOptions = {
...rootPathOptions,
httpOnly: false,
};
expect(clearedCookies).toContainEqual(['CloudFront-Policy', legacyPathOptions]);
expect(clearedCookies).toContainEqual(['CloudFront-Signature', legacyPathOptions]);
expect(clearedCookies).toContainEqual(['CloudFront-Key-Pair-Id', legacyPathOptions]);
expect(clearedCookies).toContainEqual(['CloudFront-Policy', rootPathOptions]);
expect(clearedCookies).toContainEqual(['CloudFront-Signature', rootPathOptions]);
expect(clearedCookies).toContainEqual(['CloudFront-Key-Pair-Id', rootPathOptions]);
expect(clearedCookies).toContainEqual(['LibreChat-CloudFront-Scope', scopePathOptions]);
expect(clearedCookies).toContainEqual([
'CloudFront-Policy',
expect.objectContaining({ path: '/r' }),
@ -799,7 +1009,7 @@ describe('clearCloudFrontCookies', () => {
{
domain: '.example.com',
path: '/',
httpOnly: true,
httpOnly: false,
secure: true,
sameSite: 'none',
},

View file

@ -1,7 +1,7 @@
import { getSignedCookies } from '@aws-sdk/cloudfront-signer';
import { logger } from '@librechat/data-schemas';
import type { Response } from 'express';
import type { NextFunction, Response } from 'express';
import { INLINE_AVATAR_PATH_PREFIX, INLINE_IMAGE_PATH_PREFIX } from '~/storage/constants';
import { assertPathSegment } from '~/storage/validation';
@ -23,8 +23,44 @@ export interface CloudFrontCookieScope {
userId?: string | null;
tenantId?: string | null;
storageRegion?: string | null;
issuedAt?: number | null;
expiresAt?: number | null;
}
type CloudFrontScopeValue = string | number | { toString(): string } | null | undefined;
type CloudFrontScopeUser = {
_id?: CloudFrontScopeValue;
id?: CloudFrontScopeValue;
tenantId?: CloudFrontScopeValue;
orgId?: CloudFrontScopeValue;
storageRegion?: CloudFrontScopeValue;
};
type CloudFrontCookieRequest = {
cookies?: Partial<Record<string, string>>;
user?: CloudFrontScopeUser | null;
};
type CloudFrontAuthCookieRefreshRequest = CloudFrontCookieRequest & {
cloudFrontAuthCookieRefreshResult?: CloudFrontAuthCookieRefreshResult;
};
export type CloudFrontAuthCookieRefreshResult = {
enabled: boolean;
attempted: boolean;
refreshed: boolean;
reason?: string;
expiresInSec?: number;
refreshAfterSec?: number;
};
type CloudFrontCookieRefreshOptions = CloudFrontCookieScope & {
orgId?: CloudFrontScopeValue;
force?: boolean;
refreshWindowSec?: number;
};
type CookieOptions = {
domain: string;
httpOnly: boolean;
@ -106,6 +142,41 @@ function getPolicyScopes(
];
}
function getConfiguredCookieExpiry(): number {
const config = getCloudFrontConfig();
return config?.cookieExpiry ?? DEFAULT_COOKIE_EXPIRY;
}
export function getCloudFrontCookieRefreshWindowSec(cookieExpiry = getConfiguredCookieExpiry()) {
return Math.min(300, Math.floor(cookieExpiry / 4));
}
export function getCloudFrontCookieTiming() {
const expiresInSec = getConfiguredCookieExpiry();
const refreshWindowSec = getCloudFrontCookieRefreshWindowSec(expiresInSec);
return {
expiresInSec,
refreshAfterSec: Math.max(0, expiresInSec - refreshWindowSec),
refreshWindowSec,
};
}
function getEffectiveCloudFrontScope(
scope: CloudFrontCookieScope,
includeRegionInPath: boolean,
): CloudFrontCookieScope {
const configuredStorageRegion =
scope.storageRegion ??
getCloudFrontConfig()?.storageRegion ??
s3Config.AWS_REGION ??
process.env.AWS_REGION;
const scopedStorageRegion = includeRegionInPath ? configuredStorageRegion : scope.storageRegion;
return {
...scope,
...(scopedStorageRegion ? { storageRegion: scopedStorageRegion } : {}),
};
}
function getScopeCookiePaths(
scope: CloudFrontCookieScope,
{ includeTenantRoot = false }: { includeTenantRoot?: boolean } = {},
@ -136,6 +207,8 @@ function encodeCloudFrontCookieScope(scope: CloudFrontCookieScope): string {
userId: scope.userId ?? null,
tenantId: scope.tenantId ?? null,
storageRegion: scope.storageRegion ?? null,
issuedAt: scope.issuedAt ?? null,
expiresAt: scope.expiresAt ?? null,
};
return Buffer.from(JSON.stringify(payload), 'utf8').toString('base64url');
}
@ -152,6 +225,8 @@ export function parseCloudFrontCookieScope(
userId?: unknown;
tenantId?: unknown;
storageRegion?: unknown;
issuedAt?: unknown;
expiresAt?: unknown;
};
const scope: CloudFrontCookieScope = {};
if (typeof parsed.userId === 'string') {
@ -163,12 +238,111 @@ export function parseCloudFrontCookieScope(
if (typeof parsed.storageRegion === 'string') {
scope.storageRegion = assertPolicyPathSegment('storageRegion', parsed.storageRegion);
}
if (typeof parsed.issuedAt === 'number' && Number.isFinite(parsed.issuedAt)) {
scope.issuedAt = parsed.issuedAt;
}
if (typeof parsed.expiresAt === 'number' && Number.isFinite(parsed.expiresAt)) {
scope.expiresAt = parsed.expiresAt;
}
return scope.userId ? scope : null;
} catch {
return null;
}
}
function normalizeCloudFrontScopeValue(value: CloudFrontScopeValue): string | undefined {
if (value == null) {
return undefined;
}
const normalized = String(value);
return normalized.length > 0 ? normalized : undefined;
}
function getCloudFrontScopeValue(
optionsValue: CloudFrontScopeValue,
userValue: CloudFrontScopeValue,
requestValue: CloudFrontScopeValue,
): string | undefined {
return normalizeCloudFrontScopeValue(optionsValue ?? userValue ?? requestValue);
}
export function resolveCloudFrontCookieScope(
req: CloudFrontCookieRequest | null | undefined,
user: CloudFrontScopeUser | null | undefined,
options: CloudFrontCookieRefreshOptions = {},
): CloudFrontCookieScope {
const storageRegion = getCloudFrontScopeValue(
options.storageRegion,
user?.storageRegion,
req?.user?.storageRegion,
);
return {
userId: getCloudFrontScopeValue(
options.userId,
user?._id ?? user?.id,
req?.user?._id ?? req?.user?.id,
),
tenantId: getCloudFrontScopeValue(
options.tenantId ?? options.orgId,
user?.tenantId ?? user?.orgId,
req?.user?.tenantId ?? req?.user?.orgId,
),
...(storageRegion ? { storageRegion } : {}),
};
}
function getPreviousCloudFrontScope(
req: CloudFrontCookieRequest | null | undefined,
): CloudFrontCookieScope | null {
return parseCloudFrontCookieScope(req?.cookies?.[CLOUDFRONT_SCOPE_COOKIE]);
}
function getCloudFrontCookieSkipReason(scope: CloudFrontCookieScope): string | null {
const config = getCloudFrontConfig();
if (!config || config.imageSigning !== 'cookies' || !config.privateKey || !config.keyPairId) {
return 'cloudfront_disabled';
}
if (!config.cookieDomain) {
return 'missing_cookie_domain';
}
if (!scope.userId) {
return 'missing_user_id';
}
return null;
}
function getScopeRefreshReason(
previousScope: CloudFrontCookieScope | null,
currentScope: CloudFrontCookieScope,
refreshWindowSec: number,
): string | null {
if (!previousScope?.userId) {
return 'missing_scope';
}
if (previousScope.userId !== currentScope.userId) {
return 'user_scope_mismatch';
}
if ((previousScope.tenantId ?? null) !== (currentScope.tenantId ?? null)) {
return 'tenant_scope_mismatch';
}
if ((previousScope.storageRegion ?? null) !== (currentScope.storageRegion ?? null)) {
return 'storage_region_scope_mismatch';
}
const expiresAt = Number(previousScope.expiresAt);
if (!Number.isFinite(expiresAt)) {
return 'missing_expiry';
}
const now = Math.floor(Date.now() / 1000);
if (expiresAt - now <= refreshWindowSec) {
return 'near_expiry';
}
return null;
}
function clearCookiePaths(
res: Response,
baseOptions: CookieOptions,
@ -215,7 +389,7 @@ export function clearCloudFrontCookies(res: Response, scope: CloudFrontCookieSco
}
clearCookiePaths(res, baseOptions, paths);
res.clearCookie(CLOUDFRONT_SCOPE_COOKIE, { ...baseOptions, path: '/' });
res.clearCookie(CLOUDFRONT_SCOPE_COOKIE, { ...baseOptions, httpOnly: false, path: '/' });
} catch (error) {
logger.warn('[clearCloudFrontCookies] Failed to clear cookies:', error);
}
@ -254,20 +428,15 @@ export function setCloudFrontCookies(
try {
const { keyPairId, privateKey } = config;
const cookieExpiry = config.cookieExpiry ?? DEFAULT_COOKIE_EXPIRY;
const expiresAtMs = Date.now() + cookieExpiry * 1000;
const cookieExpiry = getConfiguredCookieExpiry();
const issuedAtEpoch = Math.floor(Date.now() / 1000);
const expiresAtEpoch = issuedAtEpoch + cookieExpiry;
const expiresAtMs = expiresAtEpoch * 1000;
const expiresAt = new Date(expiresAtMs);
const expiresAtEpoch = Math.floor(expiresAtMs / 1000);
const cleanDomain = config.domain.replace(/\/+$/, '');
const includeRegionInPath = config.includeRegionInPath ?? false;
const configuredStorageRegion =
scope.storageRegion ?? config.storageRegion ?? s3Config.AWS_REGION ?? process.env.AWS_REGION;
const scopedStorageRegion = includeRegionInPath ? configuredStorageRegion : scope.storageRegion;
const effectiveScope = {
...scope,
...(scopedStorageRegion ? { storageRegion: scopedStorageRegion } : {}),
};
const effectiveScope = getEffectiveCloudFrontScope(scope, includeRegionInPath);
const policyScopes = getPolicyScopes(cleanDomain, effectiveScope, includeRegionInPath);
const resourcesByPath = new Map<string, string[]>();
for (const { resource, path } of policyScopes) {
@ -336,8 +505,14 @@ export function setCloudFrontCookies(
res.cookie(key, cookies[key], cookieOptions);
}
}
res.cookie(CLOUDFRONT_SCOPE_COOKIE, encodeCloudFrontCookieScope(effectiveScope), {
const scopeCookieValue = encodeCloudFrontCookieScope({
...effectiveScope,
issuedAt: issuedAtEpoch,
expiresAt: expiresAtEpoch,
});
res.cookie(CLOUDFRONT_SCOPE_COOKIE, scopeCookieValue, {
...baseCookieOptions,
httpOnly: false,
path: '/',
});
@ -351,3 +526,120 @@ export function setCloudFrontCookies(
return false;
}
}
export function maybeRefreshCloudFrontAuthCookies(
req: CloudFrontCookieRequest | null | undefined,
res: Response,
user: CloudFrontScopeUser | null | undefined,
options: CloudFrontCookieRefreshOptions = {},
): CloudFrontAuthCookieRefreshResult {
try {
const config = getCloudFrontConfig();
const scope = resolveCloudFrontCookieScope(req, user, options);
const skipReason = getCloudFrontCookieSkipReason(scope);
const timing = getCloudFrontCookieTiming();
if (skipReason) {
logger.debug('[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies skipped', {
attempted: false,
refreshed: false,
reason: skipReason,
has_user_id: Boolean(scope.userId),
has_tenant_scope: Boolean(scope.tenantId),
has_storage_region: Boolean(scope.storageRegion),
});
return {
enabled: false,
attempted: false,
refreshed: false,
reason: skipReason,
};
}
const includeRegionInPath = config?.includeRegionInPath ?? false;
const effectiveScope = getEffectiveCloudFrontScope(scope, includeRegionInPath);
const previousScope = getPreviousCloudFrontScope(req);
const refreshWindowSec = options.refreshWindowSec ?? timing.refreshWindowSec;
const refreshReason = options.force
? 'forced'
: getScopeRefreshReason(previousScope, effectiveScope, refreshWindowSec);
if (!refreshReason) {
logger.debug('[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies still fresh', {
attempted: false,
refreshed: false,
reason: 'fresh',
refresh_window_sec: refreshWindowSec,
});
return {
enabled: true,
attempted: false,
refreshed: false,
reason: 'fresh',
expiresInSec: timing.expiresInSec,
refreshAfterSec: timing.refreshAfterSec,
};
}
const cookiesSet = setCloudFrontCookies(res, effectiveScope, previousScope);
const logPayload = {
attempted: true,
refreshed: cookiesSet,
reason: cookiesSet ? refreshReason : 'set_failed',
refresh_window_sec: refreshWindowSec,
has_tenant_scope: Boolean(effectiveScope.tenantId),
has_storage_region: Boolean(effectiveScope.storageRegion),
has_previous_scope: Boolean(previousScope?.userId),
};
if (cookiesSet) {
logger.debug(
'[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies refreshed',
logPayload,
);
} else {
logger.warn(
'[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookie refresh failed',
logPayload,
);
}
return {
enabled: true,
attempted: true,
refreshed: cookiesSet,
reason: cookiesSet ? refreshReason : 'set_failed',
expiresInSec: timing.expiresInSec,
refreshAfterSec: timing.refreshAfterSec,
};
} catch (error) {
logger.warn(
'[maybeRefreshCloudFrontAuthCookies] Failed to refresh CloudFront auth cookies:',
error,
);
return {
enabled: false,
attempted: false,
refreshed: false,
reason: 'error',
};
}
}
export function forceRefreshCloudFrontAuthCookies(
req: CloudFrontCookieRequest | null | undefined,
res: Response,
user: CloudFrontScopeUser | null | undefined,
options: CloudFrontCookieRefreshOptions = {},
): CloudFrontAuthCookieRefreshResult {
return maybeRefreshCloudFrontAuthCookies(req, res, user, { ...options, force: true });
}
export function maybeRefreshCloudFrontAuthCookiesMiddleware(
req: CloudFrontAuthCookieRefreshRequest,
res: Response,
next: NextFunction,
): void {
req.cloudFrontAuthCookieRefreshResult = maybeRefreshCloudFrontAuthCookies(req, res, req.user);
next();
}

View file

@ -0,0 +1,202 @@
import { fireEvent, waitFor } from '@testing-library/react';
const mockApiBaseUrl = jest.fn(() => '');
const mockGetTokenHeader = jest.fn(() => 'Bearer test-token');
jest.mock('librechat-data-provider', () => ({
apiBaseUrl: () => mockApiBaseUrl(),
}));
import {
isCloudFrontMediaUrl,
refreshCloudFrontCookiesOnce,
installCloudFrontImageRetry,
configureCloudFrontCookieRefresh,
} from './cloudfront';
const cloudFrontStartupConfig = {
cloudFront: {
cookieRefresh: {
endpoint: '/api/auth/cloudfront/refresh',
domain: 'https://cdn.example.com',
},
},
};
function refreshResponse(payload: { ok?: boolean }, ok = true): Response {
return {
ok,
json: () => Promise.resolve(payload),
} as Response;
}
describe('CloudFront cookie refresh helpers', () => {
let fetchMock: jest.MockedFunction<typeof fetch>;
const originalFetch = global.fetch;
beforeEach(() => {
mockApiBaseUrl.mockReturnValue('');
mockGetTokenHeader.mockReturnValue('Bearer test-token');
fetchMock = jest.fn(() =>
Promise.resolve(refreshResponse({ ok: true })),
) as jest.MockedFunction<typeof fetch>;
global.fetch = fetchMock;
configureCloudFrontCookieRefresh(undefined);
jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000);
});
afterEach(() => {
global.fetch = originalFetch;
jest.restoreAllMocks();
});
it('no-ops when startup config has no CloudFront refresh capability', async () => {
configureCloudFrontCookieRefresh({});
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false);
expect(fetchMock).not.toHaveBeenCalled();
});
it('dedupes concurrent refresh calls', async () => {
let resolveRefresh: ((value: Response) => void) | undefined;
fetchMock.mockReturnValue(
new Promise((resolve) => {
resolveRefresh = resolve;
}),
);
configureCloudFrontCookieRefresh(cloudFrontStartupConfig, {
getAuthorizationHeader: mockGetTokenHeader,
});
const first = refreshCloudFrontCookiesOnce();
const second = refreshCloudFrontCookiesOnce();
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenCalledWith(
'/api/auth/cloudfront/refresh',
expect.objectContaining({
method: 'POST',
credentials: 'include',
headers: expect.objectContaining({ Authorization: 'Bearer test-token' }),
body: '{}',
}),
);
resolveRefresh?.(refreshResponse({ ok: true }));
await expect(first).resolves.toBe(true);
await expect(second).resolves.toBe(true);
});
it('returns false on 401 without retrying the refresh request', async () => {
fetchMock.mockResolvedValue(refreshResponse({}, false));
configureCloudFrontCookieRefresh(cloudFrontStartupConfig, {
getAuthorizationHeader: mockGetTokenHeader,
});
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false);
expect(fetchMock).toHaveBeenCalledTimes(1);
});
it('prefixes the refresh endpoint with the configured app base path', async () => {
mockApiBaseUrl.mockReturnValue('/chat');
configureCloudFrontCookieRefresh(cloudFrontStartupConfig, {
getAuthorizationHeader: mockGetTokenHeader,
});
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(true);
expect(fetchMock).toHaveBeenCalledWith(
'/chat/api/auth/cloudfront/refresh',
expect.objectContaining({ method: 'POST' }),
);
});
it('detects only the configured CloudFront domain', () => {
expect(
isCloudFrontMediaUrl(
'https://cdn.example.com/i/images/user/file.png',
cloudFrontStartupConfig,
),
).toBe(true);
expect(
isCloudFrontMediaUrl(
'https://images.example.net/i/images/user/file.png',
cloudFrontStartupConfig,
),
).toBe(false);
});
it('retries a configured CloudFront image only once from the global listener', async () => {
const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig);
const img = document.createElement('img');
const onFailure = jest.fn();
img.src = 'https://cdn.example.com/i/images/user/file.png';
img.addEventListener('error', onFailure);
document.body.appendChild(img);
fireEvent.error(img);
await waitFor(() =>
expect(img).toHaveAttribute(
'src',
'https://cdn.example.com/i/images/user/file.png?_cf_refresh=1700000000000',
),
);
fireEvent.error(img);
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(onFailure).toHaveBeenCalledTimes(1);
cleanup();
img.remove();
});
it('does not consume the one retry when cookie refresh fails', async () => {
fetchMock
.mockResolvedValueOnce(refreshResponse({ ok: false }))
.mockResolvedValueOnce(refreshResponse({ ok: true }));
const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig);
const img = document.createElement('img');
const onFailure = jest.fn();
img.src = 'https://cdn.example.com/i/images/user/file.png';
img.addEventListener('error', onFailure);
document.body.appendChild(img);
fireEvent.error(img);
await waitFor(() => expect(onFailure).toHaveBeenCalledTimes(1));
expect(img).toHaveAttribute('src', 'https://cdn.example.com/i/images/user/file.png');
fireEvent.error(img);
await waitFor(() =>
expect(img).toHaveAttribute(
'src',
'https://cdn.example.com/i/images/user/file.png?_cf_refresh=1700000000000',
),
);
expect(fetchMock).toHaveBeenCalledTimes(2);
cleanup();
img.remove();
});
it('does not retry arbitrary external images', () => {
const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig);
const img = document.createElement('img');
const onFailure = jest.fn();
img.src = 'https://example.com/photo.png';
img.addEventListener('error', onFailure);
document.body.appendChild(img);
fireEvent.error(img);
expect(fetchMock).not.toHaveBeenCalled();
expect(onFailure).toHaveBeenCalledTimes(1);
cleanup();
img.remove();
});
});

View file

@ -0,0 +1,209 @@
import { apiBaseUrl } from 'librechat-data-provider';
import type { TStartupConfig } from 'librechat-data-provider';
type CloudFrontCookieRefreshConfig = NonNullable<
NonNullable<TStartupConfig['cloudFront']>['cookieRefresh']
>;
type CloudFrontCookieRefreshResponse = {
ok?: boolean;
};
type CloudFrontCookieRefreshOptions = {
getAuthorizationHeader?: () => string | undefined;
};
let cookieRefreshConfig: CloudFrontCookieRefreshConfig | undefined;
let getAuthorizationHeader: CloudFrontCookieRefreshOptions['getAuthorizationHeader'];
let refreshPromise: Promise<boolean> | null = null;
let removeImageErrorListener: (() => void) | null = null;
const retriedImageSources = new WeakMap<HTMLImageElement, string>();
const pendingImageRefreshes = new WeakMap<HTMLImageElement, string>();
const forwardedImageErrors = new WeakSet<HTMLImageElement>();
function getRefreshConfig(
startupConfig?: Pick<TStartupConfig, 'cloudFront'> | null,
): CloudFrontCookieRefreshConfig | undefined {
return startupConfig?.cloudFront?.cookieRefresh ?? cookieRefreshConfig;
}
function getBaseUrl(): string {
return typeof window === 'undefined' ? 'http://localhost' : window.location.origin;
}
function parseUrl(value: string): URL | null {
try {
return new URL(value, getBaseUrl());
} catch {
return null;
}
}
export function configureCloudFrontCookieRefresh(
startupConfig?: Pick<TStartupConfig, 'cloudFront'> | null,
options: CloudFrontCookieRefreshOptions = {},
): void {
cookieRefreshConfig = startupConfig?.cloudFront?.cookieRefresh;
getAuthorizationHeader = options.getAuthorizationHeader;
}
export function isCloudFrontMediaUrl(
url: string | null | undefined,
startupConfig?: Pick<TStartupConfig, 'cloudFront'> | null,
): boolean {
const config = getRefreshConfig(startupConfig);
if (!url || !config?.domain) {
return false;
}
const mediaUrl = parseUrl(url);
const cloudFrontUrl = parseUrl(config.domain);
return mediaUrl?.origin === cloudFrontUrl?.origin;
}
export function withCloudFrontCacheBuster(url: string): string {
const parsed = parseUrl(url);
if (!parsed) {
return url;
}
parsed.searchParams.set('_cf_refresh', Date.now().toString());
return parsed.toString();
}
function getRetryKey(url: string): string {
const parsed = parseUrl(url);
if (!parsed) {
return url;
}
parsed.searchParams.delete('_cf_refresh');
return parsed.toString();
}
function dispatchImageError(img: HTMLImageElement): void {
forwardedImageErrors.add(img);
img.dispatchEvent(new Event('error'));
}
function getRefreshEndpoint(endpoint: string): string {
if (/^https?:\/\//i.test(endpoint)) {
return endpoint;
}
const baseUrl = apiBaseUrl();
if (!baseUrl || endpoint === baseUrl || endpoint.startsWith(`${baseUrl}/`)) {
return endpoint;
}
return `${baseUrl}${endpoint.startsWith('/') ? '' : '/'}${endpoint}`;
}
async function postCloudFrontCookieRefresh(endpoint: string): Promise<boolean> {
const authorization = getAuthorizationHeader?.();
const headers: Record<string, string> = {
Accept: 'application/json',
'Content-Type': 'application/json',
};
if (authorization) {
headers.Authorization = authorization;
}
const response = await fetch(endpoint, {
method: 'POST',
credentials: 'include',
headers,
body: '{}',
});
if (!response.ok) {
return false;
}
const payload = (await response.json()) as CloudFrontCookieRefreshResponse;
return payload.ok === true;
}
export function refreshCloudFrontCookiesOnce(): Promise<boolean> {
const config = getRefreshConfig();
if (!config?.endpoint) {
return Promise.resolve(false);
}
if (refreshPromise) {
return refreshPromise;
}
const endpoint = getRefreshEndpoint(config.endpoint);
refreshPromise = postCloudFrontCookieRefresh(endpoint)
.catch(() => false)
.finally(() => {
refreshPromise = null;
});
return refreshPromise;
}
export function installCloudFrontImageRetry(
startupConfig?: Pick<TStartupConfig, 'cloudFront'> | null,
options: CloudFrontCookieRefreshOptions = {},
): () => void {
configureCloudFrontCookieRefresh(startupConfig, options);
removeImageErrorListener?.();
removeImageErrorListener = null;
const config = getRefreshConfig();
if (typeof window === 'undefined' || !config?.endpoint || !config.domain) {
return () => undefined;
}
const handleImageError = (event: Event) => {
const img = event.target;
if (!(img instanceof HTMLImageElement)) {
return;
}
if (forwardedImageErrors.has(img)) {
forwardedImageErrors.delete(img);
return;
}
const failedSrc = img.currentSrc || img.src || img.getAttribute('src') || '';
if (!isCloudFrontMediaUrl(failedSrc)) {
return;
}
const retryKey = getRetryKey(failedSrc);
if (retriedImageSources.get(img) === retryKey) {
return;
}
event.preventDefault();
event.stopPropagation();
event.stopImmediatePropagation();
if (pendingImageRefreshes.get(img) === retryKey) {
return;
}
pendingImageRefreshes.set(img, retryKey);
void refreshCloudFrontCookiesOnce().then((refreshed) => {
pendingImageRefreshes.delete(img);
if (!refreshed || !img.isConnected) {
dispatchImageError(img);
return;
}
retriedImageSources.set(img, retryKey);
img.src = withCloudFrontCacheBuster(failedSrc);
});
};
window.addEventListener('error', handleImageError, true);
const cleanup = () => {
window.removeEventListener('error', handleImageError, true);
if (removeImageErrorListener === cleanup) {
removeImageErrorListener = null;
}
};
removeImageErrorListener = cleanup;
return cleanup;
}

View file

@ -1,3 +1,4 @@
export * from './utils';
export * from './theme';
export * from './cloudfront';
export { default as logger } from './logger';

View file

@ -1,5 +1,5 @@
import axios from 'axios';
import { setTokenHeader } from '../src/headers-helpers';
import { getTokenHeader, setTokenHeader } from '../src/headers-helpers';
describe('setTokenHeader', () => {
afterEach(() => {
@ -9,12 +9,14 @@ describe('setTokenHeader', () => {
it('sets the Authorization header with a Bearer token', () => {
setTokenHeader('my-token');
expect(axios.defaults.headers.common['Authorization']).toBe('Bearer my-token');
expect(getTokenHeader()).toBe('Bearer my-token');
});
it('deletes the Authorization header when called with undefined', () => {
axios.defaults.headers.common['Authorization'] = 'Bearer old-token';
setTokenHeader(undefined);
expect(axios.defaults.headers.common['Authorization']).toBeUndefined();
expect(getTokenHeader()).toBeUndefined();
});
it('is a no-op when clearing an already absent header', () => {

View file

@ -1103,6 +1103,12 @@ export type TStartupConfig = {
scraperProvider?: ScraperProviders;
rerankerType?: RerankerTypes;
};
cloudFront?: {
cookieRefresh?: {
endpoint: string;
domain: string;
};
};
mcpServers?: Record<
string,
{

View file

@ -11,3 +11,8 @@ export function setTokenHeader(token: string | undefined) {
axios.defaults.headers.common['Authorization'] = 'Bearer ' + token;
}
}
export function getTokenHeader(): string | undefined {
const authorization = axios.defaults.headers.common['Authorization'];
return typeof authorization === 'string' ? authorization : undefined;
}