diff --git a/packages/api/src/agents/__tests__/initialize.test.ts b/packages/api/src/agents/__tests__/initialize.test.ts index 128ff76a00..adb6689bc5 100644 --- a/packages/api/src/agents/__tests__/initialize.test.ts +++ b/packages/api/src/agents/__tests__/initialize.test.ts @@ -338,6 +338,29 @@ describe('initializeAgent — provider web_search precedence', () => { jest.clearAllMocks(); }); + async function initializeGoogleMixedToolAgent(model: string, provider = Providers.GOOGLE) { + const { agent, req, res, loadTools, db } = createMocks({ + provider, + model, + providerTools: [nativeGoogleSearchTool], + loadedToolDefinitions: [mcpToolDefinition], + }); + agent.tools = ['mcp_lookup']; + + return initializeAgent( + { + req, + res, + agent, + loadTools, + endpointOption: { endpoint: EModelEndpoint.agents }, + allowedProviders: new Set([provider]), + isInitialAgent: true, + }, + db, + ); + } + it('keeps Anthropic native web_search when LibreChat search is not selected', async () => { const { agent, req, res, loadTools, db } = createMocks({ provider: Providers.ANTHROPIC, @@ -438,31 +461,126 @@ describe('initializeAgent — provider web_search precedence', () => { expect(result.tools).toEqual([nativeGoogleSearchTool]); expect(countGoogleSearchTools(result.tools)).toBe(1); expect(result.toolDefinitions).toContain(mcpToolDefinition); + expect(result.model_parameters).toEqual( + expect.objectContaining({ + includeServerSideToolInvocations: true, + }), + ); }); - it('rejects Google native search with external tools for unsupported Gemini models', async () => { + it('includes the mixed-tool flag for Vertex AI native search with external tools', async () => { const { agent, req, res, loadTools, db } = createMocks({ - provider: Providers.GOOGLE, - model: 'gemini-2.5-flash', + provider: Providers.VERTEXAI, + model: 'gemini-3.5-flash', providerTools: [nativeGoogleSearchTool], loadedToolDefinitions: [mcpToolDefinition], }); agent.tools = ['mcp_lookup']; - await expect( - initializeAgent( - { - req, - res, - agent, - loadTools, - endpointOption: { endpoint: EModelEndpoint.agents }, - allowedProviders: new Set([Providers.GOOGLE]), - isInitialAgent: true, - }, - db, - ), - ).rejects.toThrow(/google_tool_conflict/); + const result = await initializeAgent( + { + req, + res, + agent, + loadTools, + endpointOption: { endpoint: EModelEndpoint.agents }, + allowedProviders: new Set([Providers.VERTEXAI]), + isInitialAgent: true, + }, + db, + ); + + expect(result.tools).toEqual([nativeGoogleSearchTool]); + expect(result.toolDefinitions).toContain(mcpToolDefinition); + expect(result.model_parameters).toEqual( + expect.objectContaining({ + includeServerSideToolInvocations: true, + }), + ); + }); + + it.each([ + 'gemini-3-flash-preview', + 'gemini-3-pro-preview', + 'gemini-3.1-pro-preview', + 'gemini-3.1-pro-preview-customtools', + 'gemini-3.1-flash-lite', + 'gemini-3.1-flash-lite-preview', + 'gemini-3.5-flash', + 'google/gemini-3.5-flash-latest', + 'models/gemini-3.10-pro-preview', + 'gemini-4-pro-preview', + ])('allows Google mixed tools for supported Gemini text model %s', async (model) => { + const result = await initializeGoogleMixedToolAgent(model); + + expect(result.tools).toEqual([nativeGoogleSearchTool]); + expect(result.toolDefinitions).toContain(mcpToolDefinition); + expect(result.model_parameters).toEqual( + expect.objectContaining({ + includeServerSideToolInvocations: true, + }), + ); + }); + + it('sets the mixed-tool flag when the skill catalog adds the external tool', async () => { + const { agent, req, res, loadTools, db } = createMocks({ + provider: Providers.GOOGLE, + model: 'gemini-3.5-flash', + providerTools: [nativeGoogleSearchTool], + }); + const { Types } = await import('mongoose'); + const skillId = new Types.ObjectId(); + const author = { + toString: () => req.user?.id, + } as unknown as import('mongoose').Types.ObjectId; + + const result = await initializeAgent( + { + req, + res, + agent, + loadTools, + endpointOption: { endpoint: EModelEndpoint.agents }, + allowedProviders: new Set([Providers.GOOGLE]), + isInitialAgent: true, + accessibleSkillIds: [skillId], + }, + { + ...db, + listSkillsByAccess: jest.fn().mockResolvedValue({ + skills: [ + { + _id: skillId, + name: 'research-helper', + description: 'Research current information.', + author, + }, + ], + has_more: false, + after: null, + }), + }, + ); + + expect(result.tools).toEqual([nativeGoogleSearchTool]); + expect(result.toolDefinitions?.map((toolDefinition) => toolDefinition.name)).toContain('skill'); + expect(result.model_parameters).toEqual( + expect.objectContaining({ + includeServerSideToolInvocations: true, + }), + ); + }); + + it.each([ + 'gemini-2.5-flash', + 'gemini-3', + 'gemini-3.1', + 'gemini-3-pro-image-preview', + 'gemini-3.1-flash-image', + 'gemini-3.5-flash-live', + 'gemini-4-pro-tts', + ])('rejects Google mixed tools for unsupported Gemini model %s', async (model) => { + await expect(initializeGoogleMixedToolAgent(model)).rejects.toThrow(/google_tool_conflict/); }); it('prefers LibreChat web_search when Google native search is also enabled', async () => { @@ -1681,6 +1799,9 @@ describe('initializeAgent — execute_code capability expansion', () => { ).resolves.toEqual( expect.objectContaining({ tools: [{ googleSearch: {} }], + model_parameters: expect.objectContaining({ + includeServerSideToolInvocations: true, + }), toolDefinitions: expect.arrayContaining([ expect.objectContaining({ name: 'bash_tool' }), expect.objectContaining({ name: 'read_file' }), @@ -1722,6 +1843,9 @@ describe('initializeAgent — execute_code capability expansion', () => { ).resolves.toEqual( expect.objectContaining({ tools: [structuredTool, providerTool], + model_parameters: expect.objectContaining({ + includeServerSideToolInvocations: true, + }), }), ); }); diff --git a/packages/api/src/agents/initialize.ts b/packages/api/src/agents/initialize.ts index 36cdb5e53f..122e5c3462 100644 --- a/packages/api/src/agents/initialize.ts +++ b/packages/api/src/agents/initialize.ts @@ -24,15 +24,8 @@ import type { GenericTool, LCToolRegistry, ToolMap, LCTool } from '@librechat/ag import type { Response as ServerResponse } from 'express'; import type { IMongoFile } from '@librechat/data-schemas'; import type { InitializeResultBase, ServerRequest, EndpointDbMethods } from '~/types'; -import { - optionalChainWithEmptyCheck, - extractLibreChatParams, - getModelMaxTokens, - getThreadData, -} from '~/utils'; -import { filterFilesByEndpointConfig } from '~/files'; -import { generateArtifactsPrompt } from '~/prompts'; -import { getProviderConfig } from '~/endpoints'; +import type { ResolvedManualSkill, ResolvedAlwaysApplySkill } from './skills'; +import type { TFilterFilesByAgentAccess } from './resources'; import { injectSkillCatalog, resolveManualSkills, @@ -40,14 +33,21 @@ import { unionPrimeAllowedTools, MAX_PRIMED_SKILLS_PER_TURN, } from './skills'; +import { + optionalChainWithEmptyCheck, + extractLibreChatParams, + getModelMaxTokens, + getThreadData, +} from '~/utils'; import { registerCodeExecutionTools, registerFileAuthoringTools, isFileAuthoringToolDefinition, } from './tools'; +import { filterFilesByEndpointConfig } from '~/files'; +import { generateArtifactsPrompt } from '~/prompts'; +import { getProviderConfig } from '~/endpoints'; import { primeResources } from './resources'; -import type { ResolvedManualSkill, ResolvedAlwaysApplySkill } from './skills'; -import type { TFilterFilesByAgentAccess } from './resources'; /** * Fraction of context budget reserved as headroom when no explicit maxContextTokens is set. @@ -56,6 +56,15 @@ import type { TFilterFilesByAgentAccess } from './resources'; */ const DEFAULT_RESERVE_RATIO = 0.05; const temporalSpecialVarRegex = /{{\s*(current_date|current_datetime|iso_datetime)\s*}}/i; +const geminiModelVersionRegex = /^gemini-(\d+)(?:\.(\d+))?(?:-|$)/; +const googleToolCombinationTextModels = [ + 'gemini-3-flash-preview', + 'gemini-3-pro-preview', + 'gemini-3.1-flash-lite', + 'gemini-3.1-pro-preview', +]; +const googleToolCombinationExcludedModalityRegex = + /(?:^|-)image(?:-|$)|(?:^|-)live(?:-|$)|(?:^|-)tts(?:-|$)/; function hasTemporalSpecialVars(text: string): boolean { return temporalSpecialVarRegex.test(text); @@ -96,12 +105,71 @@ function hasGoogleSearchTool(tool: unknown): boolean { return 'googleSearch' in tool || 'googleSearchRetrieval' in tool; } +function normalizeGoogleModelName(model: string): string { + const normalized = model.trim().toLowerCase(); + return normalized.split('/').pop() ?? normalized; +} + +function isKnownGoogleToolCombinationTextModel(model: string): boolean { + return googleToolCombinationTextModels.some( + (knownModel) => model === knownModel || model.startsWith(`${knownModel}-`), + ); +} + +function isGemini35OrLater(model: string): boolean { + const match = geminiModelVersionRegex.exec(model); + if (!match) { + return false; + } + const major = Number(match[1]); + const minor = Number(match[2] ?? '0'); + return major > 3 || (major === 3 && minor >= 5); +} + function supportsGoogleToolCombination(model: unknown): boolean { if (typeof model !== 'string') { return false; } - const normalized = model.toLowerCase().split('/').pop() ?? model.toLowerCase(); - return normalized.startsWith('gemini-3'); + const normalized = normalizeGoogleModelName(model); + if (googleToolCombinationExcludedModalityRegex.test(normalized)) { + return false; + } + return isKnownGoogleToolCombinationTextModel(normalized) || isGemini35OrLater(normalized); +} + +function isGoogleToolCombinationProvider(provider?: string): boolean { + return provider === Providers.GOOGLE || provider === Providers.VERTEXAI; +} + +function shouldIncludeGoogleServerSideToolInvocations({ + provider, + hasProviderTools, + hasAgentTools, +}: { + provider?: string; + hasProviderTools: boolean; + hasAgentTools: boolean; +}): boolean { + return isGoogleToolCombinationProvider(provider) && hasProviderTools && hasAgentTools; +} + +function assertGoogleToolCombinationSupport(model: unknown): void { + if (!supportsGoogleToolCombination(model)) { + throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`); + } +} + +function enableGoogleServerSideToolInvocations({ + agent, + llmConfig, +}: { + agent: Agent; + llmConfig: Record; +}): void { + llmConfig.includeServerSideToolInvocations = true; + if (agent.model_parameters) { + (agent.model_parameters as Record).includeServerSideToolInvocations = true; + } } function resolveProviderToolConflicts({ @@ -965,13 +1033,16 @@ export async function initializeAgent( ? (providerTools as GenericTool[]) : (structuredTools ?? []); - if ( - (agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) && - hasProviderTools && - hasAgentTools - ) { - if (!supportsGoogleToolCombination(llmConfig.model)) { - throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`); + if (isGoogleToolCombinationProvider(agent.provider) && hasProviderTools && hasAgentTools) { + assertGoogleToolCombinationSupport(llmConfig.model); + if ( + shouldIncludeGoogleServerSideToolInvocations({ + provider: agent.provider, + hasProviderTools, + hasAgentTools, + }) + ) { + enableGoogleServerSideToolInvocations({ agent, llmConfig }); } if (structuredTools?.length) { tools = structuredTools.concat(providerTools as GenericTool[]); @@ -1043,6 +1114,21 @@ export async function initializeAgent( activeSkillNames = skillResult.activeSkillNames; } + const hasFinalAgentTools = + (structuredTools?.length ?? 0) > 0 || (toolDefinitions?.length ?? 0) > 0; + if (isGoogleToolCombinationProvider(agent.provider) && hasProviderTools && hasFinalAgentTools) { + assertGoogleToolCombinationSupport(llmConfig.model); + if ( + shouldIncludeGoogleServerSideToolInvocations({ + provider: agent.provider, + hasProviderTools, + hasAgentTools: hasFinalAgentTools, + }) + ) { + enableGoogleServerSideToolInvocations({ agent, llmConfig }); + } + } + const agentMaxContextNum = Number(agentMaxContextTokens) || DEFAULT_MAX_CONTEXT_TOKENS; const maxOutputTokensNum = Number(maxOutputTokens) || 0; const baseContextTokens = Math.max(0, agentMaxContextNum - maxOutputTokensNum); diff --git a/packages/api/src/endpoints/google/initialize.spec.ts b/packages/api/src/endpoints/google/initialize.spec.ts new file mode 100644 index 0000000000..50de1cac26 --- /dev/null +++ b/packages/api/src/endpoints/google/initialize.spec.ts @@ -0,0 +1,130 @@ +import { Providers } from '@librechat/agents'; +import { AuthKeys, EModelEndpoint } from 'librechat-data-provider'; +import type { EndpointDbMethods, ServerRequest } from '~/types'; + +const mockGetGoogleConfig = jest.fn( + (_credentials?: unknown, _options?: unknown, _acceptRawApiKey?: unknown) => ({ + provider: Providers.VERTEXAI, + llmConfig: { model: 'gemini-2.5-flash' }, + }), +); +const mockIsEnabled = jest.fn(); +const mockLoadServiceKey = jest.fn(); +const mockCheckUserKeyExpiry = jest.fn(); + +jest.mock('./llm', () => ({ + getGoogleConfig: (credentials: unknown, options: unknown, acceptRawApiKey?: unknown) => + mockGetGoogleConfig(credentials, options, acceptRawApiKey), +})); + +jest.mock('~/utils', () => ({ + isEnabled: (value: unknown) => mockIsEnabled(value), + loadServiceKey: (keyPath: unknown) => mockLoadServiceKey(keyPath), + checkUserKeyExpiry: (expiresAt: unknown, endpoint: unknown) => + mockCheckUserKeyExpiry(expiresAt, endpoint), +})); + +import { initializeGoogle } from './initialize'; + +function createDb(): EndpointDbMethods { + return { + getUserKey: jest.fn().mockResolvedValue('user-google-key'), + getUserKeyValues: jest.fn().mockResolvedValue({}), + }; +} + +function createReq(): ServerRequest { + return { + body: {}, + config: {}, + user: { id: 'user-1' }, + } as ServerRequest; +} + +function getGoogleConfigCall(): [Record, Record] { + expect(mockGetGoogleConfig).toHaveBeenCalled(); + return mockGetGoogleConfig.mock.calls[0] as unknown as [ + Record, + Record, + ]; +} + +describe('initializeGoogle', () => { + const originalEnv = process.env; + + beforeEach(() => { + jest.clearAllMocks(); + process.env = { ...originalEnv }; + delete process.env.GOOGLE_KEY; + delete process.env.GOOGLE_REVERSE_PROXY; + delete process.env.GOOGLE_AUTH_HEADER; + delete process.env.GOOGLE_SERVICE_KEY_FILE; + delete process.env.VERTEX_PROJECT_ID; + delete process.env.GOOGLE_CLOUD_PROJECT; + delete process.env.GCLOUD_PROJECT; + delete process.env.GOOGLE_PROJECT_ID; + }); + + afterAll(() => { + process.env = originalEnv; + }); + + it('forces Vertex AI ADC config and ignores GOOGLE_KEY for vertexai endpoint', async () => { + process.env.GOOGLE_KEY = 'test-api-key'; + process.env.VERTEX_PROJECT_ID = 'fiery-catwalk-385918'; + mockLoadServiceKey.mockResolvedValue(null); + + const db = createDb(); + + await initializeGoogle({ + req: createReq(), + endpoint: Providers.VERTEXAI, + model_parameters: { model: 'gemini-2.5-flash' }, + db, + }); + + expect(mockLoadServiceKey).toHaveBeenCalledTimes(1); + expect(db.getUserKey).not.toHaveBeenCalled(); + expect(mockCheckUserKeyExpiry).not.toHaveBeenCalled(); + + const [credentials, options] = getGoogleConfigCall(); + expect(credentials).toEqual({ + [AuthKeys.GOOGLE_SERVICE_KEY]: {}, + }); + expect(credentials).not.toHaveProperty(AuthKeys.GOOGLE_API_KEY); + expect(options).toEqual( + expect.objectContaining({ + forceVertex: true, + projectId: 'fiery-catwalk-385918', + modelOptions: { model: 'gemini-2.5-flash' }, + }), + ); + }); + + it('keeps Google API-key config for the google endpoint', async () => { + process.env.GOOGLE_KEY = 'test-api-key'; + process.env.VERTEX_PROJECT_ID = 'fiery-catwalk-385918'; + + await initializeGoogle({ + req: createReq(), + endpoint: EModelEndpoint.google, + model_parameters: { model: 'gemini-2.5-flash' }, + db: createDb(), + }); + + expect(mockLoadServiceKey).not.toHaveBeenCalled(); + + const [credentials, options] = getGoogleConfigCall(); + expect(credentials).toEqual({ + [AuthKeys.GOOGLE_SERVICE_KEY]: {}, + [AuthKeys.GOOGLE_API_KEY]: 'test-api-key', + }); + expect(options).toEqual( + expect.objectContaining({ + forceVertex: false, + projectId: undefined, + modelOptions: { model: 'gemini-2.5-flash' }, + }), + ); + }); +}); diff --git a/packages/api/src/endpoints/google/initialize.ts b/packages/api/src/endpoints/google/initialize.ts index dcf750bb87..1050a974eb 100644 --- a/packages/api/src/endpoints/google/initialize.ts +++ b/packages/api/src/endpoints/google/initialize.ts @@ -1,4 +1,5 @@ import path from 'path'; +import { Providers } from '@librechat/agents'; import { EModelEndpoint, AuthKeys } from 'librechat-data-provider'; import type { BaseInitializeParams, @@ -23,14 +24,15 @@ export async function initializeGoogle({ model_parameters, db, }: BaseInitializeParams): Promise { - void endpoint; const appConfig = req.config; const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env; const isUserProvided = GOOGLE_KEY === 'user_provided'; + const isVertexEndpoint = endpoint === Providers.VERTEXAI; + const useUserProvidedGoogleKey = !isVertexEndpoint && isUserProvided; const { key: expiresAt } = req.body; let userKey = null; - if (expiresAt && isUserProvided) { + if (expiresAt && useUserProvidedGoogleKey) { checkUserKeyExpiry(expiresAt, EModelEndpoint.google); userKey = await db.getUserKey({ userId: req.user?.id ?? '', name: EModelEndpoint.google }); } @@ -39,9 +41,10 @@ export async function initializeGoogle({ /** Check if GOOGLE_KEY is provided at all (including 'user_provided') */ const isGoogleKeyProvided = - (GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null); + !isVertexEndpoint && + ((GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (useUserProvidedGoogleKey && userKey != null)); - if (!isGoogleKeyProvided && loadServiceKey) { + if ((isVertexEndpoint || !isGoogleKeyProvided) && loadServiceKey) { /** Only attempt to load service key if GOOGLE_KEY is not provided */ try { const serviceKeyPath = @@ -56,11 +59,11 @@ export async function initializeGoogle({ } } - const credentials: GoogleCredentials = isUserProvided + const credentials: GoogleCredentials = useUserProvidedGoogleKey ? (userKey as GoogleCredentials) : { [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, - [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, + ...(!isVertexEndpoint && { [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY }), }; let clientOptions: GoogleConfigOptions = {}; @@ -84,6 +87,13 @@ export async function initializeGoogle({ authHeader: isEnabled(GOOGLE_AUTH_HEADER) ?? undefined, proxy: PROXY ?? undefined, modelOptions: model_parameters ?? {}, + forceVertex: isVertexEndpoint, + projectId: isVertexEndpoint + ? (process.env.VERTEX_PROJECT_ID ?? + process.env.GOOGLE_CLOUD_PROJECT ?? + process.env.GCLOUD_PROJECT ?? + process.env.GOOGLE_PROJECT_ID) + : undefined, ...clientOptions, }; diff --git a/packages/api/src/endpoints/google/llm.spec.ts b/packages/api/src/endpoints/google/llm.spec.ts index 0a65ab7a80..c8f6cffd0e 100644 --- a/packages/api/src/endpoints/google/llm.spec.ts +++ b/packages/api/src/endpoints/google/llm.spec.ts @@ -64,6 +64,23 @@ describe('getGoogleConfig', () => { expect(result.llmConfig).toHaveProperty('apiKey', 'raw-api-key-string'); }); + it('should not let project id force Vertex AI without the force flag', () => { + const credentials = { + [AuthKeys.GOOGLE_API_KEY]: 'test-api-key', + }; + + const result = getGoogleConfig(credentials, { + projectId: 'fiery-catwalk-385918', + modelOptions: { + model: 'gemini-2.5-flash', + }, + }); + + expect(result.provider).toBe(Providers.GOOGLE); + expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key'); + expect(result.llmConfig).not.toHaveProperty('authOptions'); + }); + it('should handle model options including temperature and topP/topK', () => { const credentials = { [AuthKeys.GOOGLE_API_KEY]: 'test-api-key', @@ -318,6 +335,53 @@ describe('getGoogleConfig', () => { expect(result.llmConfig).toHaveProperty('location', 'us-central1'); }); + it('should force Vertex AI ADC config with a project id even when an API key is present', () => { + const credentials = { + [AuthKeys.GOOGLE_API_KEY]: 'test-api-key', + }; + + const result = getGoogleConfig(credentials, { + forceVertex: true, + projectId: 'fiery-catwalk-385918', + modelOptions: { + model: 'gemini-2.5-flash', + }, + }); + + expect(result.provider).toBe(Providers.VERTEXAI); + expect(result.llmConfig).not.toHaveProperty('apiKey'); + expect((result.llmConfig as Record).authOptions).toEqual({ + projectId: 'fiery-catwalk-385918', + }); + }); + + it('should force Vertex AI service-account config when an API key is also present', () => { + const credentials = { + [AuthKeys.GOOGLE_API_KEY]: 'test-api-key', + [AuthKeys.GOOGLE_SERVICE_KEY]: { + project_id: 'test-project', + client_email: 'test@test-project.iam.gserviceaccount.com', + private_key: 'test-private-key', + }, + }; + + const result = getGoogleConfig(credentials, { + forceVertex: true, + modelOptions: { + model: 'gemini-2.5-flash', + }, + }); + + expect(result.provider).toBe(Providers.VERTEXAI); + expect(result.llmConfig).not.toHaveProperty('apiKey'); + expect((result.llmConfig as Record).authOptions).toMatchObject({ + projectId: 'test-project', + credentials: expect.objectContaining({ + project_id: 'test-project', + }), + }); + }); + it('should use GOOGLE_LOC env variable for Vertex AI location', () => { process.env.GOOGLE_LOC = 'europe-west1'; diff --git a/packages/api/src/endpoints/google/llm.ts b/packages/api/src/endpoints/google/llm.ts index 98802dcd15..7ba91b9c9d 100644 --- a/packages/api/src/endpoints/google/llm.ts +++ b/packages/api/src/endpoints/google/llm.ts @@ -264,6 +264,10 @@ function applyVertexMultiRegionEndpoint(config: VertexAIClientOptions & { endpoi } } +function hasServiceKeyCredentials(serviceKey: Record): boolean { + return Object.keys(serviceKey).length > 0; +} + export function getSafetySettings( model?: string, ): Array<{ category: string; threshold: string }> | undefined { @@ -337,7 +341,12 @@ export function getGoogleConfig( typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {}); const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null; - const project_id = !apiKey ? (serviceKey?.project_id ?? null) : null; + let project_id = null; + if (options.forceVertex === true) { + project_id = options.projectId ?? serviceKey?.project_id ?? null; + } else if (!apiKey) { + project_id = serviceKey?.project_id ?? null; + } const reverseProxyUrl = options.reverseProxyUrl; const authHeader = options.authHeader; @@ -373,7 +382,7 @@ export function getGoogleConfig( let provider; - if (project_id) { + if (options.forceVertex === true || project_id) { provider = Providers.VERTEXAI; } else { provider = Providers.GOOGLE; @@ -381,10 +390,13 @@ export function getGoogleConfig( // If we have a GCP project => Vertex AI if (provider === Providers.VERTEXAI) { - (llmConfig as VertexAIClientOptions).authOptions = { - credentials: { ...serviceKey }, - projectId: project_id, - }; + (llmConfig as VertexAIClientOptions).authOptions = removeNullishValues( + { + ...(hasServiceKeyCredentials(serviceKey) && { credentials: { ...serviceKey } }), + projectId: project_id, + }, + true, + ); const location = process.env.GOOGLE_LOC || 'us-central1'; (llmConfig as VertexAIClientOptions).location = location; } else if (apiKey && provider === Providers.GOOGLE) { diff --git a/packages/api/src/types/google.ts b/packages/api/src/types/google.ts index c6a53f961a..bf26f23a47 100644 --- a/packages/api/src/types/google.ts +++ b/packages/api/src/types/google.ts @@ -27,4 +27,8 @@ export interface GoogleConfigOptions { streamRate?: number; /** Model to use for title generation */ titleModel?: string; + /** Force Vertex AI auth semantics even when a Google API key is configured */ + forceVertex?: boolean; + /** GCP project id for Vertex AI ADC/service-account authentication */ + projectId?: string; }