diff --git a/.env.example b/.env.example index d4788e71a5..385cbbc7ef 100644 --- a/.env.example +++ b/.env.example @@ -13,12 +13,44 @@ APP_TITLE=LibreChat HOST=localhost PORT=3080 +# Automated Moderation System +# The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions +# like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching +# a set threshold, the user and their IP are temporarily banned. This system ensures platform security +# by monitoring and penalizing rapid or suspicious activities. + +BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged) +BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for +BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold + +# The score for each violation + +LOGIN_VIOLATION_SCORE=1 +REGISTRATION_VIOLATION_SCORE=1 +CONCURRENT_VIOLATION_SCORE=1 +MESSAGE_VIOLATION_SCORE=1 +NON_BROWSER_VIOLATION_SCORE=20 + # Login and registration rate limiting. LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW -LOGIN_WINDOW=5 # in minutes, determines how long an IP is banned for after LOGIN_MAX logins +LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW -REGISTER_WINDOW=60 # in minutes, determines how long an IP is banned for after REGISTER_MAX registrations +REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations + +# Message rate limiting (per user & IP) + +LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request +CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request + +LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages + +# Note: You can utilize both limiters, but default is to limit by IP only. +LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages # Change this to proxy any API request. # It's useful if your machine has difficulty calling the original API server. diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index e5c86caa4c..a005c10efb 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -18,6 +18,9 @@ jobs: JWT_SECRET: ${{ secrets.JWT_SECRET }} CREDS_KEY: ${{ secrets.CREDS_KEY }} CREDS_IV: ${{ secrets.CREDS_IV }} + BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }} + BAN_DURATION: ${{ secrets.BAN_DURATION }} + BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }} NODE_ENV: ci steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index c294b25b48..52ce79baa7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # Logs data-node meili_data +data/ logs *.log diff --git a/README.md b/README.md index bfe3cf4b22..b5fca4c448 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,8 @@ Keep up with the latest updates by visiting the releases page - [Releases](https * [Using official ChatGPT Plugins](docs/features/plugins/chatgpt_plugins_openapi.md) - * [Third-Party Tools](docs/features/third-party.md) + * [Automated Moderation](docs/features/mod_system.md) + * [Third-Party Tools](docs/features/third_party.md) * [Proxy](docs/features/proxy.md) * [Bing Jailbreak](docs/features/bing_jailbreak.md) diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js index 5fe7f1cb36..83bc5e9397 100644 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js +++ b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js @@ -1,7 +1,14 @@ const fs = require('fs'); const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin'); -jest.mock('node-fetch'); +global.fetch = jest.fn().mockImplementationOnce(() => { + return new Promise((resolve) => { + resolve({ + ok: true, + json: () => Promise.resolve({ key: 'value' }), + }); + }); +}); jest.mock('fs', () => ({ promises: { readFile: jest.fn(), diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js new file mode 100644 index 0000000000..f00296d3b3 --- /dev/null +++ b/api/cache/banViolation.js @@ -0,0 +1,68 @@ +const Session = require('../models/Session'); +const getLogStores = require('./getLogStores'); +const { isEnabled, math, removePorts } = require('../server/utils'); +const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; +const interval = math(BAN_INTERVAL, 20); + +/** + * Bans a user based on violation criteria. + * + * If the user's violation count is a multiple of the BAN_INTERVAL, the user will be banned. + * The duration of the ban is determined by the BAN_DURATION environment variable. + * If BAN_DURATION is not set or invalid, the user will not be banned. + * Sessions will be deleted and the refreshToken cookie will be cleared even with + * an invalid or nill duration, which is a "soft" ban; the user can remain active until + * access token expiry. + * + * @async + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {Object} errorMessage - Object containing user violation details. + * @param {string} errorMessage.type - Type of the violation. + * @param {string} errorMessage.user_id - ID of the user who committed the violation. + * @param {number} errorMessage.violation_count - Number of violations committed by the user. + * + * @returns {Promise} + * + */ +const banViolation = async (req, res, errorMessage) => { + if (!isEnabled(BAN_VIOLATIONS)) { + return; + } + + if (!errorMessage) { + return; + } + + const { type, user_id, prev_count, violation_count } = errorMessage; + + const prevThreshold = Math.floor(prev_count / interval); + const currentThreshold = Math.floor(violation_count / interval); + + if (prevThreshold >= currentThreshold) { + return; + } + + await Session.deleteAllUserSessions(user_id); + res.clearCookie('refreshToken'); + + const banLogs = getLogStores('ban'); + const duration = banLogs.opts.ttl; + + if (duration <= 0) { + return; + } + + req.ip = removePorts(req); + console.log(`[BAN] Banning user ${user_id} @ ${req.ip} for ${duration / 1000 / 60} minutes`); + const expiresAt = Date.now() + duration; + await banLogs.set(user_id, { type, violation_count, duration, expiresAt }); + await banLogs.set(req.ip, { type, user_id, violation_count, duration, expiresAt }); + + errorMessage.ban = true; + errorMessage.ban_duration = duration; + + return; +}; + +module.exports = banViolation; diff --git a/api/cache/banViolation.spec.js b/api/cache/banViolation.spec.js new file mode 100644 index 0000000000..ba8e78a1ed --- /dev/null +++ b/api/cache/banViolation.spec.js @@ -0,0 +1,155 @@ +const banViolation = require('./banViolation'); + +jest.mock('keyv'); +jest.mock('../models/Session'); +// Mocking the getLogStores function +jest.mock('./getLogStores', () => { + return jest.fn().mockImplementation(() => { + const EventEmitter = require('events'); + const math = require('../server/utils/math'); + const mockGet = jest.fn(); + const mockSet = jest.fn(); + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = mockGet; + set = mockSet; + } + + return new KeyvMongo('', { + namespace: 'bans', + ttl: math(process.env.BAN_DURATION, 7200000), + }); + }); +}); + +describe('banViolation', () => { + let req, res, errorMessage; + + beforeEach(() => { + req = { + ip: '127.0.0.1', + cookies: { + refreshToken: 'someToken', + }, + }; + res = { + clearCookie: jest.fn(), + }; + errorMessage = { + type: 'someViolation', + user_id: '12345', + prev_count: 0, + violation_count: 0, + }; + process.env.BAN_VIOLATIONS = 'true'; + process.env.BAN_DURATION = '7200000'; // 2 hours in ms + process.env.BAN_INTERVAL = '20'; + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should not ban if BAN_VIOLATIONS are not enabled', async () => { + process.env.BAN_VIOLATIONS = 'false'; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('should not ban if errorMessage is not provided', async () => { + await banViolation(req, res, null); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[1/3] should ban if violation_count crosses the interval threshold: 19 -> 39', async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('[2/3] should ban if violation_count crosses the interval threshold: 19 -> 20', async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = 20; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + const randomValueAbove = Math.floor(20 + Math.random() * 100); + it(`[3/3] should ban if violation_count crosses the interval threshold: 19 -> ${randomValueAbove}`, async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = randomValueAbove; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should handle invalid BAN_INTERVAL and default to 20', async () => { + process.env.BAN_INTERVAL = 'invalid'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should ban if BAN_DURATION is invalid as default is 2 hours', async () => { + process.env.BAN_DURATION = 'invalid'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should not ban if BAN_DURATION is 0 but should clear cookies', async () => { + process.env.BAN_DURATION = '0'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(res.clearCookie).toHaveBeenCalledWith('refreshToken'); + }); + + it('should not ban if violation_count does not change', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = 0; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[1/2] should not ban if violation_count does not cross the interval threshold: 0 -> 19', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = 19; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + const randomValueUnder = Math.floor(1 + Math.random() * 19); + it(`[2/2] should not ban if violation_count does not cross the interval threshold: 0 -> ${randomValueUnder}`, async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = randomValueUnder; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[EDGE CASE] should not ban if violation_count is lower', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = -10; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); +}); diff --git a/api/cache/clearPendingReq.js b/api/cache/clearPendingReq.js new file mode 100644 index 0000000000..d31d51d78a --- /dev/null +++ b/api/cache/clearPendingReq.js @@ -0,0 +1,29 @@ +const Keyv = require('keyv'); +const { pendingReqFile } = require('./keyvFiles'); +const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {}; + +const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' }); + +/** + * Clear pending requests from the cache. + * Checks the environmental variable LIMIT_CONCURRENT_MESSAGES; + * if the rule is enabled ('true'), pending requests in the cache are cleared. + * + * @module clearPendingReq + * @requires keyv + * @requires keyvFiles + * @requires process + * + * @async + * @function + * @returns {Promise} A promise that either clears 'pendingRequests' from store or resolves with no value. + */ +const clearPendingReq = async () => { + if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') { + return; + } + + await keyv.clear(); +}; + +module.exports = clearPendingReq; diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js new file mode 100644 index 0000000000..5bc703fe5e --- /dev/null +++ b/api/cache/getLogStores.js @@ -0,0 +1,40 @@ +const Keyv = require('keyv'); +const keyvMongo = require('./keyvMongo'); +const { math } = require('../server/utils'); +const { logFile, violationFile } = require('./keyvFiles'); +const { BAN_DURATION } = process.env ?? {}; + +const duration = math(BAN_DURATION, 7200000); + +const namespaces = { + ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }), + general: new Keyv({ store: logFile, namespace: 'violations' }), + concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }), + non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }), + message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }), + registrations: new Keyv({ store: violationFile, namespace: 'registrations' }), + logins: new Keyv({ store: violationFile, namespace: 'logins' }), +}; + +/** + * Returns either the logs of violations specified by type if a type is provided + * or it returns the general log if no type is specified. If an invalid type is passed, + * an error will be thrown. + * + * @module getLogStores + * @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters. + * @requires keyvFiles - a module that includes the logFile and violationFile. + * + * @param {string} type - The type of violation, which can be 'concurrent', 'message_limit', 'registrations' or 'logins'. + * @returns {Keyv} - If a valid type is passed, returns an object containing the logs for violations of the specified type. + * @throws Will throw an error if an invalid violation type is passed. + */ +const getLogStores = (type) => { + if (!type) { + throw new Error(`Invalid store type: ${type}`); + } + const logs = namespaces[type]; + return logs; +}; + +module.exports = getLogStores; diff --git a/api/cache/index.js b/api/cache/index.js new file mode 100644 index 0000000000..1edbf981d9 --- /dev/null +++ b/api/cache/index.js @@ -0,0 +1,6 @@ +const keyvFiles = require('./keyvFiles'); +const getLogStores = require('./getLogStores'); +const logViolation = require('./logViolation'); +const clearPendingReq = require('./clearPendingReq'); + +module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq }; diff --git a/api/cache/keyvFiles.js b/api/cache/keyvFiles.js new file mode 100644 index 0000000000..f969174b7d --- /dev/null +++ b/api/cache/keyvFiles.js @@ -0,0 +1,11 @@ +const { KeyvFile } = require('keyv-file'); + +const logFile = new KeyvFile({ filename: './data/logs.json' }); +const pendingReqFile = new KeyvFile({ filename: './data/pendingReqCache.json' }); +const violationFile = new KeyvFile({ filename: './data/violations.json' }); + +module.exports = { + logFile, + pendingReqFile, + violationFile, +}; diff --git a/api/cache/keyvMongo.js b/api/cache/keyvMongo.js new file mode 100644 index 0000000000..429329adc6 --- /dev/null +++ b/api/cache/keyvMongo.js @@ -0,0 +1,7 @@ +const KeyvMongo = require('@keyv/mongo'); +const { MONGO_URI } = process.env ?? {}; + +const keyvMongo = new KeyvMongo(MONGO_URI, { collection: 'logs' }); +keyvMongo.on('error', (err) => console.error('KeyvMongo connection error:', err)); + +module.exports = keyvMongo; diff --git a/api/cache/logViolation.js b/api/cache/logViolation.js new file mode 100644 index 0000000000..0e35cf1859 --- /dev/null +++ b/api/cache/logViolation.js @@ -0,0 +1,36 @@ +const getLogStores = require('./getLogStores'); +const banViolation = require('./banViolation'); + +/** + * Logs the violation. + * + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {string} type - The type of violation. + * @param {Object} errorMessage - The error message to log. + * @param {number} [score=1] - The severity of the violation. Defaults to 1 + */ +const logViolation = async (req, res, type, errorMessage, score = 1) => { + const userId = req.user?.id ?? req.user?._id; + if (!userId) { + return; + } + const logs = getLogStores('general'); + const violationLogs = getLogStores(type); + + const userViolations = (await violationLogs.get(userId)) ?? 0; + const violationCount = userViolations + score; + await violationLogs.set(userId, violationCount); + + errorMessage.user_id = userId; + errorMessage.prev_count = userViolations; + errorMessage.violation_count = violationCount; + errorMessage.date = new Date().toISOString(); + + await banViolation(req, res, errorMessage); + const userLogs = (await logs.get(userId)) ?? []; + userLogs.push(errorMessage); + await logs.set(userId, userLogs); +}; + +module.exports = logViolation; diff --git a/api/jest.config.js b/api/jest.config.js index a877e75980..a2147b2216 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -3,5 +3,5 @@ module.exports = { clearMocks: true, roots: [''], coverageDirectory: 'coverage', - setupFiles: ['./test/jestSetup.js'], + setupFiles: ['./test/jestSetup.js', './test/__mocks__/KeyvMongo.js'], }; diff --git a/api/lib/db/index.js b/api/lib/db/index.js new file mode 100644 index 0000000000..fa7a460d05 --- /dev/null +++ b/api/lib/db/index.js @@ -0,0 +1,4 @@ +const connectDb = require('./connectDb'); +const indexSync = require('./indexSync'); + +module.exports = { connectDb, indexSync }; diff --git a/api/models/Message.js b/api/models/Message.js index 64dba3dbfb..adcdd9e56d 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -1,5 +1,8 @@ +const { z } = require('zod'); const Message = require('./schema/messageSchema'); +const idSchema = z.string().uuid(); + module.exports = { Message, @@ -22,8 +25,9 @@ module.exports = { model = null, }) { try { - if (!conversationId) { - return console.log('Message not saved: no conversationId'); + const validConvoId = idSchema.safeParse(conversationId); + if (!validConvoId.success) { + return; } // may also need to update the conversation here await Message.findOneAndUpdate( diff --git a/api/models/Session.js b/api/models/Session.js index e1b9898bb9..697fa66343 100644 --- a/api/models/Session.js +++ b/api/models/Session.js @@ -54,6 +54,21 @@ sessionSchema.methods.generateRefreshToken = async function () { } }; +sessionSchema.statics.deleteAllUserSessions = async function (userId) { + try { + if (!userId) { + return; + } + const result = await this.deleteMany({ user: userId }); + if (result && result?.deletedCount > 0) { + console.log(`Deleted ${result.deletedCount} sessions for user ${userId}.`); + } + } catch (error) { + console.log('Error in deleting user sessions:', error); + throw error; + } +}; + const Session = mongoose.model('Session', sessionSchema); module.exports = Session; diff --git a/api/package.json b/api/package.json index d548a74fa8..5d1d1d6029 100644 --- a/api/package.json +++ b/api/package.json @@ -6,6 +6,7 @@ "start": "echo 'please run this from the root directory'", "server-dev": "echo 'please run this from the root directory'", "test": "cross-env NODE_ENV=test jest", + "b:test": "NODE_ENV=test bun jest", "test:ci": "jest --ci" }, "repository": { @@ -45,7 +46,7 @@ "joi": "^17.9.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", - "keyv": "^4.5.2", + "keyv": "^4.5.3", "keyv-file": "^0.2.0", "langchain": "^0.0.144", "lodash": "^4.17.21", @@ -64,6 +65,7 @@ "pino": "^8.12.1", "sanitize": "^2.1.2", "sharp": "^0.32.5", + "ua-parser-js": "^1.0.36", "zod": "^3.22.2" }, "devDependencies": { diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index fbf85bec22..361c3464be 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -80,7 +80,7 @@ const refreshController = async (req, res) => { const userId = payload.id; const user = await User.findOne({ _id: userId }); if (!user) { - return res.status(401).send('User not found'); + return res.status(401).redirect('/login'); } if (process.env.NODE_ENV === 'development') { @@ -99,6 +99,8 @@ const refreshController = async (req, res) => { const token = await setAuthTokens(userId, res, session._id); const userObj = user.toJSON(); res.status(200).send({ token, user: userObj }); + } else if (payload.exp > Date.now() / 1000) { + res.status(403).redirect('/login'); } else { res.status(401).send('Refresh token expired or not found for this user'); } diff --git a/api/server/index.js b/api/server/index.js index 8a6f438197..496f0ac420 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -1,16 +1,16 @@ const express = require('express'); const mongoSanitize = require('express-mongo-sanitize'); -const connectDb = require('../lib/db/connectDb'); -const indexSync = require('../lib/db/indexSync'); +const { connectDb, indexSync } = require('../lib/db'); const path = require('path'); const cors = require('cors'); const routes = require('./routes'); const errorController = require('./controllers/ErrorController'); const passport = require('passport'); const configureSocialLogins = require('./socialLogins'); +const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {}; -const port = Number(process.env.PORT) || 3080; -const host = process.env.HOST || 'localhost'; +const port = Number(PORT) || 3080; +const host = HOST || 'localhost'; const projectPath = path.join(__dirname, '..', '..', 'client'); const { jwtLogin, passportLogin } = require('../strategies'); @@ -31,7 +31,7 @@ const startServer = async () => { app.set('trust proxy', 1); // trust first proxy app.use(cors()); - if (!process.env.ALLOW_SOCIAL_LOGIN) { + if (!ALLOW_SOCIAL_LOGIN) { console.warn( 'Social logins are disabled. Set Envrionment Variable "ALLOW_SOCIAL_LOGIN" to true to enable them.', ); @@ -42,7 +42,7 @@ const startServer = async () => { passport.use(await jwtLogin()); passport.use(passportLogin()); - if (process.env.ALLOW_SOCIAL_LOGIN === 'true') { + if (ALLOW_SOCIAL_LOGIN?.toLowerCase() === 'true') { configureSocialLogins(app); } diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 54e11b7f5d..80cf26ba48 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,6 +1,5 @@ -const crypto = require('crypto'); const { saveMessage, getConvo, getConvoTitle } = require('../../models'); -const { sendMessage, handleError } = require('../utils'); +const { sendMessage, sendError } = require('../utils'); const abortControllers = require('./abortControllers'); async function abortMessage(req, res) { @@ -27,8 +26,9 @@ const handleAbort = () => { }; }; -const createAbortController = (res, req, endpointOption, getAbortData) => { +const createAbortController = (req, res, getAbortData) => { const abortController = new AbortController(); + const { endpointOption } = req.body; const onStart = (userMessage) => { sendMessage(res, { message: userMessage, created: true }); const abortKey = userMessage?.conversationId ?? req.user.id; @@ -73,25 +73,23 @@ const handleAbortError = async (res, req, error, data) => { const { sender, conversationId, messageId, parentMessageId, partialText } = data; const respondWithError = async () => { - const errorMessage = { + const options = { sender, - messageId: messageId ?? crypto.randomUUID(), + messageId, conversationId, parentMessageId, - unfinished: false, - cancelled: false, - error: true, - final: true, text: error.message, - isCreatedByUser: false, + shouldSaveMessage: true, }; - if (abortControllers.has(conversationId)) { - const { abortController } = abortControllers.get(conversationId); - abortController.abort(); - abortControllers.delete(conversationId); - } - await saveMessage(errorMessage); - handleError(res, errorMessage); + const callback = async () => { + if (abortControllers.has(conversationId)) { + const { abortController } = abortControllers.get(conversationId); + abortController.abort(); + abortControllers.delete(conversationId); + } + }; + + await sendError(res, options, callback); }; if (partialText && partialText.length > 5) { diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js new file mode 100644 index 0000000000..294f4a668b --- /dev/null +++ b/api/server/middleware/checkBan.js @@ -0,0 +1,92 @@ +const Keyv = require('keyv'); +const uap = require('ua-parser-js'); +const { getLogStores } = require('../../cache'); +const denyRequest = require('./denyRequest'); +const { isEnabled, removePorts } = require('../utils'); + +const banCache = new Keyv({ namespace: 'bans', ttl: 0 }); +const message = 'Your account has been temporarily banned due to violations of our service.'; + +/** + * Respond to the request if the user is banned. + * + * @async + * @function + * @param {Object} req - Express Request object. + * @param {Object} res - Express Response object. + * @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request. + * + * @returns {Promise} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function. + */ +const banResponse = async (req, res) => { + const ua = uap(req.headers['user-agent']); + const { baseUrl } = req; + if (!ua.browser.name) { + return res.status(403).json({ message }); + } else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') { + return await denyRequest(req, res, { type: 'ban' }); + } + + return res.status(403).json({ message }); +}; + +/** + * Checks if the source IP or user is banned or not. + * + * @async + * @function + * @param {Object} req - Express request object. + * @param {Object} res - Express response object. + * @param {Function} next - Next middleware function. + * + * @returns {Promise} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`. + */ +const checkBan = async (req, res, next = () => {}) => { + const { BAN_VIOLATIONS } = process.env ?? {}; + + if (!isEnabled(BAN_VIOLATIONS)) { + return next(); + } + + req.ip = removePorts(req); + const userId = req.user?.id ?? req.user?._id ?? null; + + const cachedIPBan = await banCache.get(req.ip); + const cachedUserBan = await banCache.get(userId); + const cachedBan = cachedIPBan || cachedUserBan; + + if (cachedBan) { + req.banned = true; + return await banResponse(req, res); + } + + const banLogs = getLogStores('ban'); + const duration = banLogs.opts.ttl; + + if (duration <= 0) { + return next(); + } + + const ipBan = await banLogs.get(req.ip); + const userBan = await banLogs.get(userId); + const isBanned = ipBan || userBan; + + if (!isBanned) { + return next(); + } + + const timeLeft = Number(isBanned.expiresAt) - Date.now(); + + if (timeLeft <= 0) { + await banLogs.delete(req.ip); + await banLogs.delete(userId); + return next(); + } + + banCache.set(req.ip, isBanned, timeLeft); + banCache.set(userId, isBanned, timeLeft); + req.banned = true; + return await banResponse(req, res); +}; + +module.exports = checkBan; diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js new file mode 100644 index 0000000000..d110b1b86f --- /dev/null +++ b/api/server/middleware/concurrentLimiter.js @@ -0,0 +1,81 @@ +const Keyv = require('keyv'); +const { logViolation } = require('../../cache'); + +const denyRequest = require('./denyRequest'); + +// Serve cache from memory so no need to clear it on startup/exit +const pendingReqCache = new Keyv({ namespace: 'pendingRequests' }); + +/** + * Middleware to limit concurrent requests for a user. + * + * This middleware checks if a user has exceeded a specified concurrent request limit. + * If the user exceeds the limit, an error is returned. If the user is within the limit, + * their request count is incremented. After the request is processed, the count is decremented. + * If the `pendingReqCache` store is not available, the middleware will skip its logic. + * + * @function + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {function} next - Express next middleware function. + * @throws {Error} Throws an error if the user exceeds the concurrent request limit. + */ +const concurrentLimiter = async (req, res, next) => { + if (!pendingReqCache) { + return next(); + } + + if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) { + return next(); + } + + const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score } = process.env; + const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); + const type = 'concurrent'; + + const userId = req.user?.id ?? req.user?._id ?? null; + const pendingRequests = (await pendingReqCache.get(userId)) ?? 0; + + if (pendingRequests >= limit) { + const errorMessage = { + type, + limit, + pendingRequests, + }; + + await logViolation(req, res, type, errorMessage, score); + return await denyRequest(req, res, errorMessage); + } else { + await pendingReqCache.set(userId, pendingRequests + 1); + } + + // Ensure the requests are removed from the store once the request is done + const cleanUp = async () => { + if (!pendingReqCache) { + return; + } + + const currentRequests = await pendingReqCache.get(userId); + + if (currentRequests && currentRequests >= 1) { + await pendingReqCache.set(userId, currentRequests - 1); + } else { + await pendingReqCache.delete(userId); + } + }; + + if (pendingRequests < limit) { + res.on('finish', cleanUp); + res.on('close', cleanUp); + } + + next(); +}; + +// if cache is not served from memory, clear it on exit +// process.on('exit', async () => { +// console.log('Clearing all pending requests before exiting...'); +// await pendingReqCache.clear(); +// }); + +module.exports = concurrentLimiter; diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js new file mode 100644 index 0000000000..64ca86c63d --- /dev/null +++ b/api/server/middleware/denyRequest.js @@ -0,0 +1,58 @@ +const crypto = require('crypto'); +const { sendMessage, sendError } = require('../utils'); +const { getResponseSender } = require('../routes/endpoints/schemas'); +const { saveMessage } = require('../../models'); + +/** + * Denies a request by sending an error message and optionally saves the user's message. + * + * @async + * @function + * @param {Object} req - Express request object. + * @param {Object} req.body - The body of the request. + * @param {string} [req.body.messageId] - The ID of the message. + * @param {string} [req.body.conversationId] - The ID of the conversation. + * @param {string} [req.body.parentMessageId] - The ID of the parent message. + * @param {string} req.body.text - The text of the message. + * @param {Object} res - Express response object. + * @param {string} errorMessage - The error message to be sent. + * @returns {Promise} A promise that resolves with the error response. + * @throws {Error} Throws an error if there's an issue saving the message or sending the error. + */ +const denyRequest = async (req, res, errorMessage) => { + let responseText = errorMessage; + if (typeof errorMessage === 'object') { + responseText = JSON.stringify(errorMessage); + } + + const { messageId, conversationId: _convoId, parentMessageId, text } = req.body; + const conversationId = _convoId ?? crypto.randomUUID(); + + const userMessage = { + sender: 'User', + messageId: messageId ?? crypto.randomUUID(), + parentMessageId, + conversationId, + isCreatedByUser: true, + text, + }; + sendMessage(res, { message: userMessage, created: true }); + + const shouldSaveMessage = + _convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000'; + + if (shouldSaveMessage) { + await saveMessage(userMessage); + } + + return await sendError(res, { + sender: getResponseSender(req.body), + messageId: crypto.randomUUID(), + conversationId, + parentMessageId: userMessage.messageId, + text: responseText, + shouldSaveMessage, + }); +}; + +module.exports = denyRequest; diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index eb1f538704..553f2c663a 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -1,22 +1,30 @@ const abortMiddleware = require('./abortMiddleware'); +const checkBan = require('./checkBan'); +const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); const requireJwtAuth = require('./requireJwtAuth'); const registerLimiter = require('./registerLimiter'); +const messageLimiters = require('./messageLimiters'); const requireLocalAuth = require('./requireLocalAuth'); const validateEndpoint = require('./validateEndpoint'); +const concurrentLimiter = require('./concurrentLimiter'); const validateMessageReq = require('./validateMessageReq'); const buildEndpointOption = require('./buildEndpointOption'); const validateRegistration = require('./validateRegistration'); module.exports = { ...abortMiddleware, + ...messageLimiters, + checkBan, + uaParser, setHeaders, loginLimiter, requireJwtAuth, registerLimiter, requireLocalAuth, validateEndpoint, + concurrentLimiter, validateMessageReq, buildEndpointOption, validateRegistration, diff --git a/api/server/middleware/loginLimiter.js b/api/server/middleware/loginLimiter.js index bca07d0a7a..bdc95e2878 100644 --- a/api/server/middleware/loginLimiter.js +++ b/api/server/middleware/loginLimiter.js @@ -1,16 +1,30 @@ const rateLimit = require('express-rate-limit'); -const windowMs = (process.env?.LOGIN_WINDOW ?? 5) * 60 * 1000; // default: 5 minutes -const max = process.env?.LOGIN_MAX ?? 7; // default: limit each IP to 7 requests per windowMs +const { logViolation } = require('../../cache'); +const { removePorts } = require('../utils'); + +const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; +const windowMs = LOGIN_WINDOW * 60 * 1000; +const max = LOGIN_MAX; const windowInMinutes = windowMs / 60000; +const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`; + +const handler = async (req, res) => { + const type = 'logins'; + const errorMessage = { + type, + max, + windowInMinutes, + }; + + await logViolation(req, res, type, errorMessage, score); + return res.status(429).json({ message }); +}; const loginLimiter = rateLimit({ windowMs, max, - message: `Too many login attempts from this IP, please try again after ${windowInMinutes} minutes.`, - keyGenerator: function (req) { - // Strip out the port number from the IP address - return req.ip.replace(/:\d+[^:]*$/, ''); - }, + handler, + keyGenerator: removePorts, }); module.exports = loginLimiter; diff --git a/api/server/middleware/messageLimiters.js b/api/server/middleware/messageLimiters.js new file mode 100644 index 0000000000..63bac7e181 --- /dev/null +++ b/api/server/middleware/messageLimiters.js @@ -0,0 +1,67 @@ +const rateLimit = require('express-rate-limit'); +const { logViolation } = require('../../cache'); +const denyRequest = require('./denyRequest'); + +const { + MESSAGE_IP_MAX = 40, + MESSAGE_IP_WINDOW = 1, + MESSAGE_USER_MAX = 40, + MESSAGE_USER_WINDOW = 1, +} = process.env; + +const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000; +const ipMax = MESSAGE_IP_MAX; +const ipWindowInMinutes = ipWindowMs / 60000; + +const userWindowMs = MESSAGE_USER_WINDOW * 60 * 1000; +const userMax = MESSAGE_USER_MAX; +const userWindowInMinutes = userWindowMs / 60000; + +/** + * Creates either an IP/User message request rate limiter for excessive requests + * that properly logs and denies the violation. + * + * @param {boolean} [ip=true] - Whether to create an IP limiter or a user limiter. + * @returns {function} A rate limiter function. + * + */ +const createHandler = (ip = true) => { + return async (req, res) => { + const type = 'message_limit'; + const errorMessage = { + type, + max: ip ? ipMax : userMax, + limiter: ip ? 'ip' : 'user', + windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes, + }; + + await logViolation(req, res, type, errorMessage); + return await denyRequest(req, res, errorMessage); + }; +}; + +/** + * Message request rate limiter by IP + */ +const messageIpLimiter = rateLimit({ + windowMs: ipWindowMs, + max: ipMax, + handler: createHandler(), +}); + +/** + * Message request rate limiter by userId + */ +const messageUserLimiter = rateLimit({ + windowMs: userWindowMs, + max: userMax, + handler: createHandler(false), + keyGenerator: function (req) { + return req.user?.id; // Use the user ID or NULL if not available + }, +}); + +module.exports = { + messageIpLimiter, + messageUserLimiter, +}; diff --git a/api/server/middleware/registerLimiter.js b/api/server/middleware/registerLimiter.js index df2d3d1ca8..e19e261cbe 100644 --- a/api/server/middleware/registerLimiter.js +++ b/api/server/middleware/registerLimiter.js @@ -1,16 +1,30 @@ const rateLimit = require('express-rate-limit'); -const windowMs = (process.env?.REGISTER_WINDOW ?? 60) * 60 * 1000; // default: 1 hour -const max = process.env?.REGISTER_MAX ?? 5; // default: limit each IP to 5 registrations per windowMs +const { logViolation } = require('../../cache'); +const { removePorts } = require('../utils'); + +const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; +const windowMs = REGISTER_WINDOW * 60 * 1000; +const max = REGISTER_MAX; const windowInMinutes = windowMs / 60000; +const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`; + +const handler = async (req, res) => { + const type = 'registrations'; + const errorMessage = { + type, + max, + windowInMinutes, + }; + + await logViolation(req, res, type, errorMessage, score); + return res.status(429).json({ message }); +}; const registerLimiter = rateLimit({ windowMs, max, - message: `Too many accounts created from this IP, please try again after ${windowInMinutes} minutes`, - keyGenerator: function (req) { - // Strip out the port number from the IP address - return req.ip.replace(/:\d+[^:]*$/, ''); - }, + handler, + keyGenerator: removePorts, }); module.exports = registerLimiter; diff --git a/api/server/middleware/uaParser.js b/api/server/middleware/uaParser.js new file mode 100644 index 0000000000..f5b726dd3a --- /dev/null +++ b/api/server/middleware/uaParser.js @@ -0,0 +1,31 @@ +const uap = require('ua-parser-js'); +const { handleError } = require('../utils'); +const { logViolation } = require('../../cache'); + +/** + * Middleware to parse User-Agent header and check if it's from a recognized browser. + * If the User-Agent is not recognized as a browser, logs a violation and sends an error response. + * + * @function + * @async + * @param {Object} req - Express request object. + * @param {Object} res - Express response object. + * @param {Function} next - Express next middleware function. + * @returns {void} Sends an error response if the User-Agent is not recognized as a browser. + * + * @example + * app.use(uaParser); + */ +async function uaParser(req, res, next) { + const { NON_BROWSER_VIOLATION_SCORE: score = 20 } = process.env; + const ua = uap(req.headers['user-agent']); + + if (!ua.browser.name) { + const type = 'non_browser'; + await logViolation(req, res, type, { type }, score); + return handleError(res, { message: 'Illegal request' }); + } + next(); +} + +module.exports = uaParser; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index 637fc090aa..673fd185db 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -7,138 +7,125 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { sendMessage, createOnProgress } = require('../../utils'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = data.userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; - - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - getIds, - // debug: true, - user: req.user.id, - conversationId, - parentMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId ?? userMessageId, - }), - onStart, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - await saveConvo(req.user.id, { - ...endpointOption, - ...endpointOption.modelOptions, - conversationId, - endpoint: 'anthropic', - }); - - await saveMessage(response); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = data.userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; } - }, -); + }; + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + getIds, + // debug: true, + user: req.user.id, + conversationId, + parentMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + onStart, + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + await saveConvo(req.user.id, { + ...endpointOption, + ...endpointOption.modelOptions, + conversationId, + endpoint: 'anthropic', + }); + + await saveMessage(response); + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index b590314f90..1c916265ca 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -4,9 +4,9 @@ const router = express.Router(); const { browserClient } = require('../../../app/'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index 10bf55442e..f3047c2857 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -4,9 +4,9 @@ const router = express.Router(); const { titleConvoBing, askBing } = require('../../../app'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index aa77c3129f..5742120b0d 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -5,9 +5,9 @@ const { GoogleClient } = require('../../../app'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress } = require('../../utils'); const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; if (text.length === 0) { return handleError(res, { text: 'Prompt empty or too short' }); diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 890228cefc..330f9404da 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -11,218 +11,205 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const newConvo = !conversationId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; - const plugins = []; + const plugins = []; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; + } + }; - let streaming = null; - let timer = null; + let streaming = null; + let timer = null; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (timer) { - clearTimeout(timer); - } - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - plugins, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { plugins }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; - - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage(res, { plugins }, extraTokens); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; + if (timer) { + clearTimeout(timer); } - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const onChainEnd = () => { - saveMessage(userMessage); - sendIntermediateMessage(res, { plugins }); - }; - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - }); - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - conversationId, - parentMessageId, - overrideParentMessageId, - getIds, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - addMetadata, - getPartialText, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, plugins, - }), - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } - - console.log('CLIENT RESPONSE'); - console.dir(response, { depth: null }); - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - addTitle(req, { - text, - response, - client, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + + if (saveDelay < 500) { + saveDelay = 500; + } + + streaming = new Promise((resolve) => { + timer = setTimeout(() => { + resolve(); + }, 250); + }); + }, + }); + + const pluginMap = new Map(); + const onAgentAction = async (action, runId) => { + pluginMap.set(runId, action.tool); + sendIntermediateMessage(res, { plugins }); + }; + + const onToolStart = async (tool, input, runId, parentRunId) => { + const pluginName = pluginMap.get(parentRunId); + const latestPlugin = { + runId, + loading: true, + inputs: [input], + latest: pluginName, + outputs: null, + }; + + if (streaming) { + await streaming; + } + const extraTokens = ':::plugin:::\n'; + plugins.push(latestPlugin); + sendIntermediateMessage(res, { plugins }, extraTokens); + }; + + const onToolEnd = async (output, runId) => { + if (streaming) { + await streaming; + } + + const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); + + if (pluginIndex !== -1) { + plugins[pluginIndex].loading = false; + plugins[pluginIndex].outputs = output; + } + }; + + const onChainEnd = () => { + saveMessage(userMessage); + sendIntermediateMessage(res, { plugins }); + }; + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugins: plugins.map((p) => ({ ...p, loading: false })), + userMessage, + }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + conversationId, + parentMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onToolStart, + onToolEnd, + onStart, + addMetadata, + getPartialText, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + plugins, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugins = plugins.map((p) => ({ ...p, loading: false })); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { + addTitle(req, { + text, + response, + client, }); } - }, -); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index 77da50f68a..d87daa6a8c 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -6,6 +6,33 @@ const bingAI = require('./bingAI'); const gptPlugins = require('./gptPlugins'); const askChatGPTBrowser = require('./askChatGPTBrowser'); const anthropic = require('./anthropic'); +const { + uaParser, + checkBan, + requireJwtAuth, + concurrentLimiter, + messageIpLimiter, + messageUserLimiter, +} = require('../../middleware'); +const { isEnabled } = require('../../utils'); + +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); + +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + router.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + router.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + router.use(messageUserLimiter); +} router.use(['/azureOpenAI', '/openAI'], openAI); router.use('/google', google); diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 5e5f445080..fb662809d8 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -9,151 +9,138 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const newConvo = !conversationId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; - const addMetadata = (data) => (metadata = data); + const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; + } + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - parentMessageId, - conversationId, - overrideParentMessageId, - getIds, - onStart, - addMetadata, - abortController, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - }), - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } - - console.log( - 'promptTokens, completionTokens:', - response.promptTokens, - response.completionTokens, - ); - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - addTitle(req, { - text, - response, - client, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + parentMessageId, + conversationId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { + addTitle(req, { + text, + response, + client, }); } - }, -); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 1ccbcb34b5..862a098fa5 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -8,6 +8,7 @@ const { const { loginController } = require('../controllers/auth/LoginController'); const { logoutController } = require('../controllers/auth/LogoutController'); const { + checkBan, loginLimiter, registerLimiter, requireJwtAuth, @@ -19,9 +20,9 @@ const router = express.Router(); //Local router.post('/logout', requireJwtAuth, logoutController); -router.post('/login', loginLimiter, requireLocalAuth, loginController); +router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController); router.post('/refresh', refreshController); -router.post('/register', registerLimiter, validateRegistration, registrationController); +router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController); router.post('/requestPasswordReset', resetPasswordRequestController); router.post('/resetPassword', resetPasswordController); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 66b3ffc0ac..d4b919d309 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -4,12 +4,14 @@ const { getConvo, saveConvo } = require('../../models'); const { getConvosByPage, deleteConvos } = require('../../models/Conversation'); const requireJwtAuth = require('../middleware/requireJwtAuth'); -router.get('/', requireJwtAuth, async (req, res) => { +router.use(requireJwtAuth); + +router.get('/', async (req, res) => { const pageNumber = req.query.pageNumber || 1; res.status(200).send(await getConvosByPage(req.user.id, pageNumber)); }); -router.get('/:conversationId', requireJwtAuth, async (req, res) => { +router.get('/:conversationId', async (req, res) => { const { conversationId } = req.params; const convo = await getConvo(req.user.id, conversationId); @@ -20,7 +22,7 @@ router.get('/:conversationId', requireJwtAuth, async (req, res) => { } }); -router.post('/clear', requireJwtAuth, async (req, res) => { +router.post('/clear', async (req, res) => { let filter = {}; const { conversationId, source } = req.body.arg; if (conversationId) { @@ -43,7 +45,7 @@ router.post('/clear', requireJwtAuth, async (req, res) => { } }); -router.post('/update', requireJwtAuth, async (req, res) => { +router.post('/update', async (req, res) => { const update = req.body.arg; try { diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 0fe67fb56a..5695d67ccb 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -7,140 +7,127 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); const { sendMessage, createOnProgress } = require('../../utils'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user: req.user.id, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId ?? userMessageId, - }), - getIds, - onStart, - addMetadata, - abortController, - }); - - if (metadata) { - response = { ...response, ...metadata }; + text: partialText, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; + if (saveDelay < 500) { + saveDelay = 500; } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); - await saveMessage(response); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + const { abortController, onStart } = createAbortController(req, res, getAbortData); - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + getIds, + onStart, + addMetadata, + abortController, + }); + + if (metadata) { + response = { ...response, ...metadata }; } - }, -); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + await saveMessage(response); + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index c05592323f..b180c844f9 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -10,183 +10,170 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; + const user = req.user.id; - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (plugin.loading === true) { - plugin.loading = false; - } - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start) { - saveMessage(userMessage); + if (plugin.loading === true) { + plugin.loading = false; } - sendIntermediateMessage(res, { plugin }); - // console.log('PLUGIN ACTION', formattedAction); - }; - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage(userMessage); - sendIntermediateMessage(res, { plugin }); - // console.log('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - }); - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getIds, - onAgentAction, - onChainEnd, - onStart, - addMetadata, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - plugin, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - if (metadata) { - response = { ...response, ...metadata }; + if (saveDelay < 500) { + saveDelay = 500; } + }, + }); - console.log('CLIENT RESPONSE'); - console.dir(response, { depth: null }); - response.plugin = { ...plugin, loading: false }; - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const onAgentAction = (action, start = false) => { + const formattedAction = formatAction(action); + plugin.inputs.push(formattedAction); + plugin.latest = formattedAction.plugin; + if (!start) { + saveMessage(userMessage); } - }, -); + sendIntermediateMessage(res, { plugin }); + // console.log('PLUGIN ACTION', formattedAction); + }; + + const onChainEnd = (data) => { + let { intermediateSteps: steps } = data; + plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; + plugin.loading = false; + saveMessage(userMessage); + sendIntermediateMessage(res, { plugin }); + // console.log('CHAIN END', plugin.outputs); + }; + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, + }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onStart, + addMetadata, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugin = { ...plugin, loading: false }; + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index 7eda18b8a1..dcf5ff553b 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -3,11 +3,36 @@ const router = express.Router(); const openAI = require('./openAI'); const gptPlugins = require('./gptPlugins'); const anthropic = require('./anthropic'); -// const google = require('./google'); +const { + checkBan, + uaParser, + requireJwtAuth, + concurrentLimiter, + messageIpLimiter, + messageUserLimiter, +} = require('../../middleware'); +const { isEnabled } = require('../../utils'); + +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); + +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + router.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + router.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + router.use(messageUserLimiter); +} router.use(['/azureOpenAI', '/openAI'], openAI); router.use('/gptPlugins', gptPlugins); router.use('/anthropic', anthropic); -// router.use('/google', google); module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index 1a9dc3ff22..8af7ee2061 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -9,140 +9,127 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user: req.user.id, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getIds, - onStart, - addMetadata, - abortController, - onProgress: progressCallback.call(null, { - res, - text, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, - }), - }); - - if (metadata) { - response = { ...response, ...metadata }; + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - console.log( - 'promptTokens, completionTokens:', - response.promptTokens, - response.completionTokens, - ); - await saveMessage(response); + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); + + if (metadata) { + response = { ...response, ...metadata }; } - }, -); + + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 556603e9ec..f64930c751 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -3,8 +3,24 @@ const express = require('express'); const router = express.Router(); const config = require('../../../config/loader'); const { setAuthTokens } = require('../services/AuthService'); +const { loginLimiter, checkBan } = require('../middleware'); const domains = config.domains; +router.use(loginLimiter); + +const oauthHandler = async (req, res) => { + try { + await checkBan(req, res); + if (req.banned) { + return; + } + await setAuthTokens(req.user._id, res); + res.redirect(domains.client); + } catch (err) { + console.error('Error in setting authentication tokens:', err); + } +}; + /** * Google Routes */ @@ -24,14 +40,7 @@ router.get( session: false, scope: ['openid', 'profile', 'email'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -52,14 +61,7 @@ router.get( scope: ['public_profile'], profileFields: ['id', 'email', 'name'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -76,14 +78,7 @@ router.get( failureMessage: true, session: false, }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -102,14 +97,7 @@ router.get( session: false, scope: ['user:email', 'read:user'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( '/discord', @@ -127,14 +115,7 @@ router.get( session: false, scope: ['identify', 'email'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); module.exports = router; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 4715610d79..3ae18e98c5 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,22 +1,11 @@ const partialRight = require('lodash/partialRight'); -const citationRegex = /\[\^\d+?\^]/g; const { getCitations, citeText } = require('./citations'); +const { sendMessage } = require('./streamResponse'); const cursor = ''; +const citationRegex = /\[\^\d+?\^]/g; const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); -const handleError = (res, message) => { - res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); - res.end(); -}; - -const sendMessage = (res, message, event = 'message') => { - if (message.length === 0) { - return; - } - res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); -}; - const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { let i = 0; let code = ''; @@ -148,10 +137,27 @@ function formatAction(action) { return formattedAction; } +/** + * Checks if the given string value is truthy by comparing it to the string 'true' (case-insensitive). + * + * @function + * @param {string|null|undefined} value - The string value to check. + * @returns {boolean} Returns `true` if the value is a case-insensitive match for the string 'true', otherwise returns `false`. + * @example + * + * isEnabled("True"); // returns true + * isEnabled("TRUE"); // returns true + * isEnabled("false"); // returns false + * isEnabled(null); // returns false + * isEnabled(); // returns false + */ +function isEnabled(value) { + return value?.toLowerCase()?.trim() === 'true'; +} + module.exports = { - handleError, - sendMessage, createOnProgress, + isEnabled, handleText, formatSteps, formatAction, diff --git a/api/server/utils/index.js b/api/server/utils/index.js index e76d5b4365..ba21583f51 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -1,11 +1,17 @@ -const cryptoUtils = require('./crypto'); +const streamResponse = require('./streamResponse'); +const removePorts = require('./removePorts'); const handleText = require('./handleText'); +const cryptoUtils = require('./crypto'); const citations = require('./citations'); const sendEmail = require('./sendEmail'); +const math = require('./math'); module.exports = { + ...streamResponse, ...cryptoUtils, ...handleText, ...citations, + removePorts, sendEmail, + math, }; diff --git a/api/server/utils/math.js b/api/server/utils/math.js new file mode 100644 index 0000000000..12c12c8ccd --- /dev/null +++ b/api/server/utils/math.js @@ -0,0 +1,48 @@ +/** + * Evaluates a mathematical expression provided as a string and returns the result. + * + * If the input is already a number, it returns the number as is. + * If the input is not a string or contains invalid characters, an error is thrown. + * If the evaluated result is not a number, an error is thrown. + * + * @param {string|number} str - The mathematical expression to evaluate, or a number. + * @param {number} [fallbackValue] - The default value to return if the input is not a string or number, or if the evaluated result is not a number. + * + * @returns {number} The result of the evaluated expression or the input number. + * + * @throws {Error} Throws an error if the input is not a string or number, contains invalid characters, or does not evaluate to a number. + */ +function math(str, fallbackValue) { + const fallback = typeof fallbackValue !== 'undefined' && typeof fallbackValue === 'number'; + if (typeof str !== 'string' && typeof str === 'number') { + return str; + } else if (typeof str !== 'string') { + if (fallback) { + return fallbackValue; + } + throw new Error(`str is ${typeof str}, but should be a string`); + } + + const validStr = /^[+\-\d.\s*/%()]+$/.test(str); + + if (!validStr) { + if (fallback) { + return fallbackValue; + } + throw new Error('Invalid characters in string'); + } + + const value = eval(str); + + if (typeof value !== 'number') { + if (fallback) { + return fallbackValue; + } + console.error('str', str); + throw new Error(`str did not evaluate to a number but to a ${typeof value}`); + } + + return value; +} + +module.exports = math; diff --git a/api/server/utils/removePorts.js b/api/server/utils/removePorts.js new file mode 100644 index 0000000000..db3e5e1db8 --- /dev/null +++ b/api/server/utils/removePorts.js @@ -0,0 +1 @@ +module.exports = (req) => req.ip.replace(/:\d+[^:]*$/, ''); diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js new file mode 100644 index 0000000000..26cb0c238d --- /dev/null +++ b/api/server/utils/streamResponse.js @@ -0,0 +1,63 @@ +const crypto = require('crypto'); +const { saveMessage } = require('../../models'); + +/** + * Sends error data in Server Sent Events format and ends the response. + * @param {object} res - The server response. + * @param {string} message - The error message. + */ +const handleError = (res, message) => { + res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); + res.end(); +}; + +/** + * Sends message data in Server Sent Events format. + * @param {object} res - - The server response. + * @param {string} message - The message to be sent. + * @param {string} event - [Optional] The type of event. Default is 'message'. + */ +const sendMessage = (res, message, event = 'message') => { + if (message.length === 0) { + return; + } + res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); +}; + +/** + * Processes an error with provided options, saves the error message and sends a corresponding SSE response + * @async + * @param {object} res - The server response. + * @param {object} options - The options for handling the error containing message properties. + * @param {function} callback - [Optional] The callback function to be executed. + */ +const sendError = async (res, options, callback) => { + const { sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } = options; + const errorMessage = { + sender, + messageId: messageId ?? crypto.randomUUID(), + conversationId, + parentMessageId, + unfinished: false, + cancelled: false, + error: true, + final: true, + text, + isCreatedByUser: false, + }; + if (callback && typeof callback === 'function') { + await callback(); + } + + if (shouldSaveMessage) { + await saveMessage(errorMessage); + } + + handleError(res, errorMessage); +}; + +module.exports = { + handleError, + sendMessage, + sendError, +}; diff --git a/api/test/.env.test.example b/api/test/.env.test.example index e7a3fc48e9..16730f672d 100644 --- a/api/test/.env.test.example +++ b/api/test/.env.test.example @@ -7,3 +7,7 @@ CREDS_IV=cd02538f4be2fa37aba9420b5924389f # For testing the ChatAgent OPENAI_API_KEY=your-api-key + +BAN_VIOLATIONS=true +BAN_DURATION=7200000 +BAN_INTERVAL=20 diff --git a/api/test/__mocks__/KeyvMongo.js b/api/test/__mocks__/KeyvMongo.js new file mode 100644 index 0000000000..f88bc144be --- /dev/null +++ b/api/test/__mocks__/KeyvMongo.js @@ -0,0 +1,30 @@ +const mockGet = jest.fn(); +const mockSet = jest.fn(); + +jest.mock('@keyv/mongo', () => { + const EventEmitter = require('events'); + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = mockGet; + set = mockSet; + } + + return KeyvMongo; +}); diff --git a/client/package.json b/client/package.json index df34c5a264..a335e25e4e 100644 --- a/client/package.json +++ b/client/package.json @@ -9,6 +9,7 @@ "dev": "cross-env NODE_ENV=development dotenv -e ../.env -- vite", "preview-prod": "cross-env NODE_ENV=development dotenv -e ../.env -- vite preview", "test": "cross-env NODE_ENV=test jest --watch", + "b:test": "NODE_ENV=test bun jest --watch", "test:ci": "cross-env NODE_ENV=test jest --ci", "b:build": "NODE_ENV=production bun vite build", "b:dev": "NODE_ENV=development bun vite" diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 63635d84b7..a2ab5c88db 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,4 +1,11 @@ -import type { TConversation, TMessage, TPreset, TMutation } from 'librechat-data-provider'; +import type { + TConversation, + TMessage, + TPreset, + TMutation, + TLoginUser, + TUser, +} from 'librechat-data-provider'; export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void; export type TSetExample = ( @@ -146,3 +153,28 @@ export type TDialogProps = { open: boolean; onOpenChange: (open: boolean) => void; }; + +export type TResError = { + response: { data: { message: string } }; + message: string; +}; + +export type TAuthContext = { + user: TUser | undefined; + token: string | undefined; + isAuthenticated: boolean; + error: string | undefined; + login: (data: TLoginUser) => void; + logout: () => void; +}; + +export type TUserContext = { + user?: TUser | undefined; + token: string | undefined; + isAuthenticated: boolean; + redirect?: string; +}; + +export type TAuthConfig = { + loginRedirect: string; +}; diff --git a/client/src/components/Auth/Login.tsx b/client/src/components/Auth/Login.tsx index 6e2616c24c..f75530b32c 100644 --- a/client/src/components/Auth/Login.tsx +++ b/client/src/components/Auth/Login.tsx @@ -5,6 +5,7 @@ import { useNavigate } from 'react-router-dom'; import { useLocalize } from '~/hooks'; import { useGetStartupConfig } from 'librechat-data-provider'; import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components'; +import { getLoginError } from '~/utils'; function Login() { const { login, error, isAuthenticated } = useAuthContext(); @@ -30,9 +31,7 @@ function Login() { className="relative mt-4 rounded border border-red-400 bg-red-100 px-4 py-3 text-red-700" role="alert" > - {error?.includes('429') - ? localize('com_auth_error_login_rl') - : localize('com_auth_error_login')} + {localize(getLoginError(error))} )} diff --git a/client/src/components/Messages/Content/MessageContent.tsx b/client/src/components/Messages/Content/MessageContent.tsx index 9da0a831ba..879dd60980 100644 --- a/client/src/components/Messages/Content/MessageContent.tsx +++ b/client/src/components/Messages/Content/MessageContent.tsx @@ -1,20 +1,28 @@ import { Fragment } from 'react'; import type { TResPlugin } from 'librechat-data-provider'; import type { TMessageContent, TText, TDisplayProps } from '~/common'; -import { cn, getError } from '~/utils'; +import { useAuthContext } from '~/hooks'; +import { cn, getMessageError } from '~/utils'; import EditMessage from './EditMessage'; import Container from './Container'; import Markdown from './Markdown'; import Plugin from './Plugin'; -// Error Message Component -const ErrorMessage = ({ text }: TText) => ( - -
- {getError(text)} -
-
-); +const ErrorMessage = ({ text }: TText) => { + const { logout } = useAuthContext(); + + if (text.includes('ban')) { + logout(); + return null; + } + return ( + +
+ {getMessageError(text)} +
+
+ ); +}; // Display Message Component const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => ( diff --git a/client/src/hooks/AuthContext.tsx b/client/src/hooks/AuthContext.tsx index 3c6c8aff80..6f67df8d95 100644 --- a/client/src/hooks/AuthContext.tsx +++ b/client/src/hooks/AuthContext.tsx @@ -1,7 +1,7 @@ import { + useMemo, useState, useEffect, - useMemo, ReactNode, useCallback, createContext, @@ -17,33 +17,14 @@ import { useRefreshTokenMutation, TLoginUser, } from 'librechat-data-provider'; +import { TAuthConfig, TUserContext, TAuthContext, TResError } from '~/common'; import { useNavigate } from 'react-router-dom'; +import useTimeout from './useTimeout'; -export type TAuthContext = { - user: TUser | undefined; - token: string | undefined; - isAuthenticated: boolean; - error: string | undefined; - login: (data: TLoginUser) => void; - logout: () => void; -}; - -export type TUserContext = { - user?: TUser | undefined; - token: string | undefined; - isAuthenticated: boolean; - redirect?: string; -}; - -export type TAuthConfig = { - loginRedirect: string; -}; -//@ts-ignore - index expression is not of type number -window['errorTimeout'] = undefined; const AuthContext = createContext(undefined); const AuthContextProvider = ({ - authConfig, + // authConfig, children, }: { authConfig?: TAuthConfig; @@ -61,16 +42,7 @@ const AuthContextProvider = ({ const userQuery = useGetUserQuery({ enabled: !!token }); const refreshToken = useRefreshTokenMutation(); - // This seems to prevent the error flashing issue - const doSetError = (error: string | undefined) => { - if (error) { - console.log(error); - // set timeout to ensure we don't get a flash of the error message - window['errorTimeout'] = setTimeout(() => { - setError(error); - }, 400); - } - }; + const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) }); const setUserContext = useCallback( (userContext: TUserContext) => { @@ -89,19 +61,15 @@ const AuthContextProvider = ({ [navigate], ); - const getCookieValue = (key: string) => { - const keyValue = document.cookie.match('(^|;) ?' + key + '=([^;]*)(;|$)'); - return keyValue ? keyValue[2] : null; - }; - const login = (data: TLoginUser) => { loginUser.mutate(data, { onSuccess: (data: TLoginResponse) => { const { user, token } = data; setUserContext({ token, isAuthenticated: true, user, redirect: '/chat/new' }); }, - onError: (error) => { - doSetError((error as Error).message); + onError: (error: TResError | unknown) => { + const resError = error as TResError; + doSetError(resError.message); navigate('/login', { replace: true }); }, }); @@ -119,6 +87,12 @@ const AuthContextProvider = ({ }, onError: (error) => { doSetError((error as Error).message); + setUserContext({ + token: undefined, + isAuthenticated: false, + user: undefined, + redirect: '/login', + }); }, }); }, [setUserContext, logoutUser]); diff --git a/client/src/hooks/index.ts b/client/src/hooks/index.ts index 746f44cddd..25f2755a91 100644 --- a/client/src/hooks/index.ts +++ b/client/src/hooks/index.ts @@ -2,6 +2,7 @@ export * from './AuthContext'; export * from './ThemeContext'; export * from './ScreenshotContext'; export * from './ApiErrorBoundaryContext'; +export { default as useTimeout } from './useTimeout'; export { default as useUserKey } from './useUserKey'; export { default as useDebounce } from './useDebounce'; export { default as useLocalize } from './useLocalize'; diff --git a/client/src/hooks/useTimeout.tsx b/client/src/hooks/useTimeout.tsx new file mode 100644 index 0000000000..e058e9ca8b --- /dev/null +++ b/client/src/hooks/useTimeout.tsx @@ -0,0 +1,39 @@ +import { useEffect, useRef } from 'react'; + +type TUseTimeoutParams = { + callback: (error: string | number | boolean | null) => void; + delay?: number | undefined; +}; +type TTimeout = ReturnType | null; + +function useTimeout({ callback, delay = 400 }: TUseTimeoutParams) { + const timeout = useRef(null); + + const callOnTimeout = (value: string | undefined) => { + // Clear existing timeout + if (timeout.current !== null) { + clearTimeout(timeout.current); + } + + // Set new timeout + if (value) { + console.log(value); + timeout.current = setTimeout(() => { + callback(value); + }, delay); + } + }; + + // Clear timeout when the component unmounts + useEffect(() => { + return () => { + if (timeout.current !== null) { + clearTimeout(timeout.current); + } + }; + }, []); + + return callOnTimeout; +} + +export default useTimeout; diff --git a/client/src/localization/languages/Eng.tsx b/client/src/localization/languages/Eng.tsx index 937bcd676a..f8bd245317 100644 --- a/client/src/localization/languages/Eng.tsx +++ b/client/src/localization/languages/Eng.tsx @@ -52,7 +52,11 @@ export default { com_auth_error_login: 'Unable to login with the information provided. Please check your credentials and try again.', com_auth_error_login_rl: - 'Too many login attempts from this IP in a short amount of time. Please try again later.', + 'Too many login attempts in a short amount of time. Please try again later.', + com_auth_error_login_ban: + 'Your account has been temporarily banned due to violations of our service.', + com_auth_error_login_server: + 'There was an internal server error. Please wait a few moments and try again.', com_auth_no_account: 'Don\'t have an account?', com_auth_sign_up: 'Sign up', com_auth_sign_in: 'Sign in', diff --git a/client/src/utils/getError.ts b/client/src/utils/getError.ts deleted file mode 100644 index e41cc39519..0000000000 --- a/client/src/utils/getError.ts +++ /dev/null @@ -1,28 +0,0 @@ -const isJson = (str: string) => { - try { - JSON.parse(str); - } catch (e) { - return false; - } - return true; -}; - -const getError = (text: string) => { - const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text; - const match = text.match(/\{[^{}]*\}/); - const jsonString = match ? match[0] : ''; - if (isJson(jsonString)) { - const json = JSON.parse(jsonString); - if (json.code === 'invalid_api_key') { - return 'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.'; - } else if (json.type === 'insufficient_quota') { - return 'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.'; - } else { - return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; - } - } else { - return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; - } -}; - -export default getError; diff --git a/client/src/utils/getLoginError.ts b/client/src/utils/getLoginError.ts new file mode 100644 index 0000000000..6bd3c1ba8d --- /dev/null +++ b/client/src/utils/getLoginError.ts @@ -0,0 +1,18 @@ +const getLoginError = (errorText: string) => { + const defaultError = 'com_auth_error_login'; + if (!errorText) { + return defaultError; + } + + if (errorText?.includes('429')) { + return 'com_auth_error_login_rl'; + } else if (errorText?.includes('403')) { + return 'com_auth_error_login_ban'; + } else if (errorText?.includes('500')) { + return 'com_auth_error_login_server'; + } else { + return defaultError; + } +}; + +export default getLoginError; diff --git a/client/src/utils/getMessageError.ts b/client/src/utils/getMessageError.ts new file mode 100644 index 0000000000..4d2be10e49 --- /dev/null +++ b/client/src/utils/getMessageError.ts @@ -0,0 +1,62 @@ +const isJson = (str: string) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + +type TConcurrent = { + limit: number; +}; + +type TMessageLimit = { + max: number; + windowInMinutes: number; +}; + +const errorMessages = { + ban: 'Your account has been temporarily banned due to violations of our service.', + invalid_api_key: + 'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.', + insufficient_quota: + 'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.', + concurrent: (json: TConcurrent) => { + const { limit } = json; + const plural = limit > 1 ? 's' : ''; + return `Only ${limit} message${plural} at a time. Please allow any other responses to complete before sending another message, or wait one minute.`; + }, + message_limit: (json: TMessageLimit) => { + const { max, windowInMinutes } = json; + const plural = max > 1 ? 's' : ''; + return `You hit the message limit. You have a cap of ${max} message${plural} per ${ + windowInMinutes > 1 ? `${windowInMinutes} minutes` : 'minute' + }.`; + }, +}; + +const getMessageError = (text: string) => { + const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text; + const match = text.match(/\{[^{}]*\}/); + const jsonString = match ? match[0] : ''; + const defaultResponse = `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; + + if (!isJson(jsonString)) { + return defaultResponse; + } + + const json = JSON.parse(jsonString); + const errorKey = json.code || json.type; + const keyExists = errorKey && errorMessages[errorKey]; + + if (keyExists && typeof errorMessages[errorKey] === 'function') { + return errorMessages[errorKey](json); + } else if (keyExists) { + return errorMessages[errorKey]; + } else { + return defaultResponse; + } +}; + +export default getMessageError; diff --git a/client/src/utils/index.ts b/client/src/utils/index.ts index dff114016a..b8a199b2f2 100644 --- a/client/src/utils/index.ts +++ b/client/src/utils/index.ts @@ -2,10 +2,11 @@ import { clsx } from 'clsx'; import { twMerge } from 'tailwind-merge'; export * from './languages'; -export { default as getError } from './getError'; export { default as buildTree } from './buildTree'; +export { default as getLoginError } from './getLoginError'; export { default as cleanupPreset } from './cleanupPreset'; export { default as validateIframe } from './validateIframe'; +export { default as getMessageError } from './getMessageError'; export { default as getLocalStorageItems } from './getLocalStorageItems'; export { default as getDefaultConversation } from './getDefaultConversation'; diff --git a/docs/features/mod_system.md b/docs/features/mod_system.md new file mode 100644 index 0000000000..107c61cd35 --- /dev/null +++ b/docs/features/mod_system.md @@ -0,0 +1,67 @@ +## Automated Moderation System (optional) +The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities. + +In production, you should have Cloudflare or some other DDoS protection in place to really protect the server from excessive requests, but these changes will largely protect you from the single or several bad actors targeting your deployed instance for proxying. + +### Notes + +- Uses Caching for basic security and violation logging (bans, concurrent messages, exceeding rate limits) + - In the near future, I will add **Redis** support for production instances, which can be easily injected into the current caching setup +- Exceeding any of the rate limiters (login/registration/messaging) is considered a violation, default score is 1 +- Non-browser origin is a violation +- Default score for each violation is configurable +- Enabling any of the limiters and/or bans enables caching/logging +- Violation logs can be found in the data folder, which is created when logging begins: `librechat/data` + - **Only violations are logged** + - `violations.json` keeps track of the total count for each violation per user + - `logs.json` records each individual violation per user +- Ban logs are stored in MongoDB under the `logs` collection. They are transient as they only exist for the ban duration + - If you would like to remove a ban manually, you would have to remove them from the database manually and restart the server + - **Redis** support is also planned for this. + +### Rate Limiters + +The project's current rate limiters are as follows (see below under setup for default values): + +- Login and registration rate limiting +- [optional] Concurrent Message limiting (only X messages at a time per user) +- [optional] Message limiting (how often a user can send a message, configurable by IP and User) + +### Setup + +The following are all of the related env variables to make use of and configure the mod system. Note this is also found in the [/.env.example](/.env.example) file, to be set in your own `.env` file. + +```bash +BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged) +BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for +BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold + +# The score for each violation + +LOGIN_VIOLATION_SCORE=1 +REGISTRATION_VIOLATION_SCORE=1 +CONCURRENT_VIOLATION_SCORE=1 +MESSAGE_VIOLATION_SCORE=1 +NON_BROWSER_VIOLATION_SCORE=20 + +# Login and registration rate limiting. + +LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW +LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins +REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW +REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations + +# Message rate limiting (per user & IP) + +LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request +CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request + +LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages + +# Note: You can utilize both limiters, but default is to limit by IP only. +LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages +``` \ No newline at end of file diff --git a/docs/features/third-party.md b/docs/features/third_party.md similarity index 100% rename from docs/features/third-party.md rename to docs/features/third_party.md diff --git a/docs/install/user_auth_system.md b/docs/install/user_auth_system.md index 92ad4d2b11..b605f44de2 100644 --- a/docs/install/user_auth_system.md +++ b/docs/install/user_auth_system.md @@ -9,15 +9,29 @@ In order for the auth system to function properly, there are some environment va In /.env, you will need to set the following variables: ```bash -# Change this to a secure string +# Change the secrets to a secure, random string JWT_SECRET=secret +JWT_REFRESH_SECRET=refresh_secret + # Set the expiration delay for the secure cookie with the JWT token -# Delay is in millisecond e.g. 7 days is 1000*60*60*24*7 -SESSION_EXPIRY=1000 * 60 * 60 * 24 * 7 +# Delay is in milliseconds e.g. 7 days is 1000*60*60*24*7 + +# Recommended session expiry is 15 minutes. Make it longer if you want the user to be able to revist the page without logging in for a longer duration of time. + +# Recommended refresh token expiry is 7 days +SESSION_EXPIRY=1000 * 60 * 15 +REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7 + DOMAIN_SERVER=http://localhost:3080 DOMAIN_CLIENT=http://localhost:3080 ``` +## Automated Moderation System (optional) + +The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities. + +To set up the mod system, review [the setup guide](../features/mod_system.md). + *Please Note: If you are wanting this to work in development mode, you will need to create a file called `.env.development` in the root directory and set `DOMAIN_CLIENT` to `http://localhost:3090` or whatever port is provided by vite when runnning `npm run frontend-dev`* Important: When you run the app for the first time, you need to create a new account by clicking on "Sign up" on the login page. The first account you make will be the admin account. The admin account doesn't have any special features right now, but it might be useful if you want to make an admin dashboard to manage other users later. diff --git a/mkdocs.yml b/mkdocs.yml index b036ecbae4..b10915e604 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -102,7 +102,8 @@ nav: - Azure Cognitive Search: 'features/plugins/azure_cognitive_search.md' - Make Your Own Plugin: 'features/plugins/make_your_own.md' - Using official ChatGPT Plugins: 'features/plugins/chatgpt_plugins_openapi.md' - - Third-Party Tools: 'features/third-party.md' + - Automated Moderation: 'features/mod_system.md' + - Third-Party Tools: 'features/third_party.md' - Proxy: 'features/proxy.md' - Bing Jailbreak: 'features/bing_jailbreak.md' - Cloud Deployment: diff --git a/package-lock.json b/package-lock.json index 849219d91e..4336e74460 100644 --- a/package-lock.json +++ b/package-lock.json @@ -72,7 +72,7 @@ "joi": "^17.9.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", - "keyv": "^4.5.2", + "keyv": "^4.5.3", "keyv-file": "^0.2.0", "langchain": "^0.0.144", "lodash": "^4.17.21", @@ -91,6 +91,7 @@ "pino": "^8.12.1", "sanitize": "^2.1.2", "sharp": "^0.32.5", + "ua-parser-js": "^1.0.36", "zod": "^3.22.2" }, "devDependencies": { @@ -25069,6 +25070,28 @@ "node": ">=14.17" } }, + "node_modules/ua-parser-js": { + "version": "1.0.36", + "resolved": "https://registry.npmjs.org/ua-parser-js/-/ua-parser-js-1.0.36.tgz", + "integrity": "sha512-znuyCIXzl8ciS3+y3fHJI/2OhQIXbXw9MWC/o3qwyR+RGppjZHrM27CGFSKCJXi2Kctiz537iOu2KnXs1lMQhw==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/ua-parser-js" + }, + { + "type": "paypal", + "url": "https://paypal.me/faisalman" + }, + { + "type": "github", + "url": "https://github.com/sponsors/faisalman" + } + ], + "engines": { + "node": "*" + } + }, "node_modules/uglify-js": { "version": "3.17.4", "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.17.4.tgz", diff --git a/package.json b/package.json index bff7ba80ed..22397f172c 100644 --- a/package.json +++ b/package.json @@ -49,8 +49,8 @@ "b:data-provider": "cd packages/data-provider && bun run b:build", "b:client": "bun run b:data-provider && cd client && bun run b:build", "b:client:dev": "cd client && bun run b:dev", - "b:test:client": "cd client && bun run test", - "b:test:api": "cd api && bun run test" + "b:test:client": "cd client && bun run b:test", + "b:test:api": "cd api && bun run b:test" }, "repository": { "type": "git", @@ -92,7 +92,7 @@ "nodemonConfig": { "ignore": [ "api/data/", - "data", + "data/", "client/", "admin/", "packages/" diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index aa69af2e26..5db9e26d35 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.1.7", + "version": "0.1.8", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js",