Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import {
type GenerationCommonConfigSchema,
type MessageData,
type ModelArgument,
type ModelMiddleware,
type ModelMiddlewareArgument,
type Part,
type ToolRequestPart,
type ToolResponsePart,
Expand Down Expand Up @@ -171,7 +171,7 @@ export interface GenerateOptions<
*/
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
context?: ActionContext;
/** Abort signal for the generate request. */
Expand Down
53 changes: 36 additions & 17 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import {
ActionRunOptions,
GenkitError,
StreamingCallback,
defineAction,
Expand Down Expand Up @@ -42,6 +43,8 @@ import {
GenerateResponseChunkSchema,
GenerateResponseSchema,
MessageData,
ModelMiddlewareArgument,
ModelMiddlewareWithOptions,
resolveModel,
type GenerateActionOptions,
type GenerateActionOutputConfig,
Expand Down Expand Up @@ -85,7 +88,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
outputSchema: GenerateResponseSchema,
streamSchema: GenerateResponseChunkSchema,
},
async (request, { streamingRequested, sendChunk }) => {
async (request, { streamingRequested, sendChunk, context }) => {
const generateFn = (
sendChunk?: StreamingCallback<GenerateResponseChunk>
) =>
Expand All @@ -96,6 +99,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
// Generate util action does not support middleware. Maybe when we add named/registered middleware....
middleware: [],
streamingCallback: sendChunk,
context,
});
return streamingRequested
? generateFn((c: GenerateResponseChunk) =>
Expand All @@ -113,18 +117,18 @@ export async function generateHelper(
registry: Registry,
options: {
rawRequest: GenerateActionOptions;
middleware?: ModelMiddleware[];
middleware?: ModelMiddlewareArgument[];
currentTurn?: number;
messageIndex?: number;
abortSignal?: AbortSignal;
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
context?: Record<string, any>;
}
): Promise<GenerateResponseData> {
const currentTurn = options.currentTurn ?? 0;
const messageIndex = options.messageIndex ?? 0;
// do tracing
return await runInNewSpan(
registry,
{
metadata: {
name: options.rawRequest.stepName || 'generate',
Expand All @@ -143,6 +147,7 @@ export async function generateHelper(
messageIndex,
abortSignal: options.abortSignal,
streamingCallback: options.streamingCallback,
context: options.context,
});
metadata.output = JSON.stringify(output);
return output;
Expand Down Expand Up @@ -247,13 +252,15 @@ async function generate(
messageIndex,
abortSignal,
streamingCallback,
context,
}: {
rawRequest: GenerateActionOptions;
middleware: ModelMiddleware[] | undefined;
middleware: ModelMiddlewareArgument[] | undefined;
currentTurn: number;
messageIndex: number;
abortSignal?: AbortSignal;
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
context?: Record<string, any>;
}
): Promise<GenerateResponseData> {
const { model, tools, resources, format } = await resolveParameters(
Expand Down Expand Up @@ -320,29 +327,41 @@ async function generate(
}

var response: GenerateResponse;
const sendChunk =
streamingCallback &&
(((chunk: GenerateResponseChunkData) =>
streamingCallback &&
streamingCallback(makeChunk('model', chunk))) as any);
const dispatch = async (
index: number,
req: z.infer<typeof GenerateRequestSchema>
req: z.infer<typeof GenerateRequestSchema>,
actionOpts: ActionRunOptions<any>
) => {
if (!middleware || index === middleware.length) {
// end of the chain, call the original model action
return await model(req, {
abortSignal,
onChunk:
streamingCallback &&
(((chunk: GenerateResponseChunkData) =>
streamingCallback &&
streamingCallback(makeChunk('model', chunk))) as any),
});
return await model(req, actionOpts);
}

const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
if (currentMiddleware.length === 3) {
return (currentMiddleware as ModelMiddlewareWithOptions)(
req,
actionOpts,
async (modifiedReq, opts) =>
dispatch(index + 1, modifiedReq || req, opts || actionOpts)
);
} else {
return (currentMiddleware as ModelMiddleware)(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req, actionOpts)
);
}
};

const modelResponse = await dispatch(0, request);
const modelResponse = await dispatch(0, request, {
abortSignal,
context,
onChunk: sendChunk,
});

if (model.__action.actionType === 'background-model') {
response = new GenerateResponse(
Expand Down
17 changes: 14 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ActionFnArg,
BackgroundAction,
GenkitError,
MiddlewareWithOptions,
Operation,
OperationSchema,
action,
Expand Down Expand Up @@ -108,6 +109,16 @@ export type ModelMiddleware = SimpleMiddleware<
z.infer<typeof GenerateResponseSchema>
>;

export type ModelMiddlewareWithOptions = MiddlewareWithOptions<
z.infer<typeof GenerateRequestSchema>,
z.infer<typeof GenerateResponseSchema>,
z.infer<typeof GenerateResponseChunkSchema>
>;

export type ModelMiddlewareArgument =
| ModelMiddleware
| ModelMiddlewareWithOptions;

export type DefineModelOptions<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = {
Expand All @@ -121,7 +132,7 @@ export type DefineModelOptions<
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
label?: string;
/** Middleware to be used with this model. */
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
};

export function model<CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny>(
Expand Down Expand Up @@ -324,11 +335,11 @@ export function backgroundModel<
}

function getModelMiddleware(options: {
use?: ModelMiddleware[];
use?: ModelMiddlewareArgument[];
name: string;
supports?: ModelInfo['supports'];
}) {
const middleware: ModelMiddleware[] = options.use || [];
const middleware: ModelMiddlewareArgument[] = options.use || [];
if (!options?.supports?.context) middleware.push(augmentWithContext());
const constratedSimulator = simulateConstrainedGeneration();
middleware.push((req, next) => {
Expand Down
156 changes: 156 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
defineModel,
type ModelAction,
type ModelMiddleware,
type ModelMiddlewareWithOptions,
} from '../../src/model.js';
import { defineResource } from '../../src/resource.js';
import { defineTool } from '../../src/tool.js';
Expand Down Expand Up @@ -804,4 +805,159 @@ describe('generate', () => {
},
]);
});

it('middleware can intercept streaming callback', async () => {
const registry = new Registry();
const echoModel = defineModel(
registry,
{
apiVersion: 'v2',
name: 'echoModel',
supports: { tools: true },
},
async (_, { sendChunk }) => {
if (sendChunk) {
sendChunk({ content: [{ text: 'chunk1' }] });
sendChunk({ content: [{ text: 'chunk2' }] });
}
return {
message: {
role: 'model',
content: [{ text: 'done' }],
},
finishReason: 'stop',
};
}
);

const interceptMiddleware: ModelMiddlewareWithOptions = async (
req,
opts,
next
) => {
const originalOnChunk = opts!.onChunk;
return next(req, {
...opts,
onChunk: (chunk) => {
if (originalOnChunk) {
const text = chunk.content?.[0]?.text;
originalOnChunk({
...chunk,
content: [{ text: `intercepted: ${text}` }],
});
}
},
});
};

const { response, stream } = generateStream(registry, {
model: echoModel,
prompt: 'test',
use: [interceptMiddleware],
});

const streamed: any[] = [];
for await (const chunk of stream) {
streamed.push(chunk.content[0].text);
}

assert.deepStrictEqual(streamed, [
'intercepted: chunk1',
'intercepted: chunk2',
]);
await response;
});

it('middleware can modify context', async () => {
const registry = new Registry();
const checkContextModel = defineModel(
registry,
{
apiVersion: 'v2',
name: 'checkContextModel',
supports: { context: true },
},
async (request, { context }) => {
return {
message: {
role: 'model',
content: [{ text: `Context: ${context?.myValue}` }],
},
finishReason: 'stop',
};
}
);

const contextMiddleware: ModelMiddlewareWithOptions = async (
req,
opts,
next
) => {
return next(req, {
...opts,
context: {
...opts?.context,
myValue: 'foo',
},
});
};

const response = await generate(registry, {
model: checkContextModel,
prompt: 'test',
use: [contextMiddleware],
});

assert.strictEqual(response.text, 'Context: foo');
});

it('middleware can chain option modifications', async () => {
const registry = new Registry();
const checkContextModel = defineModel(
registry,
{
apiVersion: 'v2',
name: 'checkContextModel',
supports: { context: true },
},
async (request, { context }) => {
return {
message: {
role: 'model',
content: [{ text: `Context: ${JSON.stringify(context)}` }],
},
finishReason: 'stop',
};
}
);

const middleware1: ModelMiddlewareWithOptions = async (req, opts, next) => {
return next(req, {
...opts,
context: {
...opts?.context,
val: [...(opts?.context?.val ?? []), 'A'],
},
});
};

const middleware2: ModelMiddlewareWithOptions = async (req, opts, next) => {
return next(req, {
...opts,
context: {
...opts?.context,
val: [...(opts?.context?.val ?? []), 'B'],
},
});
};

const response = await generate(registry, {
model: checkContextModel,
prompt: 'test',
use: [middleware1, middleware2],
});

const context = JSON.parse(response.text.substring('Context: '.length));
assert.deepStrictEqual(context.val, ['A', 'B']);
});
});
2 changes: 2 additions & 0 deletions js/genkit/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ export {
type ModelArgument,
type ModelInfo,
type ModelMiddleware,
type ModelMiddlewareArgument,
type ModelMiddlewareWithOptions,
type ModelReference,
type ModelRequest,
type ModelResponseChunkData,
Expand Down