mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-09 17:31:19 +00:00
🩻 refactor: Replace Opaque OAuth Errors with Structured Failure Diagnostics (#13471)
* Improve OAuth failure logging * Improve OAuth failure logging * test: type oauth failure request helper * refactor: move OpenID callback helper to api package
This commit is contained in:
parent
8ba0249f1e
commit
317b8dfbd5
7 changed files with 1032 additions and 11 deletions
|
|
@ -4,7 +4,13 @@ const passport = require('passport');
|
|||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
buildOAuthFailureLog,
|
||||
createOpenIDCallbackAuthenticator,
|
||||
createSetBalanceConfig,
|
||||
getOAuthFailureMessage,
|
||||
redirectToAuthFailure,
|
||||
} = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { findBalanceByUser, upsertBalanceFields } = require('~/models');
|
||||
|
|
@ -23,19 +29,35 @@ const domains = {
|
|||
server: process.env.DOMAIN_SERVER,
|
||||
};
|
||||
|
||||
const authFailureRedirectOptions = {
|
||||
clientDomain: domains.client,
|
||||
authFailedError: ErrorTypes.AUTH_FAILED,
|
||||
};
|
||||
|
||||
router.use(logHeaders);
|
||||
router.use(loginLimiter);
|
||||
|
||||
const oauthHandler = createOAuthHandler();
|
||||
const authenticateOpenIDCallback = createOpenIDCallbackAuthenticator({
|
||||
passport,
|
||||
logger,
|
||||
...authFailureRedirectOptions,
|
||||
});
|
||||
|
||||
router.get('/error', (req, res) => {
|
||||
/** A single error message is pushed by passport when authentication fails. */
|
||||
const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error';
|
||||
logger.error('Error in OAuth authentication:', {
|
||||
message: errorMessage,
|
||||
});
|
||||
const errorMessage = getOAuthFailureMessage(req);
|
||||
logger.warn(
|
||||
'[OAuth] Authentication failed',
|
||||
buildOAuthFailureLog({
|
||||
provider: 'unknown',
|
||||
req,
|
||||
info: { message: errorMessage },
|
||||
defaultMessage: errorMessage,
|
||||
}),
|
||||
);
|
||||
|
||||
res.redirect(`${domains.client}/login?redirect=false&error=${ErrorTypes.AUTH_FAILED}`);
|
||||
redirectToAuthFailure(res, authFailureRedirectOptions);
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -100,11 +122,7 @@ router.get('/openid', (req, res, next) => {
|
|||
|
||||
router.get(
|
||||
'/openid/callback',
|
||||
passport.authenticate('openid', {
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
authenticateOpenIDCallback,
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
|
|
|
|||
190
api/server/routes/oauth.test.js
Normal file
190
api/server/routes/oauth.test.js
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
|
||||
const originalDomainClient = process.env.DOMAIN_CLIENT;
|
||||
process.env.DOMAIN_CLIENT = 'http://client.test';
|
||||
|
||||
const mockLogger = {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
};
|
||||
|
||||
const mockOAuthHandler = jest.fn((_req, res) => res.status(204).end());
|
||||
const mockOpenIDCallbackMiddleware = jest.fn((_req, _res, next) => next());
|
||||
let mockOpenIDCallbackAuthenticatorOptions;
|
||||
const mockCreateOpenIDCallbackAuthenticator = jest.fn((options) => {
|
||||
mockOpenIDCallbackAuthenticatorOptions = options;
|
||||
return mockOpenIDCallbackMiddleware;
|
||||
});
|
||||
const mockBuildOAuthFailureLog = jest.fn(({ provider, req, err, info, defaultMessage }) => ({
|
||||
provider,
|
||||
code: err?.code ?? info?.code ?? info?.error ?? req.query?.error,
|
||||
name: err?.name ?? info?.name,
|
||||
message:
|
||||
err?.message ??
|
||||
info?.message ??
|
||||
info?.error_description ??
|
||||
req.query?.error_description ??
|
||||
defaultMessage,
|
||||
cause_code: err?.cause?.code ?? info?.cause?.code,
|
||||
cause_name: err?.cause?.name ?? info?.cause?.name,
|
||||
has_code: req.query?.code != null,
|
||||
has_state: req.query?.state != null,
|
||||
query_error: req.query?.error,
|
||||
query_error_description: req.query?.error_description,
|
||||
path: req.path,
|
||||
forwarded_for: req.headers?.['x-forwarded-for'],
|
||||
user_agent: req.headers?.['user-agent'],
|
||||
}));
|
||||
const mockGetOAuthFailureMessage = jest.fn(
|
||||
(req) =>
|
||||
req.session?.messages?.pop() ??
|
||||
req.query?.error_description ??
|
||||
req.query?.error ??
|
||||
'OAuth authentication failed',
|
||||
);
|
||||
const mockRedirectToAuthFailure = jest.fn((res, { clientDomain, authFailedError }) =>
|
||||
res.redirect(`${clientDomain}/login?redirect=false&error=${authFailedError}`),
|
||||
);
|
||||
const mockPassportAuthenticate = jest.fn(() => (_req, _res, next) => next());
|
||||
|
||||
jest.mock('passport', () => ({
|
||||
authenticate: (...args) => mockPassportAuthenticate(...args),
|
||||
}));
|
||||
|
||||
jest.mock('openid-client', () => ({
|
||||
randomState: jest.fn(() => 'random-state'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: mockLogger,
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
ErrorTypes: {
|
||||
AUTH_FAILED: 'auth_failed',
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock(
|
||||
'@librechat/api',
|
||||
() => ({
|
||||
buildOAuthFailureLog: (...args) => mockBuildOAuthFailureLog(...args),
|
||||
createOpenIDCallbackAuthenticator: (...args) => mockCreateOpenIDCallbackAuthenticator(...args),
|
||||
createSetBalanceConfig: jest.fn(() => (_req, _res, next) => next()),
|
||||
getOAuthFailureMessage: (...args) => mockGetOAuthFailureMessage(...args),
|
||||
redirectToAuthFailure: (...args) => mockRedirectToAuthFailure(...args),
|
||||
}),
|
||||
{ virtual: true },
|
||||
);
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
checkDomainAllowed: jest.fn((_req, _res, next) => next()),
|
||||
loginLimiter: jest.fn((_req, _res, next) => next()),
|
||||
logHeaders: jest.fn((_req, _res, next) => next()),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/auth/oauth', () => ({
|
||||
createOAuthHandler: jest.fn(() => mockOAuthHandler),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
findBalanceByUser: jest.fn(),
|
||||
upsertBalanceFields: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
const oauthRouter = require('./oauth');
|
||||
|
||||
afterAll(() => {
|
||||
if (originalDomainClient === undefined) {
|
||||
delete process.env.DOMAIN_CLIENT;
|
||||
return;
|
||||
}
|
||||
process.env.DOMAIN_CLIENT = originalDomainClient;
|
||||
});
|
||||
|
||||
function createApp(sessionMessages) {
|
||||
const app = express();
|
||||
app.use((req, _res, next) => {
|
||||
if (sessionMessages) {
|
||||
req.session = { messages: [...sessionMessages] };
|
||||
}
|
||||
next();
|
||||
});
|
||||
app.use('/oauth', oauthRouter);
|
||||
app.use((err, _req, res, _next) => {
|
||||
res.status(500).json({ message: err.message });
|
||||
});
|
||||
return app;
|
||||
}
|
||||
|
||||
describe('OAuth route failure logging', () => {
|
||||
beforeEach(() => {
|
||||
mockLogger.warn.mockClear();
|
||||
mockLogger.error.mockClear();
|
||||
mockLogger.info.mockClear();
|
||||
mockLogger.debug.mockClear();
|
||||
mockOAuthHandler.mockClear();
|
||||
mockOpenIDCallbackMiddleware.mockClear();
|
||||
mockBuildOAuthFailureLog.mockClear();
|
||||
mockGetOAuthFailureMessage.mockClear();
|
||||
mockRedirectToAuthFailure.mockClear();
|
||||
mockPassportAuthenticate.mockClear();
|
||||
mockPassportAuthenticate.mockImplementation(() => (_req, _res, next) => next());
|
||||
mockOpenIDCallbackMiddleware.mockImplementation((_req, _res, next) => next());
|
||||
});
|
||||
|
||||
it('wires the package OpenID callback middleware into the route', async () => {
|
||||
const app = createApp();
|
||||
|
||||
await request(app)
|
||||
.get('/oauth/openid/callback?code=secret-code&state=secret-state')
|
||||
.expect(204);
|
||||
|
||||
expect(mockOpenIDCallbackAuthenticatorOptions).toEqual({
|
||||
passport: expect.objectContaining({ authenticate: expect.any(Function) }),
|
||||
logger: mockLogger,
|
||||
clientDomain: 'http://client.test',
|
||||
authFailedError: 'auth_failed',
|
||||
});
|
||||
expect(mockOpenIDCallbackMiddleware).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.any(Object),
|
||||
expect.any(Function),
|
||||
);
|
||||
expect(mockOAuthHandler).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs structured fallback errors without using Unknown OAuth error', async () => {
|
||||
const app = createApp();
|
||||
|
||||
const response = await request(app)
|
||||
.get('/oauth/error?error=access_denied&error_description=Denied%20by%20provider')
|
||||
.set('x-forwarded-for', '203.0.113.10')
|
||||
.expect(302);
|
||||
|
||||
expect(response.headers.location).toBe(
|
||||
'http://client.test/login?redirect=false&error=auth_failed',
|
||||
);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
'[OAuth] Authentication failed',
|
||||
expect.objectContaining({
|
||||
provider: 'unknown',
|
||||
code: 'access_denied',
|
||||
message: 'Denied by provider',
|
||||
query_error: 'access_denied',
|
||||
query_error_description: 'Denied by provider',
|
||||
has_code: false,
|
||||
has_state: false,
|
||||
forwarded_for: '203.0.113.10',
|
||||
}),
|
||||
);
|
||||
expect(JSON.stringify(mockLogger.warn.mock.calls[0])).not.toContain('Unknown OAuth error');
|
||||
});
|
||||
});
|
||||
274
packages/api/src/oauth/callback.spec.ts
Normal file
274
packages/api/src/oauth/callback.spec.ts
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
import type { NextFunction, Response } from 'express';
|
||||
import {
|
||||
createOpenIDCallbackAuthenticator,
|
||||
logOpenIDCallbackFailure,
|
||||
redirectToAuthFailure,
|
||||
type OpenIDCallbackRequest,
|
||||
type OpenIDCallbackAuthenticatorOptions,
|
||||
} from './callback';
|
||||
|
||||
type CallbackFn = (err: unknown, user: unknown, info: unknown) => void;
|
||||
type TestRequest = OpenIDCallbackRequest;
|
||||
|
||||
const logger = {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
};
|
||||
|
||||
function createRequest(overrides: Partial<TestRequest> = {}): TestRequest {
|
||||
return {
|
||||
headers: {},
|
||||
method: 'GET',
|
||||
path: '/openid/callback',
|
||||
originalUrl: '/openid/callback',
|
||||
query: {},
|
||||
...overrides,
|
||||
} as TestRequest;
|
||||
}
|
||||
|
||||
function createResponse(): Response {
|
||||
return {
|
||||
redirect: jest.fn(),
|
||||
} as unknown as Response;
|
||||
}
|
||||
|
||||
function createNext(): jest.MockedFunction<NextFunction> {
|
||||
return jest.fn() as jest.MockedFunction<NextFunction>;
|
||||
}
|
||||
|
||||
function createAuthenticator(
|
||||
callbackHandler: (
|
||||
callback: CallbackFn,
|
||||
req: TestRequest,
|
||||
res: Response,
|
||||
next: NextFunction,
|
||||
) => void,
|
||||
) {
|
||||
const passport = {
|
||||
authenticate: jest.fn((_strategy: 'openid', _options, callback: CallbackFn) => {
|
||||
return (req: TestRequest, res: Response, next: NextFunction) => {
|
||||
callbackHandler(callback, req, res, next);
|
||||
};
|
||||
}),
|
||||
};
|
||||
const options: OpenIDCallbackAuthenticatorOptions = {
|
||||
passport,
|
||||
logger,
|
||||
clientDomain: 'http://client.test',
|
||||
authFailedError: 'auth_failed',
|
||||
};
|
||||
|
||||
return {
|
||||
middleware: createOpenIDCallbackAuthenticator(options),
|
||||
passport,
|
||||
};
|
||||
}
|
||||
|
||||
describe('OpenID OAuth callback helpers', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('redirects failed auth attempts to the login failure URL', () => {
|
||||
const res = createResponse();
|
||||
|
||||
redirectToAuthFailure(res, {
|
||||
clientDomain: 'http://client.test',
|
||||
authFailedError: 'auth_failed',
|
||||
});
|
||||
|
||||
expect(res.redirect).toHaveBeenCalledWith(
|
||||
'http://client.test/login?redirect=false&error=auth_failed',
|
||||
);
|
||||
});
|
||||
|
||||
it('logs OpenID callback failures with structured OAuth context', () => {
|
||||
const req = createRequest({
|
||||
query: {
|
||||
code: 'secret-code',
|
||||
state: 'secret-state',
|
||||
},
|
||||
});
|
||||
const error = Object.assign(new Error('invalid response encountered'), {
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
});
|
||||
|
||||
logOpenIDCallbackFailure({
|
||||
logger,
|
||||
req,
|
||||
err: error,
|
||||
info: { message: 'provider info' },
|
||||
});
|
||||
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
'[OpenID OAuth] Callback authentication failed',
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
message: 'invalid response encountered',
|
||||
has_code: true,
|
||||
has_state: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('continues the successful callback path after logging in without a session', () => {
|
||||
const user = { id: 'user-1' };
|
||||
const req = createRequest();
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const logIn = jest.fn((loginUser, _options, done) => {
|
||||
req.user = loginUser;
|
||||
done();
|
||||
});
|
||||
req.logIn = logIn;
|
||||
const { middleware, passport } = createAuthenticator((callback) =>
|
||||
callback(null, user, { message: 'ok' }),
|
||||
);
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(passport.authenticate).toHaveBeenCalledWith(
|
||||
'openid',
|
||||
{ failureMessage: true, session: false },
|
||||
expect.any(Function),
|
||||
);
|
||||
expect(logIn).toHaveBeenCalledWith(user, { session: false }, expect.any(Function));
|
||||
expect(req.user).toBe(user);
|
||||
expect(next).toHaveBeenCalledWith();
|
||||
expect(logger.warn).not.toHaveBeenCalled();
|
||||
expect(logger.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('sets req.user and continues when req.logIn is unavailable', () => {
|
||||
const user = { id: 'user-1' };
|
||||
const req = createRequest();
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const { middleware } = createAuthenticator((callback) => callback(null, user, undefined));
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(req.user).toBe(user);
|
||||
expect(next).toHaveBeenCalledWith();
|
||||
});
|
||||
|
||||
it('logs OpenID protocol failures and redirects without escalating', () => {
|
||||
const req = createRequest({
|
||||
query: {
|
||||
code: 'secret-code',
|
||||
state: 'secret-state',
|
||||
},
|
||||
});
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const error = Object.assign(new Error('invalid response encountered'), {
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
});
|
||||
const { middleware } = createAuthenticator((callback) => callback(error, false, undefined));
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
'[OpenID OAuth] Callback authentication failed',
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
message: 'invalid response encountered',
|
||||
has_code: true,
|
||||
has_state: true,
|
||||
}),
|
||||
);
|
||||
expect(res.redirect).toHaveBeenCalledWith(
|
||||
'http://client.test/login?redirect=false&error=auth_failed',
|
||||
);
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs unexpected OpenID errors with context before escalating', () => {
|
||||
const req = createRequest({
|
||||
query: {
|
||||
code: 'secret-code',
|
||||
state: 'secret-state',
|
||||
},
|
||||
});
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const error = Object.assign(new Error('database exploded'), {
|
||||
name: 'DatabaseError',
|
||||
});
|
||||
const { middleware } = createAuthenticator((callback) => callback(error, false, undefined));
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'[OpenID OAuth] Callback authentication error',
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
name: 'DatabaseError',
|
||||
message: 'database exploded',
|
||||
has_code: true,
|
||||
has_state: true,
|
||||
}),
|
||||
);
|
||||
expect(next).toHaveBeenCalledWith(error);
|
||||
expect(res.redirect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs Passport info failures and redirects without escalating', () => {
|
||||
const req = createRequest();
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const { middleware } = createAuthenticator((callback) =>
|
||||
callback(null, false, {
|
||||
code: 'DOMAIN_DENIED',
|
||||
message: 'Email domain not allowed',
|
||||
}),
|
||||
);
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
'[OpenID OAuth] Callback authentication failed',
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
code: 'DOMAIN_DENIED',
|
||||
message: 'Email domain not allowed',
|
||||
}),
|
||||
);
|
||||
expect(res.redirect).toHaveBeenCalledWith(
|
||||
'http://client.test/login?redirect=false&error=auth_failed',
|
||||
);
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs login errors and sends them to the error handler', () => {
|
||||
const user = { id: 'user-1' };
|
||||
const req = createRequest();
|
||||
const res = createResponse();
|
||||
const next = createNext();
|
||||
const error = Object.assign(new Error('login failed'), {
|
||||
name: 'LoginError',
|
||||
});
|
||||
req.logIn = jest.fn((_loginUser, _options, done) => done(error));
|
||||
const { middleware } = createAuthenticator((callback) =>
|
||||
callback(null, user, { message: 'provider info' }),
|
||||
);
|
||||
|
||||
middleware(req, res, next);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'[OpenID OAuth] Callback authentication error',
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
name: 'LoginError',
|
||||
message: 'login failed',
|
||||
}),
|
||||
);
|
||||
expect(next).toHaveBeenCalledWith(error);
|
||||
});
|
||||
});
|
||||
140
packages/api/src/oauth/callback.ts
Normal file
140
packages/api/src/oauth/callback.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import type { NextFunction, Response } from 'express';
|
||||
import {
|
||||
buildOAuthFailureLog,
|
||||
isOAuthProtocolFailure,
|
||||
type OAuthFailureLog,
|
||||
type OAuthFailureRequest,
|
||||
} from './failure';
|
||||
|
||||
type LoginFunction = (
|
||||
user: unknown,
|
||||
options: { session: false },
|
||||
done: (err?: unknown) => void,
|
||||
) => void;
|
||||
|
||||
export type OpenIDCallbackRequest = OAuthFailureRequest & {
|
||||
logIn?: LoginFunction;
|
||||
user?: unknown;
|
||||
};
|
||||
|
||||
type OpenIDCallback = (err: unknown, user: unknown, info: unknown) => void;
|
||||
|
||||
type PassportMiddleware = (
|
||||
req: OpenIDCallbackRequest,
|
||||
res: Response,
|
||||
next: NextFunction,
|
||||
) => unknown;
|
||||
|
||||
type PassportLike = {
|
||||
authenticate: (
|
||||
strategy: 'openid',
|
||||
options: {
|
||||
failureMessage: true;
|
||||
session: false;
|
||||
},
|
||||
callback: OpenIDCallback,
|
||||
) => PassportMiddleware;
|
||||
};
|
||||
|
||||
type OAuthCallbackLogLevel = 'warn' | 'error';
|
||||
|
||||
type OAuthCallbackLogger = Record<
|
||||
OAuthCallbackLogLevel,
|
||||
(message: string, details: OAuthFailureLog) => void
|
||||
>;
|
||||
|
||||
export type AuthFailureRedirectOptions = {
|
||||
clientDomain?: string;
|
||||
authFailedError: string;
|
||||
};
|
||||
|
||||
export type LogOpenIDCallbackFailureOptions = {
|
||||
logger: OAuthCallbackLogger;
|
||||
req: OAuthFailureRequest;
|
||||
err?: unknown;
|
||||
info?: unknown;
|
||||
level?: OAuthCallbackLogLevel;
|
||||
};
|
||||
|
||||
export type OpenIDCallbackAuthenticatorOptions = AuthFailureRedirectOptions & {
|
||||
logger: OAuthCallbackLogger;
|
||||
passport: PassportLike;
|
||||
};
|
||||
|
||||
export function redirectToAuthFailure(
|
||||
res: Response,
|
||||
{ clientDomain, authFailedError }: AuthFailureRedirectOptions,
|
||||
): void {
|
||||
res.redirect(`${clientDomain}/login?redirect=false&error=${authFailedError}`);
|
||||
}
|
||||
|
||||
export function logOpenIDCallbackFailure({
|
||||
logger,
|
||||
req,
|
||||
err,
|
||||
info,
|
||||
level = 'warn',
|
||||
}: LogOpenIDCallbackFailureOptions): void {
|
||||
logger[level](
|
||||
level === 'error'
|
||||
? '[OpenID OAuth] Callback authentication error'
|
||||
: '[OpenID OAuth] Callback authentication failed',
|
||||
buildOAuthFailureLog({
|
||||
provider: 'openid',
|
||||
req,
|
||||
err,
|
||||
info,
|
||||
defaultMessage: 'OpenID authentication failed',
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
export function createOpenIDCallbackAuthenticator({
|
||||
passport,
|
||||
logger,
|
||||
clientDomain,
|
||||
authFailedError,
|
||||
}: OpenIDCallbackAuthenticatorOptions): (
|
||||
req: OpenIDCallbackRequest,
|
||||
res: Response,
|
||||
next: NextFunction,
|
||||
) => unknown {
|
||||
return (req: OpenIDCallbackRequest, res: Response, next: NextFunction): unknown => {
|
||||
return passport.authenticate(
|
||||
'openid',
|
||||
{
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
},
|
||||
(err: unknown, user: unknown, info: unknown) => {
|
||||
if (err) {
|
||||
if (isOAuthProtocolFailure(err, info)) {
|
||||
logOpenIDCallbackFailure({ logger, req, err, info });
|
||||
return redirectToAuthFailure(res, { clientDomain, authFailedError });
|
||||
}
|
||||
|
||||
logOpenIDCallbackFailure({ logger, req, err, info, level: 'error' });
|
||||
return next(err);
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
logOpenIDCallbackFailure({ logger, req, err, info });
|
||||
return redirectToAuthFailure(res, { clientDomain, authFailedError });
|
||||
}
|
||||
|
||||
if (typeof req.logIn !== 'function') {
|
||||
req.user = user;
|
||||
return next();
|
||||
}
|
||||
|
||||
return req.logIn(user, { session: false }, (loginErr?: unknown) => {
|
||||
if (loginErr) {
|
||||
logOpenIDCallbackFailure({ logger, req, err: loginErr, info, level: 'error' });
|
||||
return next(loginErr);
|
||||
}
|
||||
return next();
|
||||
});
|
||||
},
|
||||
)(req, res, next);
|
||||
};
|
||||
}
|
||||
139
packages/api/src/oauth/failure.spec.ts
Normal file
139
packages/api/src/oauth/failure.spec.ts
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import { buildOAuthFailureLog, getOAuthFailureMessage, isOAuthProtocolFailure } from './failure';
|
||||
import type { OAuthFailureRequest } from './failure';
|
||||
|
||||
function createRequest(overrides: Partial<OAuthFailureRequest> = {}): OAuthFailureRequest {
|
||||
return {
|
||||
headers: {},
|
||||
method: 'GET',
|
||||
path: '/openid/callback',
|
||||
originalUrl: '/openid/callback',
|
||||
query: {},
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe('OAuth failure logging helpers', () => {
|
||||
it('prefers session failure messages and removes the consumed message', () => {
|
||||
const req = createRequest({
|
||||
session: {
|
||||
messages: ['first', 'latest'],
|
||||
},
|
||||
});
|
||||
|
||||
expect(getOAuthFailureMessage(req)).toBe('latest');
|
||||
expect(req.session?.messages).toEqual(['first']);
|
||||
});
|
||||
|
||||
it('falls back to provider query error details without returning Unknown OAuth error', () => {
|
||||
const req = createRequest({
|
||||
query: {
|
||||
error: 'access_denied',
|
||||
error_description: 'Denied by provider',
|
||||
},
|
||||
});
|
||||
|
||||
expect(getOAuthFailureMessage(req)).toBe('Denied by provider');
|
||||
});
|
||||
|
||||
it('logs OpenID protocol failure metadata without raw code or state values', () => {
|
||||
const req = createRequest({
|
||||
headers: {
|
||||
host: 'chat.example.com',
|
||||
'x-forwarded-for': '203.0.113.10',
|
||||
'x-forwarded-proto': 'https',
|
||||
'user-agent': 'test-agent',
|
||||
},
|
||||
id: 'request-id',
|
||||
originalUrl: '/openid/callback?code=secret-code&state=secret-state',
|
||||
query: {
|
||||
code: 'secret-code',
|
||||
state: 'secret-state',
|
||||
},
|
||||
});
|
||||
const error = Object.assign(new Error('invalid response encountered'), {
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
cause: {
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'OperationProcessingError',
|
||||
message: 'invalid response encountered',
|
||||
},
|
||||
});
|
||||
|
||||
const log = buildOAuthFailureLog({
|
||||
provider: 'openid',
|
||||
req,
|
||||
err: error,
|
||||
defaultMessage: 'OpenID authentication failed',
|
||||
});
|
||||
|
||||
expect(log).toEqual({
|
||||
provider: 'openid',
|
||||
code: 'OAUTH_INVALID_RESPONSE',
|
||||
name: 'ClientError',
|
||||
message: 'invalid response encountered',
|
||||
cause_code: 'OAUTH_INVALID_RESPONSE',
|
||||
cause_name: 'OperationProcessingError',
|
||||
cause_message: 'invalid response encountered',
|
||||
has_code: true,
|
||||
has_state: true,
|
||||
method: 'GET',
|
||||
path: '/openid/callback',
|
||||
request_id: 'request-id',
|
||||
host: 'chat.example.com',
|
||||
forwarded_proto: 'https',
|
||||
forwarded_for: '203.0.113.10',
|
||||
user_agent: 'test-agent',
|
||||
});
|
||||
expect(JSON.stringify(log)).not.toContain('secret-code');
|
||||
expect(JSON.stringify(log)).not.toContain('secret-state');
|
||||
});
|
||||
|
||||
it('captures provider response error fields from Passport info', () => {
|
||||
const log = buildOAuthFailureLog({
|
||||
provider: 'openid',
|
||||
req: createRequest(),
|
||||
info: {
|
||||
error: 'access_denied',
|
||||
error_description: 'User denied consent',
|
||||
},
|
||||
});
|
||||
|
||||
expect(log).toEqual({
|
||||
provider: 'openid',
|
||||
code: 'access_denied',
|
||||
message: 'User denied consent',
|
||||
has_code: false,
|
||||
has_state: false,
|
||||
method: 'GET',
|
||||
path: '/openid/callback',
|
||||
});
|
||||
});
|
||||
|
||||
it('truncates very long messages', () => {
|
||||
const longMessage = 'x'.repeat(320);
|
||||
|
||||
const log = buildOAuthFailureLog({
|
||||
provider: 'openid',
|
||||
req: createRequest(),
|
||||
info: {
|
||||
message: longMessage,
|
||||
},
|
||||
});
|
||||
|
||||
expect(log.message).toHaveLength(315);
|
||||
expect(log.message?.endsWith('... [truncated]')).toBe(true);
|
||||
});
|
||||
|
||||
it.each([
|
||||
[{ code: 'OAUTH_INVALID_RESPONSE' }, true],
|
||||
[{ name: 'AuthorizationResponseError' }, true],
|
||||
[
|
||||
{ cause: { name: 'OperationProcessingError', message: 'invalid response encountered' } },
|
||||
true,
|
||||
],
|
||||
[{ name: 'DatabaseError', message: 'database exploded' }, false],
|
||||
])('classifies OAuth protocol failure %j as %s', (error, expected) => {
|
||||
expect(isOAuthProtocolFailure(error)).toBe(expected);
|
||||
});
|
||||
});
|
||||
258
packages/api/src/oauth/failure.ts
Normal file
258
packages/api/src/oauth/failure.ts
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
import type { Request } from 'express';
|
||||
|
||||
const MAX_LOG_VALUE_LENGTH = 300;
|
||||
|
||||
type LogValue = string | boolean;
|
||||
|
||||
type FailureLike = {
|
||||
code?: unknown;
|
||||
error?: unknown;
|
||||
name?: unknown;
|
||||
message?: unknown;
|
||||
error_description?: unknown;
|
||||
cause?: unknown;
|
||||
};
|
||||
|
||||
export type OAuthFailureRequest = Pick<
|
||||
Request,
|
||||
'headers' | 'method' | 'path' | 'originalUrl' | 'query'
|
||||
> & {
|
||||
id?: string;
|
||||
requestId?: string;
|
||||
session?: {
|
||||
messages?: unknown[];
|
||||
};
|
||||
};
|
||||
|
||||
export type OAuthFailureLog = {
|
||||
provider: string;
|
||||
code?: string;
|
||||
name?: string;
|
||||
message?: string;
|
||||
cause_code?: string;
|
||||
cause_name?: string;
|
||||
cause_message?: string;
|
||||
has_code: boolean;
|
||||
has_state: boolean;
|
||||
query_error?: string;
|
||||
query_error_description?: string;
|
||||
method?: string;
|
||||
path?: string;
|
||||
request_id?: string;
|
||||
host?: string;
|
||||
forwarded_host?: string;
|
||||
forwarded_proto?: string;
|
||||
forwarded_for?: string;
|
||||
real_ip?: string;
|
||||
user_agent?: string;
|
||||
};
|
||||
|
||||
export type BuildOAuthFailureLogParams = {
|
||||
provider: string;
|
||||
req: OAuthFailureRequest;
|
||||
err?: unknown;
|
||||
info?: unknown;
|
||||
defaultMessage?: string;
|
||||
};
|
||||
|
||||
function normalizeLogValue(value: unknown): string | undefined {
|
||||
if (value == null) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
for (const entry of value) {
|
||||
const normalized = normalizeLogValue(entry);
|
||||
if (normalized) {
|
||||
return normalized;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (typeof value === 'string') {
|
||||
const trimmed = value.replace(/\s+/g, ' ').trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (trimmed.length <= MAX_LOG_VALUE_LENGTH) {
|
||||
return trimmed;
|
||||
}
|
||||
return `${trimmed.slice(0, MAX_LOG_VALUE_LENGTH)}... [truncated]`;
|
||||
}
|
||||
|
||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
||||
return String(value);
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function compactLogObject(
|
||||
log: Partial<OAuthFailureLog> & Pick<OAuthFailureLog, 'provider' | 'has_code' | 'has_state'>,
|
||||
): OAuthFailureLog {
|
||||
const compacted: Partial<OAuthFailureLog> = {};
|
||||
const keys = Object.keys(log) as Array<keyof OAuthFailureLog>;
|
||||
for (const key of keys) {
|
||||
const value = log[key];
|
||||
if (value !== undefined) {
|
||||
Object.assign(compacted, { [key]: value as LogValue });
|
||||
}
|
||||
}
|
||||
return compacted as OAuthFailureLog;
|
||||
}
|
||||
|
||||
function getField(source: unknown, field: keyof FailureLike): unknown {
|
||||
if (!source) {
|
||||
return undefined;
|
||||
}
|
||||
if (typeof source === 'string') {
|
||||
return field === 'message' ? source : undefined;
|
||||
}
|
||||
if (typeof source === 'object') {
|
||||
return (source as FailureLike)[field];
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function firstLogValue(...values: unknown[]): string | undefined {
|
||||
for (const value of values) {
|
||||
const normalized = normalizeLogValue(value);
|
||||
if (normalized) {
|
||||
return normalized;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function getCause(source: unknown): unknown {
|
||||
const cause = getField(source, 'cause');
|
||||
return cause && typeof cause === 'object' ? cause : undefined;
|
||||
}
|
||||
|
||||
function getHeader(req: OAuthFailureRequest, headerName: string): string | undefined {
|
||||
return normalizeLogValue(req.headers?.[headerName]);
|
||||
}
|
||||
|
||||
function getQueryValue(req: OAuthFailureRequest, queryName: string): string | undefined {
|
||||
return normalizeLogValue(req.query?.[queryName]);
|
||||
}
|
||||
|
||||
function hasQueryValue(req: OAuthFailureRequest, queryName: string): boolean {
|
||||
return getQueryValue(req, queryName) !== undefined;
|
||||
}
|
||||
|
||||
function getRequestPath(req: OAuthFailureRequest): string | undefined {
|
||||
return firstLogValue(req.path, req.originalUrl?.split('?')[0]);
|
||||
}
|
||||
|
||||
function popSessionFailureMessage(req: OAuthFailureRequest): unknown {
|
||||
const messages = req.session?.messages;
|
||||
if (!Array.isArray(messages) || messages.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
return messages.pop();
|
||||
}
|
||||
|
||||
export function getOAuthFailureMessage(
|
||||
req: OAuthFailureRequest,
|
||||
defaultMessage = 'OAuth authentication failed',
|
||||
): string {
|
||||
return (
|
||||
firstLogValue(
|
||||
popSessionFailureMessage(req),
|
||||
getQueryValue(req, 'error_description'),
|
||||
getQueryValue(req, 'error'),
|
||||
defaultMessage,
|
||||
) ?? defaultMessage
|
||||
);
|
||||
}
|
||||
|
||||
export function buildOAuthFailureLog({
|
||||
provider,
|
||||
req,
|
||||
err,
|
||||
info,
|
||||
defaultMessage,
|
||||
}: BuildOAuthFailureLogParams): OAuthFailureLog {
|
||||
const errCause = getCause(err);
|
||||
const infoCause = getCause(info);
|
||||
return compactLogObject({
|
||||
provider,
|
||||
code: firstLogValue(
|
||||
getField(err, 'code'),
|
||||
getField(err, 'error'),
|
||||
getField(errCause, 'code'),
|
||||
getField(errCause, 'error'),
|
||||
getField(info, 'code'),
|
||||
getField(info, 'error'),
|
||||
getField(infoCause, 'code'),
|
||||
getField(infoCause, 'error'),
|
||||
getQueryValue(req, 'error'),
|
||||
),
|
||||
name: firstLogValue(getField(err, 'name'), getField(info, 'name')),
|
||||
message: firstLogValue(
|
||||
getField(err, 'message'),
|
||||
getField(err, 'error_description'),
|
||||
getField(info, 'message'),
|
||||
getField(info, 'error_description'),
|
||||
getQueryValue(req, 'error_description'),
|
||||
getQueryValue(req, 'error'),
|
||||
defaultMessage,
|
||||
),
|
||||
cause_code: firstLogValue(getField(errCause, 'code'), getField(infoCause, 'code')),
|
||||
cause_name: firstLogValue(getField(errCause, 'name'), getField(infoCause, 'name')),
|
||||
cause_message: firstLogValue(getField(errCause, 'message'), getField(infoCause, 'message')),
|
||||
has_code: hasQueryValue(req, 'code'),
|
||||
has_state: hasQueryValue(req, 'state'),
|
||||
query_error: getQueryValue(req, 'error'),
|
||||
query_error_description: getQueryValue(req, 'error_description'),
|
||||
method: normalizeLogValue(req.method),
|
||||
path: getRequestPath(req),
|
||||
request_id: firstLogValue(req.requestId, req.id, getHeader(req, 'x-request-id')),
|
||||
host: getHeader(req, 'host'),
|
||||
forwarded_host: getHeader(req, 'x-forwarded-host'),
|
||||
forwarded_proto: getHeader(req, 'x-forwarded-proto'),
|
||||
forwarded_for: getHeader(req, 'x-forwarded-for'),
|
||||
real_ip: getHeader(req, 'x-real-ip'),
|
||||
user_agent: getHeader(req, 'user-agent'),
|
||||
});
|
||||
}
|
||||
|
||||
export function isOAuthProtocolFailure(err?: unknown, info?: unknown): boolean {
|
||||
const errCause = getCause(err);
|
||||
const infoCause = getCause(info);
|
||||
const code = firstLogValue(
|
||||
getField(err, 'code'),
|
||||
getField(err, 'error'),
|
||||
getField(errCause, 'code'),
|
||||
getField(errCause, 'error'),
|
||||
getField(info, 'code'),
|
||||
getField(info, 'error'),
|
||||
getField(infoCause, 'code'),
|
||||
getField(infoCause, 'error'),
|
||||
);
|
||||
|
||||
if (code?.startsWith('OAUTH_')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const name = firstLogValue(
|
||||
getField(err, 'name'),
|
||||
getField(errCause, 'name'),
|
||||
getField(info, 'name'),
|
||||
getField(infoCause, 'name'),
|
||||
);
|
||||
if (name === 'AuthorizationResponseError') {
|
||||
return true;
|
||||
}
|
||||
|
||||
const message = firstLogValue(
|
||||
getField(err, 'message'),
|
||||
getField(errCause, 'message'),
|
||||
getField(info, 'message'),
|
||||
getField(infoCause, 'message'),
|
||||
);
|
||||
|
||||
return name === 'OperationProcessingError' && /invalid response/i.test(message ?? '');
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
export * from './csrf';
|
||||
export * from './callback';
|
||||
export * from './failure';
|
||||
export * from './tokens';
|
||||
export * from './validation';
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue