diff --git a/api/cache/keyvRedis.js b/api/cache/keyvRedis.js index 1f4727e156..992e789ae3 100644 --- a/api/cache/keyvRedis.js +++ b/api/cache/keyvRedis.js @@ -9,7 +9,7 @@ const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, RED let keyvRedis; const redis_prefix = REDIS_KEY_PREFIX || ''; -const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 10; +const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40; function mapURI(uri) { const regex = diff --git a/api/package.json b/api/package.json index 7227d36bb4..b04b049ce8 100644 --- a/api/package.json +++ b/api/package.json @@ -103,6 +103,7 @@ "passport-jwt": "^4.0.1", "passport-ldapauth": "^3.0.1", "passport-local": "^1.0.0", + "rate-limit-redis": "^4.2.0", "sharp": "^0.32.6", "tiktoken": "^1.0.15", "traverse": "^0.6.7", diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index c397ca7d1a..67540bb009 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -41,7 +41,7 @@ const banResponse = async (req, res) => { * @function * @param {Object} req - Express request object. * @param {Object} res - Express response object. - * @param {Function} next - Next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * * @returns {Promise} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`. */ diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js index 58ff689a0b..21b3a86903 100644 --- a/api/server/middleware/concurrentLimiter.js +++ b/api/server/middleware/concurrentLimiter.js @@ -21,7 +21,7 @@ const { * @function * @param {Object} req - Express request object containing user information. * @param {Object} res - Express response object. - * @param {function} next - Express next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * @throws {Error} Throws an error if the user exceeds the concurrent request limit. */ const concurrentLimiter = async (req, res, next) => { diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 3da9e06bd6..789ec6a82d 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser'); const requireJwtAuth = require('./requireJwtAuth'); const validateModel = require('./validateModel'); const moderateText = require('./moderateText'); +const logHeaders = require('./logHeaders'); const setHeaders = require('./setHeaders'); const validate = require('./validate'); const limiters = require('./limiters'); @@ -31,6 +32,7 @@ module.exports = { checkBan, uaParser, setHeaders, + logHeaders, moderateText, validateModel, requireJwtAuth, diff --git a/api/server/middleware/limiters/importLimiters.js b/api/server/middleware/limiters/importLimiters.js index a21fa6453e..5e50046a30 100644 --- a/api/server/middleware/limiters/importLimiters.js +++ b/api/server/middleware/limiters/importLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100; @@ -48,21 +53,39 @@ const createImportLimiters = () => { const { importIpWindowMs, importIpMax, importUserWindowMs, importUserMax } = getEnvironmentVariables(); - const importIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: importIpWindowMs, max: importIpMax, handler: createImportHandler(), - }); - - const importUserLimiter = rateLimit({ + }; + const userLimiterOptions = { windowMs: importUserWindowMs, max: importUserMax, handler: createImportHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for import rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'import_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'import_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const importIpLimiter = rateLimit(ipLimiterOptions); + const importUserLimiter = rateLimit(userLimiterOptions); return { importIpLimiter, importUserLimiter }; }; diff --git a/api/server/middleware/limiters/loginLimiter.js b/api/server/middleware/limiters/loginLimiter.js index 937723e859..8cf10ccb12 100644 --- a/api/server/middleware/limiters/loginLimiter.js +++ b/api/server/middleware/limiters/loginLimiter.js @@ -1,6 +1,10 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; const windowMs = LOGIN_WINDOW * 60 * 1000; @@ -20,11 +24,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const loginLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for login rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'login_limiter:', + }); + limiterOptions.store = store; +} + +const loginLimiter = rateLimit(limiterOptions); module.exports = loginLimiter; diff --git a/api/server/middleware/limiters/messageLimiters.js b/api/server/middleware/limiters/messageLimiters.js index c84db1043c..fe4f75a9c6 100644 --- a/api/server/middleware/limiters/messageLimiters.js +++ b/api/server/middleware/limiters/messageLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const denyRequest = require('~/server/middleware/denyRequest'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { MESSAGE_IP_MAX = 40, @@ -41,25 +46,49 @@ const createHandler = (ip = true) => { }; /** - * Message request rate limiter by IP + * Message request rate limiters */ -const messageIpLimiter = rateLimit({ +const ipLimiterOptions = { windowMs: ipWindowMs, max: ipMax, handler: createHandler(), -}); +}; -/** - * Message request rate limiter by userId - */ -const messageUserLimiter = rateLimit({ +const userLimiterOptions = { windowMs: userWindowMs, max: userMax, handler: createHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for message rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'message_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'message_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; +} + +/** + * Message request rate limiter by IP + */ +const messageIpLimiter = rateLimit(ipLimiterOptions); + +/** + * Message request rate limiter by userId + */ +const messageUserLimiter = rateLimit(userLimiterOptions); module.exports = { messageIpLimiter, diff --git a/api/server/middleware/limiters/registerLimiter.js b/api/server/middleware/limiters/registerLimiter.js index b069798b03..f9bf1215cd 100644 --- a/api/server/middleware/limiters/registerLimiter.js +++ b/api/server/middleware/limiters/registerLimiter.js @@ -1,6 +1,10 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; const windowMs = REGISTER_WINDOW * 60 * 1000; @@ -20,11 +24,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const registerLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for register rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'register_limiter:', + }); + limiterOptions.store = store; +} + +const registerLimiter = rateLimit(limiterOptions); module.exports = registerLimiter; diff --git a/api/server/middleware/limiters/resetPasswordLimiter.js b/api/server/middleware/limiters/resetPasswordLimiter.js index 5d2deb0282..9f56bd7949 100644 --- a/api/server/middleware/limiters/resetPasswordLimiter.js +++ b/api/server/middleware/limiters/resetPasswordLimiter.js @@ -1,7 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { RESET_PASSWORD_WINDOW = 2, @@ -25,11 +29,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const resetPasswordLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for reset password rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'reset_password_limiter:', + }); + limiterOptions.store = store; +} + +const resetPasswordLimiter = rateLimit(limiterOptions); module.exports = resetPasswordLimiter; diff --git a/api/server/middleware/limiters/sttLimiters.js b/api/server/middleware/limiters/sttLimiters.js index 76f2944f0a..f9304637c4 100644 --- a/api/server/middleware/limiters/sttLimiters.js +++ b/api/server/middleware/limiters/sttLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100; @@ -47,20 +52,40 @@ const createSTTHandler = (ip = true) => { const createSTTLimiters = () => { const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables(); - const sttIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: sttIpWindowMs, max: sttIpMax, handler: createSTTHandler(), - }); + }; - const sttUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: sttUserWindowMs, max: sttUserMax, handler: createSTTHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for STT rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'stt_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'stt_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const sttIpLimiter = rateLimit(ipLimiterOptions); + const sttUserLimiter = rateLimit(userLimiterOptions); return { sttIpLimiter, sttUserLimiter }; }; diff --git a/api/server/middleware/limiters/toolCallLimiter.js b/api/server/middleware/limiters/toolCallLimiter.js index 47dcaeabb4..7a867b5bcd 100644 --- a/api/server/middleware/limiters/toolCallLimiter.js +++ b/api/server/middleware/limiters/toolCallLimiter.js @@ -1,25 +1,46 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); -const toolCallLimiter = rateLimit({ +const handler = async (req, res) => { + const type = ViolationTypes.TOOL_CALL_LIMIT; + const errorMessage = { + type, + max: 1, + limiter: 'user', + windowInMinutes: 1, + }; + + await logViolation(req, res, type, errorMessage, 0); + res.status(429).json({ message: 'Too many tool call requests. Try again later' }); +}; + +const limiterOptions = { windowMs: 1000, max: 1, - handler: async (req, res) => { - const type = ViolationTypes.TOOL_CALL_LIMIT; - const errorMessage = { - type, - max: 1, - limiter: 'user', - windowInMinutes: 1, - }; - - await logViolation(req, res, type, errorMessage, 0); - res.status(429).json({ message: 'Too many tool call requests. Try again later' }); - }, + handler, keyGenerator: function (req) { return req.user?.id; }, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for tool call rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'tool_call_limiter:', + }); + limiterOptions.store = store; +} + +const toolCallLimiter = rateLimit(limiterOptions); module.exports = toolCallLimiter; diff --git a/api/server/middleware/limiters/ttsLimiters.js b/api/server/middleware/limiters/ttsLimiters.js index 5619a49b63..e13aaf48c3 100644 --- a/api/server/middleware/limiters/ttsLimiters.js +++ b/api/server/middleware/limiters/ttsLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100; @@ -47,20 +52,40 @@ const createTTSHandler = (ip = true) => { const createTTSLimiters = () => { const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables(); - const ttsIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: ttsIpWindowMs, max: ttsIpMax, handler: createTTSHandler(), - }); + }; - const ttsUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: ttsUserWindowMs, max: ttsUserMax, handler: createTTSHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for TTS rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'tts_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'tts_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const ttsIpLimiter = rateLimit(ipLimiterOptions); + const ttsUserLimiter = rateLimit(userLimiterOptions); return { ttsIpLimiter, ttsUserLimiter }; }; diff --git a/api/server/middleware/limiters/uploadLimiters.js b/api/server/middleware/limiters/uploadLimiters.js index 71af164fde..9fffface61 100644 --- a/api/server/middleware/limiters/uploadLimiters.js +++ b/api/server/middleware/limiters/uploadLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; @@ -52,20 +57,40 @@ const createFileLimiters = () => { const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } = getEnvironmentVariables(); - const fileUploadIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: fileUploadIpWindowMs, max: fileUploadIpMax, handler: createFileUploadHandler(), - }); + }; - const fileUploadUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: fileUploadUserWindowMs, max: fileUploadUserMax, handler: createFileUploadHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for file upload rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const fileUploadIpLimiter = rateLimit(ipLimiterOptions); + const fileUploadUserLimiter = rateLimit(userLimiterOptions); return { fileUploadIpLimiter, fileUploadUserLimiter }; }; diff --git a/api/server/middleware/limiters/verifyEmailLimiter.js b/api/server/middleware/limiters/verifyEmailLimiter.js index 770090dba5..0b245afbd1 100644 --- a/api/server/middleware/limiters/verifyEmailLimiter.js +++ b/api/server/middleware/limiters/verifyEmailLimiter.js @@ -1,7 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { VERIFY_EMAIL_WINDOW = 2, @@ -25,11 +29,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const verifyEmailLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for verify email rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'verify_email_limiter:', + }); + limiterOptions.store = store; +} + +const verifyEmailLimiter = rateLimit(limiterOptions); module.exports = verifyEmailLimiter; diff --git a/api/server/middleware/logHeaders.js b/api/server/middleware/logHeaders.js new file mode 100644 index 0000000000..26ca04da38 --- /dev/null +++ b/api/server/middleware/logHeaders.js @@ -0,0 +1,32 @@ +const { logger } = require('~/config'); + +/** + * Middleware to log Forwarded Headers + * @function + * @param {ServerRequest} req - Express request object containing user information. + * @param {ServerResponse} res - Express response object. + * @param {import('express').NextFunction} next - Next middleware function. + * @throws {Error} Throws an error if the user exceeds the concurrent request limit. + */ +const logHeaders = (req, res, next) => { + try { + const forwardedHeaders = {}; + if (req.headers['x-forwarded-for']) { + forwardedHeaders['x-forwarded-for'] = req.headers['x-forwarded-for']; + } + if (req.headers['x-forwarded-host']) { + forwardedHeaders['x-forwarded-host'] = req.headers['x-forwarded-host']; + } + if (req.headers['x-forwarded-proto']) { + forwardedHeaders['x-forwarded-proto'] = req.headers['x-forwarded-proto']; + } + if (Object.keys(forwardedHeaders).length > 0) { + logger.debug('X-Forwarded headers detected in OAuth request:', forwardedHeaders); + } + } catch (error) { + logger.error('Error logging X-Forwarded headers:', error); + } + next(); +}; + +module.exports = logHeaders; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 6536b98e92..2d9fae7ae7 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -17,6 +17,7 @@ const { } = require('~/server/controllers/TwoFactorController'); const { checkBan, + logHeaders, loginLimiter, requireJwtAuth, checkInviteUser, @@ -35,6 +36,7 @@ const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE; router.post('/logout', requireJwtAuth, logoutController); router.post( '/login', + logHeaders, loginLimiter, checkBan, ldapAuth ? requireLdapAuth : requireLocalAuth, diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 9006b25c5b..9ea896e30e 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -1,7 +1,7 @@ // file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware const express = require('express'); const passport = require('passport'); -const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware'); +const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); @@ -12,6 +12,7 @@ const domains = { server: process.env.DOMAIN_SERVER, }; +router.use(logHeaders); router.use(loginLimiter); const oauthHandler = async (req, res) => { diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index f39d1da596..af80a3b880 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -1,4 +1,4 @@ -const Redis = require('ioredis'); +const Keyv = require('keyv'); const passport = require('passport'); const session = require('express-session'); const MemoryStore = require('memorystore')(session); @@ -12,6 +12,7 @@ const { appleLogin, } = require('~/strategies'); const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logger } = require('~/config'); /** @@ -19,6 +20,8 @@ const { logger } = require('~/config'); * @param {Express.Application} app */ const configureSocialLogins = (app) => { + logger.info('Configuring social logins...'); + if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) { passport.use(googleLogin()); } @@ -41,18 +44,17 @@ const configureSocialLogins = (app) => { process.env.OPENID_SCOPE && process.env.OPENID_SESSION_SECRET ) { + logger.info('Configuring OpenID Connect...'); const sessionOptions = { secret: process.env.OPENID_SESSION_SECRET, resave: false, saveUninitialized: false, }; if (isEnabled(process.env.USE_REDIS)) { - const client = new Redis(process.env.REDIS_URI); - client - .on('error', (err) => logger.error('ioredis error:', err)) - .on('ready', () => logger.info('ioredis successfully initialized.')) - .on('reconnecting', () => logger.info('ioredis reconnecting...')); - sessionOptions.store = new RedisStore({ client, prefix: 'librechat' }); + logger.debug('Using Redis for session storage in OpenID...'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' }); } else { sessionOptions.store = new MemoryStore({ checkPeriod: 86400000, // prune expired entries every 24h @@ -61,7 +63,9 @@ const configureSocialLogins = (app) => { app.use(session(sessionOptions)); app.use(passport.session()); setupOpenId(); + + logger.info('OpenID Connect configured.'); } }; -module.exports = configureSocialLogins; \ No newline at end of file +module.exports = configureSocialLogins; diff --git a/package-lock.json b/package-lock.json index c391726fae..38c4ec0520 100644 --- a/package-lock.json +++ b/package-lock.json @@ -119,6 +119,7 @@ "passport-jwt": "^4.0.1", "passport-ldapauth": "^3.0.1", "passport-local": "^1.0.0", + "rate-limit-redis": "^4.2.0", "sharp": "^0.32.6", "tiktoken": "^1.0.15", "traverse": "^0.6.7", @@ -37941,6 +37942,17 @@ "node": ">= 0.6" } }, + "node_modules/rate-limit-redis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/rate-limit-redis/-/rate-limit-redis-4.2.0.tgz", + "integrity": "sha512-wV450NQyKC24NmPosJb2131RoczLdfIJdKCReNwtVpm5998U8SgKrAZrIHaN/NfQgqOHaan8Uq++B4sa5REwjA==", + "engines": { + "node": ">= 16" + }, + "peerDependencies": { + "express-rate-limit": ">= 6" + } + }, "node_modules/raw-body": { "version": "2.5.2", "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz",