Skip to content
65 changes: 60 additions & 5 deletions js/genkit/src/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
* limitations under the License.
*/

import type { Action, ActionMetadata, BackgroundAction } from '@genkit-ai/core';
import { type ModelAction } from '@genkit-ai/ai/model';
import {
GenkitError,
type Action,
type ActionMetadata,
type BackgroundAction,
} from '@genkit-ai/core';
import type { Genkit } from './genkit.js';
import type { ActionType } from './registry.js';
export { embedder, embedderActionMetadata } from '@genkit-ai/ai/embedder';
Expand All @@ -26,7 +32,6 @@ export {
} from '@genkit-ai/ai/model';
export { reranker } from '@genkit-ai/ai/reranker';
export { indexer, retriever } from '@genkit-ai/ai/retriever';

export interface PluginProvider {
name: string;
initializer: () => void | Promise<void>;
Expand All @@ -45,6 +50,9 @@ export interface GenkitPluginV2 {
name: string
) => ResolvableAction | undefined | Promise<ResolvableAction | undefined>;
list?: () => ActionMetadata[] | Promise<ActionMetadata[]>;

// A shortcut for resolving a model.
model(name: string): Promise<ModelAction>;
}

export type GenkitPlugin = (genkit: Genkit) => PluginProvider;
Expand Down Expand Up @@ -85,10 +93,57 @@ export function genkitPlugin<T extends PluginInit>(
});
}

export class GenkitPluginV2Instance implements Required<GenkitPluginV2> {
readonly version = 'v2';
readonly name: string;

private plugin: Omit<GenkitPluginV2, 'version' | 'model'>;

constructor(plugin: Omit<GenkitPluginV2, 'version' | 'model'>) {
this.name = plugin.name;
this.plugin = plugin;
}

init(): ResolvableAction[] | Promise<ResolvableAction[]> {
if (!this.plugin.init) {
return [];
}
return this.plugin.init();
}

list(): ActionMetadata[] | Promise<ActionMetadata[]> {
if (!this.plugin.list) {
return [];
}
return this.plugin.list();
}

resolve(
actionType: ActionType,
name: string
): ResolvableAction | undefined | Promise<ResolvableAction | undefined> {
if (!this.plugin.resolve) {
return undefined;
}
return this.plugin.resolve(actionType, name);
}

async model(name: string): Promise<ModelAction> {
const model = await this.resolve('model', name);
if (!model) {
throw new GenkitError({
message: `Failed to resolve model ${name} for plugin ${this.name}`,
status: 'NOT_FOUND',
});
}
return model as ModelAction;
}
}

export function genkitPluginV2(
options: Omit<GenkitPluginV2, 'version'>
): GenkitPluginV2 {
return { ...options, version: 'v2' };
options: Omit<GenkitPluginV2, 'version' | 'model'>
): GenkitPluginV2Instance {
return new GenkitPluginV2Instance(options);
}

export function isPluginV2(plugin: unknown): plugin is GenkitPluginV2 {
Expand Down
15 changes: 15 additions & 0 deletions js/genkit/tests/plugins_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ const v2Plugin = genkitPluginV2({
resolve(actionType, name) {
switch (actionType) {
case 'model':
if (name === 'not-found') {
return undefined;
}
return model({ name }, async () => {
return {};
});
Expand Down Expand Up @@ -241,4 +244,16 @@ describe('session', () => {
])
);
});

it('resolves model using model resolve helper', async () => {
const act = await v2Plugin.model('foo-model');
assert.ok(act);
assert.strictEqual(act.__action.name, 'foo-model');
assert.strictEqual(act.__action.actionType, 'model');

await assert.rejects(v2Plugin.model('not-found'), {
name: 'GenkitError',
status: 'NOT_FOUND',
});
});
});
2 changes: 1 addition & 1 deletion js/plugins/compat-oai/src/deepseek/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
z,
} from 'genkit';
import { logger } from 'genkit/logging';
import { GenkitPluginV2 } from 'genkit/plugin';
import { type GenkitPluginV2 } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import OpenAI from 'openai';
import { openAICompatible, PluginOptions } from '../index.js';
Expand Down
2 changes: 1 addition & 1 deletion js/plugins/compat-oai/src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {
ModelReference,
z,
} from 'genkit';
import { GenkitPluginV2, ResolvableAction } from 'genkit/plugin';
import { ResolvableAction, type GenkitPluginV2 } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import OpenAI from 'openai';
import {
Expand Down
2 changes: 1 addition & 1 deletion js/plugins/compat-oai/src/xai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
z,
} from 'genkit';
import { logger } from 'genkit/logging';
import { GenkitPluginV2, ResolvableAction } from 'genkit/plugin';
import { ResolvableAction, type GenkitPluginV2 } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import OpenAI from 'openai';
import {
Expand Down
47 changes: 47 additions & 0 deletions js/testapps/compat-oai/src/direct.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { openAI } from '@genkit-ai/compat-oai/openai';

(async () => {
const oai = openAI();
const gpt4o = await oai.model('gpt-4o');
const response = await gpt4o({
messages: [
{
role: 'user',
content: [{ text: 'what is a gablorken of 4!' }],
},
],
tools: [
{
name: 'gablorken',
description: 'calculates a gablorken',
inputSchema: {
type: 'object',
properties: {
value: {
type: 'number',
description: 'the value to calculate gablorken for',
},
},
},
},
],
});

console.log(JSON.stringify(response.message, undefined, 2));
})();