mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-06-09 17:31:19 +00:00
🛠️ fix: Enable Gemini Mixed Tool Config (#13538)
* fix: enable Gemini mixed tool config * fix: apply Gemini mixed tool flag after skills * style: match initialize formatting * style: wrap final tool check * fix: Respect Vertex auth mode * style: Sort Agent Initialize Imports * fix: Tighten Gemini Mixed Tool Gate
This commit is contained in:
parent
bfb6b224d2
commit
ed4546c5dc
7 changed files with 479 additions and 49 deletions
|
|
@ -338,6 +338,29 @@ describe('initializeAgent — provider web_search precedence', () => {
|
|||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
async function initializeGoogleMixedToolAgent(model: string, provider = Providers.GOOGLE) {
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
provider,
|
||||
model,
|
||||
providerTools: [nativeGoogleSearchTool],
|
||||
loadedToolDefinitions: [mcpToolDefinition],
|
||||
});
|
||||
agent.tools = ['mcp_lookup'];
|
||||
|
||||
return initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: { endpoint: EModelEndpoint.agents },
|
||||
allowedProviders: new Set([provider]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
}
|
||||
|
||||
it('keeps Anthropic native web_search when LibreChat search is not selected', async () => {
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
provider: Providers.ANTHROPIC,
|
||||
|
|
@ -438,31 +461,126 @@ describe('initializeAgent — provider web_search precedence', () => {
|
|||
expect(result.tools).toEqual([nativeGoogleSearchTool]);
|
||||
expect(countGoogleSearchTools(result.tools)).toBe(1);
|
||||
expect(result.toolDefinitions).toContain(mcpToolDefinition);
|
||||
expect(result.model_parameters).toEqual(
|
||||
expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('rejects Google native search with external tools for unsupported Gemini models', async () => {
|
||||
it('includes the mixed-tool flag for Vertex AI native search with external tools', async () => {
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
provider: Providers.GOOGLE,
|
||||
model: 'gemini-2.5-flash',
|
||||
provider: Providers.VERTEXAI,
|
||||
model: 'gemini-3.5-flash',
|
||||
providerTools: [nativeGoogleSearchTool],
|
||||
loadedToolDefinitions: [mcpToolDefinition],
|
||||
});
|
||||
agent.tools = ['mcp_lookup'];
|
||||
|
||||
await expect(
|
||||
initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: { endpoint: EModelEndpoint.agents },
|
||||
allowedProviders: new Set([Providers.GOOGLE]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
),
|
||||
).rejects.toThrow(/google_tool_conflict/);
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: { endpoint: EModelEndpoint.agents },
|
||||
allowedProviders: new Set([Providers.VERTEXAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
expect(result.tools).toEqual([nativeGoogleSearchTool]);
|
||||
expect(result.toolDefinitions).toContain(mcpToolDefinition);
|
||||
expect(result.model_parameters).toEqual(
|
||||
expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it.each([
|
||||
'gemini-3-flash-preview',
|
||||
'gemini-3-pro-preview',
|
||||
'gemini-3.1-pro-preview',
|
||||
'gemini-3.1-pro-preview-customtools',
|
||||
'gemini-3.1-flash-lite',
|
||||
'gemini-3.1-flash-lite-preview',
|
||||
'gemini-3.5-flash',
|
||||
'google/gemini-3.5-flash-latest',
|
||||
'models/gemini-3.10-pro-preview',
|
||||
'gemini-4-pro-preview',
|
||||
])('allows Google mixed tools for supported Gemini text model %s', async (model) => {
|
||||
const result = await initializeGoogleMixedToolAgent(model);
|
||||
|
||||
expect(result.tools).toEqual([nativeGoogleSearchTool]);
|
||||
expect(result.toolDefinitions).toContain(mcpToolDefinition);
|
||||
expect(result.model_parameters).toEqual(
|
||||
expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('sets the mixed-tool flag when the skill catalog adds the external tool', async () => {
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
provider: Providers.GOOGLE,
|
||||
model: 'gemini-3.5-flash',
|
||||
providerTools: [nativeGoogleSearchTool],
|
||||
});
|
||||
const { Types } = await import('mongoose');
|
||||
const skillId = new Types.ObjectId();
|
||||
const author = {
|
||||
toString: () => req.user?.id,
|
||||
} as unknown as import('mongoose').Types.ObjectId;
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: { endpoint: EModelEndpoint.agents },
|
||||
allowedProviders: new Set([Providers.GOOGLE]),
|
||||
isInitialAgent: true,
|
||||
accessibleSkillIds: [skillId],
|
||||
},
|
||||
{
|
||||
...db,
|
||||
listSkillsByAccess: jest.fn().mockResolvedValue({
|
||||
skills: [
|
||||
{
|
||||
_id: skillId,
|
||||
name: 'research-helper',
|
||||
description: 'Research current information.',
|
||||
author,
|
||||
},
|
||||
],
|
||||
has_more: false,
|
||||
after: null,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
expect(result.tools).toEqual([nativeGoogleSearchTool]);
|
||||
expect(result.toolDefinitions?.map((toolDefinition) => toolDefinition.name)).toContain('skill');
|
||||
expect(result.model_parameters).toEqual(
|
||||
expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it.each([
|
||||
'gemini-2.5-flash',
|
||||
'gemini-3',
|
||||
'gemini-3.1',
|
||||
'gemini-3-pro-image-preview',
|
||||
'gemini-3.1-flash-image',
|
||||
'gemini-3.5-flash-live',
|
||||
'gemini-4-pro-tts',
|
||||
])('rejects Google mixed tools for unsupported Gemini model %s', async (model) => {
|
||||
await expect(initializeGoogleMixedToolAgent(model)).rejects.toThrow(/google_tool_conflict/);
|
||||
});
|
||||
|
||||
it('prefers LibreChat web_search when Google native search is also enabled', async () => {
|
||||
|
|
@ -1681,6 +1799,9 @@ describe('initializeAgent — execute_code capability expansion', () => {
|
|||
).resolves.toEqual(
|
||||
expect.objectContaining({
|
||||
tools: [{ googleSearch: {} }],
|
||||
model_parameters: expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
toolDefinitions: expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'bash_tool' }),
|
||||
expect.objectContaining({ name: 'read_file' }),
|
||||
|
|
@ -1722,6 +1843,9 @@ describe('initializeAgent — execute_code capability expansion', () => {
|
|||
).resolves.toEqual(
|
||||
expect.objectContaining({
|
||||
tools: [structuredTool, providerTool],
|
||||
model_parameters: expect.objectContaining({
|
||||
includeServerSideToolInvocations: true,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -24,15 +24,8 @@ import type { GenericTool, LCToolRegistry, ToolMap, LCTool } from '@librechat/ag
|
|||
import type { Response as ServerResponse } from 'express';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { InitializeResultBase, ServerRequest, EndpointDbMethods } from '~/types';
|
||||
import {
|
||||
optionalChainWithEmptyCheck,
|
||||
extractLibreChatParams,
|
||||
getModelMaxTokens,
|
||||
getThreadData,
|
||||
} from '~/utils';
|
||||
import { filterFilesByEndpointConfig } from '~/files';
|
||||
import { generateArtifactsPrompt } from '~/prompts';
|
||||
import { getProviderConfig } from '~/endpoints';
|
||||
import type { ResolvedManualSkill, ResolvedAlwaysApplySkill } from './skills';
|
||||
import type { TFilterFilesByAgentAccess } from './resources';
|
||||
import {
|
||||
injectSkillCatalog,
|
||||
resolveManualSkills,
|
||||
|
|
@ -40,14 +33,21 @@ import {
|
|||
unionPrimeAllowedTools,
|
||||
MAX_PRIMED_SKILLS_PER_TURN,
|
||||
} from './skills';
|
||||
import {
|
||||
optionalChainWithEmptyCheck,
|
||||
extractLibreChatParams,
|
||||
getModelMaxTokens,
|
||||
getThreadData,
|
||||
} from '~/utils';
|
||||
import {
|
||||
registerCodeExecutionTools,
|
||||
registerFileAuthoringTools,
|
||||
isFileAuthoringToolDefinition,
|
||||
} from './tools';
|
||||
import { filterFilesByEndpointConfig } from '~/files';
|
||||
import { generateArtifactsPrompt } from '~/prompts';
|
||||
import { getProviderConfig } from '~/endpoints';
|
||||
import { primeResources } from './resources';
|
||||
import type { ResolvedManualSkill, ResolvedAlwaysApplySkill } from './skills';
|
||||
import type { TFilterFilesByAgentAccess } from './resources';
|
||||
|
||||
/**
|
||||
* Fraction of context budget reserved as headroom when no explicit maxContextTokens is set.
|
||||
|
|
@ -56,6 +56,15 @@ import type { TFilterFilesByAgentAccess } from './resources';
|
|||
*/
|
||||
const DEFAULT_RESERVE_RATIO = 0.05;
|
||||
const temporalSpecialVarRegex = /{{\s*(current_date|current_datetime|iso_datetime)\s*}}/i;
|
||||
const geminiModelVersionRegex = /^gemini-(\d+)(?:\.(\d+))?(?:-|$)/;
|
||||
const googleToolCombinationTextModels = [
|
||||
'gemini-3-flash-preview',
|
||||
'gemini-3-pro-preview',
|
||||
'gemini-3.1-flash-lite',
|
||||
'gemini-3.1-pro-preview',
|
||||
];
|
||||
const googleToolCombinationExcludedModalityRegex =
|
||||
/(?:^|-)image(?:-|$)|(?:^|-)live(?:-|$)|(?:^|-)tts(?:-|$)/;
|
||||
|
||||
function hasTemporalSpecialVars(text: string): boolean {
|
||||
return temporalSpecialVarRegex.test(text);
|
||||
|
|
@ -96,12 +105,71 @@ function hasGoogleSearchTool(tool: unknown): boolean {
|
|||
return 'googleSearch' in tool || 'googleSearchRetrieval' in tool;
|
||||
}
|
||||
|
||||
function normalizeGoogleModelName(model: string): string {
|
||||
const normalized = model.trim().toLowerCase();
|
||||
return normalized.split('/').pop() ?? normalized;
|
||||
}
|
||||
|
||||
function isKnownGoogleToolCombinationTextModel(model: string): boolean {
|
||||
return googleToolCombinationTextModels.some(
|
||||
(knownModel) => model === knownModel || model.startsWith(`${knownModel}-`),
|
||||
);
|
||||
}
|
||||
|
||||
function isGemini35OrLater(model: string): boolean {
|
||||
const match = geminiModelVersionRegex.exec(model);
|
||||
if (!match) {
|
||||
return false;
|
||||
}
|
||||
const major = Number(match[1]);
|
||||
const minor = Number(match[2] ?? '0');
|
||||
return major > 3 || (major === 3 && minor >= 5);
|
||||
}
|
||||
|
||||
function supportsGoogleToolCombination(model: unknown): boolean {
|
||||
if (typeof model !== 'string') {
|
||||
return false;
|
||||
}
|
||||
const normalized = model.toLowerCase().split('/').pop() ?? model.toLowerCase();
|
||||
return normalized.startsWith('gemini-3');
|
||||
const normalized = normalizeGoogleModelName(model);
|
||||
if (googleToolCombinationExcludedModalityRegex.test(normalized)) {
|
||||
return false;
|
||||
}
|
||||
return isKnownGoogleToolCombinationTextModel(normalized) || isGemini35OrLater(normalized);
|
||||
}
|
||||
|
||||
function isGoogleToolCombinationProvider(provider?: string): boolean {
|
||||
return provider === Providers.GOOGLE || provider === Providers.VERTEXAI;
|
||||
}
|
||||
|
||||
function shouldIncludeGoogleServerSideToolInvocations({
|
||||
provider,
|
||||
hasProviderTools,
|
||||
hasAgentTools,
|
||||
}: {
|
||||
provider?: string;
|
||||
hasProviderTools: boolean;
|
||||
hasAgentTools: boolean;
|
||||
}): boolean {
|
||||
return isGoogleToolCombinationProvider(provider) && hasProviderTools && hasAgentTools;
|
||||
}
|
||||
|
||||
function assertGoogleToolCombinationSupport(model: unknown): void {
|
||||
if (!supportsGoogleToolCombination(model)) {
|
||||
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||
}
|
||||
}
|
||||
|
||||
function enableGoogleServerSideToolInvocations({
|
||||
agent,
|
||||
llmConfig,
|
||||
}: {
|
||||
agent: Agent;
|
||||
llmConfig: Record<string, unknown>;
|
||||
}): void {
|
||||
llmConfig.includeServerSideToolInvocations = true;
|
||||
if (agent.model_parameters) {
|
||||
(agent.model_parameters as Record<string, unknown>).includeServerSideToolInvocations = true;
|
||||
}
|
||||
}
|
||||
|
||||
function resolveProviderToolConflicts({
|
||||
|
|
@ -965,13 +1033,16 @@ export async function initializeAgent(
|
|||
? (providerTools as GenericTool[])
|
||||
: (structuredTools ?? []);
|
||||
|
||||
if (
|
||||
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
|
||||
hasProviderTools &&
|
||||
hasAgentTools
|
||||
) {
|
||||
if (!supportsGoogleToolCombination(llmConfig.model)) {
|
||||
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||
if (isGoogleToolCombinationProvider(agent.provider) && hasProviderTools && hasAgentTools) {
|
||||
assertGoogleToolCombinationSupport(llmConfig.model);
|
||||
if (
|
||||
shouldIncludeGoogleServerSideToolInvocations({
|
||||
provider: agent.provider,
|
||||
hasProviderTools,
|
||||
hasAgentTools,
|
||||
})
|
||||
) {
|
||||
enableGoogleServerSideToolInvocations({ agent, llmConfig });
|
||||
}
|
||||
if (structuredTools?.length) {
|
||||
tools = structuredTools.concat(providerTools as GenericTool[]);
|
||||
|
|
@ -1043,6 +1114,21 @@ export async function initializeAgent(
|
|||
activeSkillNames = skillResult.activeSkillNames;
|
||||
}
|
||||
|
||||
const hasFinalAgentTools =
|
||||
(structuredTools?.length ?? 0) > 0 || (toolDefinitions?.length ?? 0) > 0;
|
||||
if (isGoogleToolCombinationProvider(agent.provider) && hasProviderTools && hasFinalAgentTools) {
|
||||
assertGoogleToolCombinationSupport(llmConfig.model);
|
||||
if (
|
||||
shouldIncludeGoogleServerSideToolInvocations({
|
||||
provider: agent.provider,
|
||||
hasProviderTools,
|
||||
hasAgentTools: hasFinalAgentTools,
|
||||
})
|
||||
) {
|
||||
enableGoogleServerSideToolInvocations({ agent, llmConfig });
|
||||
}
|
||||
}
|
||||
|
||||
const agentMaxContextNum = Number(agentMaxContextTokens) || DEFAULT_MAX_CONTEXT_TOKENS;
|
||||
const maxOutputTokensNum = Number(maxOutputTokens) || 0;
|
||||
const baseContextTokens = Math.max(0, agentMaxContextNum - maxOutputTokensNum);
|
||||
|
|
|
|||
130
packages/api/src/endpoints/google/initialize.spec.ts
Normal file
130
packages/api/src/endpoints/google/initialize.spec.ts
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { AuthKeys, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { EndpointDbMethods, ServerRequest } from '~/types';
|
||||
|
||||
const mockGetGoogleConfig = jest.fn(
|
||||
(_credentials?: unknown, _options?: unknown, _acceptRawApiKey?: unknown) => ({
|
||||
provider: Providers.VERTEXAI,
|
||||
llmConfig: { model: 'gemini-2.5-flash' },
|
||||
}),
|
||||
);
|
||||
const mockIsEnabled = jest.fn();
|
||||
const mockLoadServiceKey = jest.fn();
|
||||
const mockCheckUserKeyExpiry = jest.fn();
|
||||
|
||||
jest.mock('./llm', () => ({
|
||||
getGoogleConfig: (credentials: unknown, options: unknown, acceptRawApiKey?: unknown) =>
|
||||
mockGetGoogleConfig(credentials, options, acceptRawApiKey),
|
||||
}));
|
||||
|
||||
jest.mock('~/utils', () => ({
|
||||
isEnabled: (value: unknown) => mockIsEnabled(value),
|
||||
loadServiceKey: (keyPath: unknown) => mockLoadServiceKey(keyPath),
|
||||
checkUserKeyExpiry: (expiresAt: unknown, endpoint: unknown) =>
|
||||
mockCheckUserKeyExpiry(expiresAt, endpoint),
|
||||
}));
|
||||
|
||||
import { initializeGoogle } from './initialize';
|
||||
|
||||
function createDb(): EndpointDbMethods {
|
||||
return {
|
||||
getUserKey: jest.fn().mockResolvedValue('user-google-key'),
|
||||
getUserKeyValues: jest.fn().mockResolvedValue({}),
|
||||
};
|
||||
}
|
||||
|
||||
function createReq(): ServerRequest {
|
||||
return {
|
||||
body: {},
|
||||
config: {},
|
||||
user: { id: 'user-1' },
|
||||
} as ServerRequest;
|
||||
}
|
||||
|
||||
function getGoogleConfigCall(): [Record<string, unknown>, Record<string, unknown>] {
|
||||
expect(mockGetGoogleConfig).toHaveBeenCalled();
|
||||
return mockGetGoogleConfig.mock.calls[0] as unknown as [
|
||||
Record<string, unknown>,
|
||||
Record<string, unknown>,
|
||||
];
|
||||
}
|
||||
|
||||
describe('initializeGoogle', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.GOOGLE_KEY;
|
||||
delete process.env.GOOGLE_REVERSE_PROXY;
|
||||
delete process.env.GOOGLE_AUTH_HEADER;
|
||||
delete process.env.GOOGLE_SERVICE_KEY_FILE;
|
||||
delete process.env.VERTEX_PROJECT_ID;
|
||||
delete process.env.GOOGLE_CLOUD_PROJECT;
|
||||
delete process.env.GCLOUD_PROJECT;
|
||||
delete process.env.GOOGLE_PROJECT_ID;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('forces Vertex AI ADC config and ignores GOOGLE_KEY for vertexai endpoint', async () => {
|
||||
process.env.GOOGLE_KEY = 'test-api-key';
|
||||
process.env.VERTEX_PROJECT_ID = 'fiery-catwalk-385918';
|
||||
mockLoadServiceKey.mockResolvedValue(null);
|
||||
|
||||
const db = createDb();
|
||||
|
||||
await initializeGoogle({
|
||||
req: createReq(),
|
||||
endpoint: Providers.VERTEXAI,
|
||||
model_parameters: { model: 'gemini-2.5-flash' },
|
||||
db,
|
||||
});
|
||||
|
||||
expect(mockLoadServiceKey).toHaveBeenCalledTimes(1);
|
||||
expect(db.getUserKey).not.toHaveBeenCalled();
|
||||
expect(mockCheckUserKeyExpiry).not.toHaveBeenCalled();
|
||||
|
||||
const [credentials, options] = getGoogleConfigCall();
|
||||
expect(credentials).toEqual({
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: {},
|
||||
});
|
||||
expect(credentials).not.toHaveProperty(AuthKeys.GOOGLE_API_KEY);
|
||||
expect(options).toEqual(
|
||||
expect.objectContaining({
|
||||
forceVertex: true,
|
||||
projectId: 'fiery-catwalk-385918',
|
||||
modelOptions: { model: 'gemini-2.5-flash' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('keeps Google API-key config for the google endpoint', async () => {
|
||||
process.env.GOOGLE_KEY = 'test-api-key';
|
||||
process.env.VERTEX_PROJECT_ID = 'fiery-catwalk-385918';
|
||||
|
||||
await initializeGoogle({
|
||||
req: createReq(),
|
||||
endpoint: EModelEndpoint.google,
|
||||
model_parameters: { model: 'gemini-2.5-flash' },
|
||||
db: createDb(),
|
||||
});
|
||||
|
||||
expect(mockLoadServiceKey).not.toHaveBeenCalled();
|
||||
|
||||
const [credentials, options] = getGoogleConfigCall();
|
||||
expect(credentials).toEqual({
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: {},
|
||||
[AuthKeys.GOOGLE_API_KEY]: 'test-api-key',
|
||||
});
|
||||
expect(options).toEqual(
|
||||
expect.objectContaining({
|
||||
forceVertex: false,
|
||||
projectId: undefined,
|
||||
modelOptions: { model: 'gemini-2.5-flash' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import path from 'path';
|
||||
import { Providers } from '@librechat/agents';
|
||||
import { EModelEndpoint, AuthKeys } from 'librechat-data-provider';
|
||||
import type {
|
||||
BaseInitializeParams,
|
||||
|
|
@ -23,14 +24,15 @@ export async function initializeGoogle({
|
|||
model_parameters,
|
||||
db,
|
||||
}: BaseInitializeParams): Promise<InitializeResultBase> {
|
||||
void endpoint;
|
||||
const appConfig = req.config;
|
||||
const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env;
|
||||
const isUserProvided = GOOGLE_KEY === 'user_provided';
|
||||
const isVertexEndpoint = endpoint === Providers.VERTEXAI;
|
||||
const useUserProvidedGoogleKey = !isVertexEndpoint && isUserProvided;
|
||||
const { key: expiresAt } = req.body;
|
||||
|
||||
let userKey = null;
|
||||
if (expiresAt && isUserProvided) {
|
||||
if (expiresAt && useUserProvidedGoogleKey) {
|
||||
checkUserKeyExpiry(expiresAt, EModelEndpoint.google);
|
||||
userKey = await db.getUserKey({ userId: req.user?.id ?? '', name: EModelEndpoint.google });
|
||||
}
|
||||
|
|
@ -39,9 +41,10 @@ export async function initializeGoogle({
|
|||
|
||||
/** Check if GOOGLE_KEY is provided at all (including 'user_provided') */
|
||||
const isGoogleKeyProvided =
|
||||
(GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null);
|
||||
!isVertexEndpoint &&
|
||||
((GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (useUserProvidedGoogleKey && userKey != null));
|
||||
|
||||
if (!isGoogleKeyProvided && loadServiceKey) {
|
||||
if ((isVertexEndpoint || !isGoogleKeyProvided) && loadServiceKey) {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
try {
|
||||
const serviceKeyPath =
|
||||
|
|
@ -56,11 +59,11 @@ export async function initializeGoogle({
|
|||
}
|
||||
}
|
||||
|
||||
const credentials: GoogleCredentials = isUserProvided
|
||||
const credentials: GoogleCredentials = useUserProvidedGoogleKey
|
||||
? (userKey as GoogleCredentials)
|
||||
: {
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
...(!isVertexEndpoint && { [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY }),
|
||||
};
|
||||
|
||||
let clientOptions: GoogleConfigOptions = {};
|
||||
|
|
@ -84,6 +87,13 @@ export async function initializeGoogle({
|
|||
authHeader: isEnabled(GOOGLE_AUTH_HEADER) ?? undefined,
|
||||
proxy: PROXY ?? undefined,
|
||||
modelOptions: model_parameters ?? {},
|
||||
forceVertex: isVertexEndpoint,
|
||||
projectId: isVertexEndpoint
|
||||
? (process.env.VERTEX_PROJECT_ID ??
|
||||
process.env.GOOGLE_CLOUD_PROJECT ??
|
||||
process.env.GCLOUD_PROJECT ??
|
||||
process.env.GOOGLE_PROJECT_ID)
|
||||
: undefined,
|
||||
...clientOptions,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,23 @@ describe('getGoogleConfig', () => {
|
|||
expect(result.llmConfig).toHaveProperty('apiKey', 'raw-api-key-string');
|
||||
});
|
||||
|
||||
it('should not let project id force Vertex AI without the force flag', () => {
|
||||
const credentials = {
|
||||
[AuthKeys.GOOGLE_API_KEY]: 'test-api-key',
|
||||
};
|
||||
|
||||
const result = getGoogleConfig(credentials, {
|
||||
projectId: 'fiery-catwalk-385918',
|
||||
modelOptions: {
|
||||
model: 'gemini-2.5-flash',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.provider).toBe(Providers.GOOGLE);
|
||||
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
|
||||
expect(result.llmConfig).not.toHaveProperty('authOptions');
|
||||
});
|
||||
|
||||
it('should handle model options including temperature and topP/topK', () => {
|
||||
const credentials = {
|
||||
[AuthKeys.GOOGLE_API_KEY]: 'test-api-key',
|
||||
|
|
@ -318,6 +335,53 @@ describe('getGoogleConfig', () => {
|
|||
expect(result.llmConfig).toHaveProperty('location', 'us-central1');
|
||||
});
|
||||
|
||||
it('should force Vertex AI ADC config with a project id even when an API key is present', () => {
|
||||
const credentials = {
|
||||
[AuthKeys.GOOGLE_API_KEY]: 'test-api-key',
|
||||
};
|
||||
|
||||
const result = getGoogleConfig(credentials, {
|
||||
forceVertex: true,
|
||||
projectId: 'fiery-catwalk-385918',
|
||||
modelOptions: {
|
||||
model: 'gemini-2.5-flash',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.provider).toBe(Providers.VERTEXAI);
|
||||
expect(result.llmConfig).not.toHaveProperty('apiKey');
|
||||
expect((result.llmConfig as Record<string, unknown>).authOptions).toEqual({
|
||||
projectId: 'fiery-catwalk-385918',
|
||||
});
|
||||
});
|
||||
|
||||
it('should force Vertex AI service-account config when an API key is also present', () => {
|
||||
const credentials = {
|
||||
[AuthKeys.GOOGLE_API_KEY]: 'test-api-key',
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: {
|
||||
project_id: 'test-project',
|
||||
client_email: 'test@test-project.iam.gserviceaccount.com',
|
||||
private_key: 'test-private-key',
|
||||
},
|
||||
};
|
||||
|
||||
const result = getGoogleConfig(credentials, {
|
||||
forceVertex: true,
|
||||
modelOptions: {
|
||||
model: 'gemini-2.5-flash',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.provider).toBe(Providers.VERTEXAI);
|
||||
expect(result.llmConfig).not.toHaveProperty('apiKey');
|
||||
expect((result.llmConfig as Record<string, unknown>).authOptions).toMatchObject({
|
||||
projectId: 'test-project',
|
||||
credentials: expect.objectContaining({
|
||||
project_id: 'test-project',
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('should use GOOGLE_LOC env variable for Vertex AI location', () => {
|
||||
process.env.GOOGLE_LOC = 'europe-west1';
|
||||
|
||||
|
|
|
|||
|
|
@ -264,6 +264,10 @@ function applyVertexMultiRegionEndpoint(config: VertexAIClientOptions & { endpoi
|
|||
}
|
||||
}
|
||||
|
||||
function hasServiceKeyCredentials(serviceKey: Record<string, unknown>): boolean {
|
||||
return Object.keys(serviceKey).length > 0;
|
||||
}
|
||||
|
||||
export function getSafetySettings(
|
||||
model?: string,
|
||||
): Array<{ category: string; threshold: string }> | undefined {
|
||||
|
|
@ -337,7 +341,12 @@ export function getGoogleConfig(
|
|||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {});
|
||||
|
||||
const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null;
|
||||
const project_id = !apiKey ? (serviceKey?.project_id ?? null) : null;
|
||||
let project_id = null;
|
||||
if (options.forceVertex === true) {
|
||||
project_id = options.projectId ?? serviceKey?.project_id ?? null;
|
||||
} else if (!apiKey) {
|
||||
project_id = serviceKey?.project_id ?? null;
|
||||
}
|
||||
|
||||
const reverseProxyUrl = options.reverseProxyUrl;
|
||||
const authHeader = options.authHeader;
|
||||
|
|
@ -373,7 +382,7 @@ export function getGoogleConfig(
|
|||
|
||||
let provider;
|
||||
|
||||
if (project_id) {
|
||||
if (options.forceVertex === true || project_id) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else {
|
||||
provider = Providers.GOOGLE;
|
||||
|
|
@ -381,10 +390,13 @@ export function getGoogleConfig(
|
|||
|
||||
// If we have a GCP project => Vertex AI
|
||||
if (provider === Providers.VERTEXAI) {
|
||||
(llmConfig as VertexAIClientOptions).authOptions = {
|
||||
credentials: { ...serviceKey },
|
||||
projectId: project_id,
|
||||
};
|
||||
(llmConfig as VertexAIClientOptions).authOptions = removeNullishValues(
|
||||
{
|
||||
...(hasServiceKeyCredentials(serviceKey) && { credentials: { ...serviceKey } }),
|
||||
projectId: project_id,
|
||||
},
|
||||
true,
|
||||
);
|
||||
const location = process.env.GOOGLE_LOC || 'us-central1';
|
||||
(llmConfig as VertexAIClientOptions).location = location;
|
||||
} else if (apiKey && provider === Providers.GOOGLE) {
|
||||
|
|
|
|||
|
|
@ -27,4 +27,8 @@ export interface GoogleConfigOptions {
|
|||
streamRate?: number;
|
||||
/** Model to use for title generation */
|
||||
titleModel?: string;
|
||||
/** Force Vertex AI auth semantics even when a Google API key is configured */
|
||||
forceVertex?: boolean;
|
||||
/** GCP project id for Vertex AI ADC/service-account authentication */
|
||||
projectId?: string;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue