Skip to content
6 changes: 6 additions & 0 deletions js/ai/src/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ export const GenerationCommonConfigSchema = z
'Set of character sequences (up to 5) that will stop output generation.'
)
.optional(),
apiKey: z
.string()
.describe(
'API Key to use for the model call, overrides API key provided in plugin config.'
)
.optional(),
})
.passthrough();

Expand Down
1 change: 1 addition & 0 deletions js/genkit/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
export {
EmbedderInfoSchema,
embedderRef,
type EmbedRequest,
type EmbedderAction,
type EmbedderArgument,
type EmbedderInfo,
Expand Down
34 changes: 29 additions & 5 deletions js/plugins/compat-oai/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import type {
import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit';
import type { ModelAction, ModelInfo } from 'genkit/model';
import { model } from 'genkit/plugin';
import type OpenAI from 'openai';
import OpenAI from 'openai';
import { Response } from 'openai/core.mjs';
import type {
SpeechCreateParams,
Transcription,
TranscriptionCreateParams,
} from 'openai/resources/audio/index.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

export type SpeechRequestBuilder = (
req: GenerateRequest,
Expand Down Expand Up @@ -185,10 +187,16 @@ export function defineCompatOpenAISpeechModel<
client: OpenAI;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: SpeechRequestBuilder;
pluginOptions: PluginOptions;
}): ModelAction {
const { name, client, modelRef, requestBuilder } = params;
const {
name,
client: defaultClient,
pluginOptions,
modelRef,
requestBuilder,
} = params;
const modelName = name.substring(name.indexOf('/') + 1);

return model(
{
name,
Expand All @@ -197,6 +205,11 @@ export function defineCompatOpenAISpeechModel<
},
async (request, { abortSignal }) => {
const ttsRequest = toTTSRequest(modelName!, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.audio.speech.create(ttsRequest, {
signal: abortSignal,
});
Expand Down Expand Up @@ -338,11 +351,17 @@ export function defineCompatOpenAITranscriptionModel<
>(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: TranscriptionRequestBuilder;
}): ModelAction {
const { name, client, modelRef, requestBuilder } = params;

const {
name,
pluginOptions,
client: defaultClient,
modelRef,
requestBuilder,
} = params;
return model(
{
name,
Expand All @@ -353,6 +372,11 @@ export function defineCompatOpenAITranscriptionModel<
const modelName = name.substring(name.indexOf('/') + 1);

const params = toSttRequest(modelName!, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
// Explicitly setting stream to false ensures we use the non-streaming overload
const result = await client.audio.transcriptions.create(
{
Expand Down
43 changes: 22 additions & 21 deletions js/plugins/compat-oai/src/deepseek/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,25 @@ import {

export type DeepSeekPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;

const resolver = async (
client: OpenAI,
actionType: ActionType,
actionName: string
) => {
if (actionType === 'model') {
const modelRef = deepSeekModelRef({
name: actionName,
});
return defineCompatOpenAIModel({
name: modelRef.name,
client,
modelRef,
requestBuilder: deepSeekRequestBuilder,
});
} else {
logger.warn('Only model actions are supported by the DeepSeek plugin');
return undefined;
}
};
function createResolver(pluginOptions: PluginOptions) {
return async (client: OpenAI, actionType: ActionType, actionName: string) => {
if (actionType === 'model') {
const modelRef = deepSeekModelRef({
name: actionName,
});
return defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: deepSeekRequestBuilder,
});
} else {
logger.warn('Only model actions are supported by the DeepSeek plugin');
return undefined;
}
};
}

const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
return await client.models.list().then((response) =>
Expand Down Expand Up @@ -87,6 +86,7 @@ export function deepSeekPlugin(
'Please pass in the API key or set the DEEPSEEK_API_KEY environment variable.',
});
}
const pluginOptions = { name: 'deepseek', ...options };
return openAICompatible({
name: 'deepseek',
baseURL: 'https://api.deepseek.com',
Expand All @@ -97,12 +97,13 @@ export function deepSeekPlugin(
defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: deepSeekRequestBuilder,
})
);
},
resolver,
resolver: createResolver(pluginOptions),
listActions,
});
}
Expand Down
11 changes: 9 additions & 2 deletions js/plugins/compat-oai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import type { EmbedderAction, EmbedderReference } from 'genkit';
import { embedder } from 'genkit/plugin';
import OpenAI from 'openai';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

/**
* Method to define a new Genkit Embedder that is compatibale with the Open AI
Expand All @@ -37,11 +39,11 @@ import OpenAI from 'openai';
export function defineCompatOpenAIEmbedder(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
embedderRef?: EmbedderReference;
}): EmbedderAction {
const { name, client, embedderRef } = params;
const { name, client: defaultClient, pluginOptions, embedderRef } = params;
const modelName = name.substring(name.indexOf('/') + 1);

return embedder(
{
name,
Expand All @@ -50,6 +52,11 @@ export function defineCompatOpenAIEmbedder(params: {
},
async (req) => {
const { encodingFormat: encoding_format, ...restOfConfig } = req.options;
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
req,
defaultClient
);
const embeddings = await client.embeddings.create({
model: modelName!,
input: req.input.map((d) => d.text),
Expand Down
16 changes: 15 additions & 1 deletion js/plugins/compat-oai/src/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import type {
ImageGenerateParams,
ImagesResponse,
} from 'openai/resources/images.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

export type ImageRequestBuilder = (
req: GenerateRequest,
Expand Down Expand Up @@ -122,10 +124,17 @@ export function defineCompatOpenAIImageModel<
>(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: ImageRequestBuilder;
}): ModelAction<CustomOptions> {
const { name, client, modelRef, requestBuilder } = params;
const {
name,
client: defaultClient,
pluginOptions,
modelRef,
requestBuilder,
} = params;
const modelName = name.substring(name.indexOf('/') + 1);

return model(
Expand All @@ -135,6 +144,11 @@ export function defineCompatOpenAIImageModel<
configSchema: modelRef?.configSchema,
},
async (request, { abortSignal }) => {
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.images.generate(
toImageGenerateParams(modelName!, request, requestBuilder),
{ signal: abortSignal }
Expand Down
41 changes: 26 additions & 15 deletions js/plugins/compat-oai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import { ActionMetadata } from 'genkit';
import { ResolvableAction, genkitPluginV2 } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import { OpenAI, type ClientOptions } from 'openai';
import OpenAI, { type ClientOptions } from 'openai';
import { compatOaiModelRef, defineCompatOpenAIModel } from './model.js';

export {
Expand Down Expand Up @@ -45,7 +45,8 @@ export {
type ModelRequestBuilder,
} from './model.js';

export interface PluginOptions extends Partial<ClientOptions> {
export interface PluginOptions extends Partial<Omit<ClientOptions, 'apiKey'>> {
apiKey?: ClientOptions['apiKey'] | false;
name: string;
initializer?: (client: OpenAI) => Promise<ResolvableAction[]>;
resolver?: (
Expand Down Expand Up @@ -110,24 +111,33 @@ export interface PluginOptions extends Partial<ClientOptions> {
*/
export const openAICompatible = (options: PluginOptions) => {
let listActionsCache;
var client: OpenAI;
function createClient() {
if (client) return client;
const { apiKey, ...restofOptions } = options;
client = new OpenAI({
...restofOptions,
apiKey: apiKey === false ? 'placeholder' : apiKey,
});
return client;
}
return genkitPluginV2({
name: options.name,
async init() {
if (!options.initializer) {
return [];
}
const client = new OpenAI(options);
return await options.initializer(client);
return await options.initializer(createClient());
},
async resolve(actionType: ActionType, actionName: string) {
const client = new OpenAI(options);
if (options.resolver) {
return await options.resolver(client, actionType, actionName);
return await options.resolver(createClient(), actionType, actionName);
} else {
if (actionType === 'model') {
return defineCompatOpenAIModel({
name: actionName,
client,
client: createClient(),
pluginOptions: options,
modelRef: compatOaiModelRef({
name: actionName,
}),
Expand All @@ -136,14 +146,15 @@ export const openAICompatible = (options: PluginOptions) => {
return undefined;
}
},
list: options.listActions
? async () => {
if (listActionsCache) return listActionsCache;
const client = new OpenAI(options);
listActionsCache = await options.listActions!(client);
return listActionsCache;
}
: undefined,
list:
// Don't attempt to list models if apiKey set to false
options.listActions && options.apiKey !== false
? async () => {
if (listActionsCache) return listActionsCache;
listActionsCache = await options.listActions!(createClient());
return listActionsCache;
}
: undefined,
});
};

Expand Down
Loading