Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions js/plugins/google-genai/src/googleai/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import {
* @returns A promise that resolves to an array of Model objects.
*/
export async function listModels(
apiKey: string,
apiKey: string | undefined,
clientOptions?: ClientOptions
): Promise<Model[]> {
const url = getGoogleAIUrl({
Expand Down Expand Up @@ -75,7 +75,7 @@ export async function listModels(
* @throws {Error} If the API request fails or the response cannot be parsed.
*/
export async function generateContent(
apiKey: string,
apiKey: string | undefined,
model: string,
generateContentRequest: GenerateContentRequest,
clientOptions?: ClientOptions
Expand Down Expand Up @@ -108,7 +108,7 @@ export async function generateContent(
* @throws {Error} If the API request fails.
*/
export async function generateContentStream(
apiKey: string,
apiKey: string | undefined,
model: string,
generateContentRequest: GenerateContentRequest,
clientOptions?: ClientOptions
Expand Down Expand Up @@ -140,7 +140,7 @@ export async function generateContentStream(
* @throws {Error} If the API request fails or the response cannot be parsed.
*/
export async function embedContent(
apiKey: string,
apiKey: string | undefined,
model: string,
embedContentRequest: EmbedContentRequest,
clientOptions?: ClientOptions
Expand All @@ -162,7 +162,7 @@ export async function embedContent(
}

export async function imagenPredict(
apiKey: string,
apiKey: string | undefined,
model: string,
imagenPredictRequest: ImagenPredictRequest,
clientOptions?: ClientOptions
Expand All @@ -185,7 +185,7 @@ export async function imagenPredict(
}

export async function veoPredict(
apiKey: string,
apiKey: string | undefined,
model: string,
veoPredictRequest: VeoPredictRequest,
clientOptions?: ClientOptions
Expand All @@ -208,7 +208,7 @@ export async function veoPredict(
}

export async function veoCheckOperation(
apiKey: string,
apiKey: string | undefined,
operation: string,
clientOptions?: ClientOptions
): Promise<VeoOperation> {
Expand Down Expand Up @@ -265,7 +265,7 @@ export function getGoogleAIUrl(params: {

function getFetchOptions(params: {
method: 'POST' | 'GET';
apiKey: string;
apiKey: string | undefined;
body?: string;
clientOptions?: ClientOptions;
}) {
Expand Down Expand Up @@ -310,7 +310,7 @@ function getAbortSignal(
* @returns {HeadersInit} An object containing the headers to be included in the request.
*/
function getHeaders(
apiKey: string,
apiKey?: string,
clientOptions?: ClientOptions
): HeadersInit {
let customHeaders = {};
Expand All @@ -322,10 +322,13 @@ function getHeaders(
const headers: HeadersInit = {
...customHeaders,
'Content-Type': 'application/json',
'x-goog-api-key': apiKey,
'x-goog-api-client': getGenkitClientHeader(),
};

if (apiKey) {
headers['x-goog-api-key'] = apiKey;
}

return headers;
}

Expand Down
29 changes: 20 additions & 9 deletions js/plugins/google-genai/src/googleai/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { embedderRef } from 'genkit/embedder';
import { embedder as pluginEmbedder } from 'genkit/plugin';
import { embedContent } from './client.js';
import {
ClientOptions,
EmbedContentRequest,
GoogleAIPluginOptions,
Model,
Expand Down Expand Up @@ -132,6 +133,11 @@ export function defineEmbedder(
): EmbedderAction {
checkApiKey(pluginOptions?.apiKey);
const ref = model(name);
const clientOptions: ClientOptions = {
apiVersion: pluginOptions?.apiVersion,
baseUrl: pluginOptions?.baseUrl,
customHeaders: pluginOptions?.customHeaders,
};

return pluginEmbedder(
{
Expand All @@ -148,15 +154,20 @@ export function defineEmbedder(

const embeddings = await Promise.all(
request.input.map(async (doc) => {
const response = await embedContent(embedApiKey, embedVersion, {
taskType: request.options?.taskType,
title: request.options?.title,
content: {
role: '',
parts: [{ text: doc.text }],
},
outputDimensionality: request.options?.outputDimensionality,
} as EmbedContentRequest);
const response = await embedContent(
embedApiKey,
embedVersion,
{
taskType: request.options?.taskType,
title: request.options?.title,
content: {
role: '',
parts: [{ text: doc.text }],
},
outputDimensionality: request.options?.outputDimensionality,
} as EmbedContentRequest,
clientOptions
);
const values = response.embedding.values;
return { embedding: values };
})
Expand Down
1 change: 1 addition & 0 deletions js/plugins/google-genai/src/googleai/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ export function defineModel(
const clientOptions: ClientOptions = {
apiVersion: pluginOptions?.apiVersion,
baseUrl: pluginOptions?.baseUrl,
customHeaders: pluginOptions?.customHeaders,
};

const middleware: ModelMiddleware[] = [];
Expand Down
1 change: 1 addition & 0 deletions js/plugins/google-genai/src/googleai/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ export function defineModel(
const clientOptions: ClientOptions = {
apiVersion: pluginOptions?.apiVersion,
baseUrl: pluginOptions?.baseUrl,
customHeaders: pluginOptions?.customHeaders,
};

return pluginModel(
Expand Down
4 changes: 4 additions & 0 deletions js/plugins/google-genai/src/googleai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ export interface GoogleAIPluginOptions {
experimental_debugTraces?: boolean;
/** Use `responseSchema` field instead of `responseJsonSchema`. */
legacyResponseSchema?: boolean;
/**
* Additional headers to send along with the request.
*/
customHeaders?: Record<string, string>;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions js/plugins/google-genai/src/googleai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export function checkApiKey(
export function calculateApiKey(
pluginApiKey: string | false | undefined,
requestApiKey: string | undefined
): string {
): string | undefined {
let apiKey: string | undefined;

// Don't get the key from the environment if pluginApiKey is false
Expand All @@ -113,7 +113,7 @@ export function calculateApiKey(
apiKey = requestApiKey || apiKey;

if (pluginApiKey === false && !requestApiKey) {
throw API_KEY_FALSE_ERROR;
return undefined;
}

if (!apiKey) {
Expand Down
1 change: 1 addition & 0 deletions js/plugins/google-genai/src/googleai/veo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ export function defineModel(
const clientOptions: ClientOptions = {
apiVersion: pluginOptions?.apiVersion,
baseUrl: pluginOptions?.baseUrl,
customHeaders: pluginOptions?.customHeaders,
};

return pluginBackgroundModel({
Expand Down
19 changes: 6 additions & 13 deletions js/plugins/google-genai/tests/googleai/embedder_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import * as assert from 'assert';
import { Document, GenkitError } from 'genkit';
import { Document } from 'genkit';
import { afterEach, beforeEach, describe, it } from 'node:test';
import * as sinon from 'sinon';
import {
Expand Down Expand Up @@ -144,23 +144,16 @@ describe('defineGoogleAIEmbedder', () => {
});

it('throws if apiKey is false in pluginOptions and not provided in call options', async () => {
mockFetchResponse({ embedding: { values: [] } });
const embedder = defineEmbedder('text-embedding-004', {
apiKey: false,
});
await assert.rejects(
embedder.run({
assert.ok(
await embedder.run({
input: [new Document({ content: [{ text: 'test' }] })],
}),
(err: GenkitError) => {
assert.strictEqual(err.status, 'INVALID_ARGUMENT');
assert.match(
err.message,
/GoogleAI plugin was initialized with \{apiKey: false\}/
);
return true;
}
})
);
sinon.assert.notCalled(fetchStub);
sinon.assert.calledOnce(fetchStub);
});

it('uses API key from call options if apiKey is false in pluginOptions', async () => {
Expand Down
10 changes: 4 additions & 6 deletions js/plugins/google-genai/tests/googleai/gemini_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,11 @@ describe('Google AI Gemini', () => {
);
});

it('throws if apiKey is false and not in call config', async () => {
it('works if apiKey is false and not in call config', async () => {
mockFetchResponse(defaultApiResponse);
const model = defineModel('gemini-2.0-flash', { apiKey: false });
await assert.rejects(
model.run(minimalRequest),
/GoogleAI plugin was initialized with \{apiKey: false\}/
);
sinon.assert.notCalled(fetchStub);
assert.ok(await model.run(minimalRequest));
sinon.assert.calledOnce(fetchStub);
});

it('uses API key from call config if apiKey is false', async () => {
Expand Down
19 changes: 9 additions & 10 deletions js/plugins/google-genai/tests/googleai/imagen_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ import {
ImagenPredictResponse,
ImagenPrediction,
} from '../../src/googleai/types.js';
import {
API_KEY_FALSE_ERROR,
MISSING_API_KEY_ERROR,
} from '../../src/googleai/utils.js';
import { MISSING_API_KEY_ERROR } from '../../src/googleai/utils.js';

const { toImagenParameters, fromImagenPrediction } = TEST_ONLY;

Expand Down Expand Up @@ -394,20 +391,22 @@ describe('Google AI Imagen', () => {
assert.strictEqual(fetchArgs[1].headers['x-goog-api-key'], requestApiKey);
});

it('apiKey false at init, missing in request - throws error', async () => {
it('works with apiKey false at init, missing in request', async () => {
mockFetchResponse({
predictions: [{ bytesBase64Encoded: 'jkl', mimeType: 'image/png' }],
});
const modelRunner = captureModelRunner({ apiKey: false });

await assert.rejects(
modelRunner(
assert.ok(
await modelRunner(
{
messages: [{ role: 'user', content: [{ text: 'A car' }] }],
config: {},
},
{}
),
API_KEY_FALSE_ERROR
)
);
sinon.assert.notCalled(fetchStub);
sinon.assert.calledOnce(fetchStub);
});

it('defineImagenModel throws if key not found in env or args', async () => {
Expand Down
8 changes: 2 additions & 6 deletions js/plugins/google-genai/tests/googleai/index_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ describe('GoogleAI Plugin', () => {
);
});

it('should return empty array if API key is missing for listActions', async () => {
it('should still call list if API key is missing for listActions', async () => {
delete process.env.GOOGLE_API_KEY;
delete process.env.GEMINI_API_KEY;
delete process.env.GOOGLE_GENAI_API_KEY;
Expand All @@ -600,11 +600,7 @@ describe('GoogleAI Plugin', () => {
[],
'Should return empty array if API key is not found'
);
assert.strictEqual(
fetchMock.mock.callCount(),
0,
'Fetch should not be called'
);
assert.strictEqual(fetchMock.mock.callCount(), 1);
});

it('should use listActions cache', async () => {
Expand Down
8 changes: 2 additions & 6 deletions js/plugins/google-genai/tests/googleai/utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import assert from 'node:assert';
import { afterEach, beforeEach, describe, it } from 'node:test';
import process from 'process';
import {
API_KEY_FALSE_ERROR,
MISSING_API_KEY_ERROR,
calculateApiKey,
checkApiKey,
Expand Down Expand Up @@ -276,11 +275,8 @@ describe('API Key Utils', () => {
assert.strictEqual(calculateApiKey(undefined, undefined), 'env_key');
});

it('throws API_KEY_FALSE_ERROR if apiKey1 is false and apiKey2 is undefined', () => {
assert.throws(
() => calculateApiKey(false, undefined),
API_KEY_FALSE_ERROR
);
it('returns undefined if apiKey1 is false and apiKey2 is undefined', () => {
assert.strictEqual(calculateApiKey(false, undefined), undefined);
});

it('throws MISSING_API_KEY_ERROR if apiKey1 and apiKey2 are undefined and no env var', () => {
Expand Down