Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions js/plugins/googleai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"compile": "tsup-node",
"build:clean": "rm -rf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch"
"build:watch": "tsup-node --watch",
"test": "tsx --test ./tests/*_test.ts"
},
"repository": {
"type": "git",
Expand All @@ -32,7 +33,7 @@
"dependencies": {
"@genkit-ai/ai": "workspace:*",
"@genkit-ai/core": "workspace:*",
"@google/generative-ai": "^0.6.0",
"@google/generative-ai": "^0.10.0",
"google-auth-library": "^9.6.3",
"node-fetch": "^3.3.2",
"zod": "^3.22.4"
Expand Down
125 changes: 114 additions & 11 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,26 @@ import {
modelRef,
ModelReference,
Part,
ToolDefinitionSchema,
ToolRequestPart,
ToolResponsePart,
} from '@genkit-ai/ai/model';
import {
downloadRequestMedia,
simulateSystemPrompt,
} from '@genkit-ai/ai/model/middleware';
import { GENKIT_CLIENT_HEADER } from '@genkit-ai/core';
import {
FunctionCallPart,
FunctionDeclaration,
FunctionDeclarationSchemaType,
FunctionResponsePart,
GenerateContentCandidate as GeminiCandidate,
Content as GeminiMessage,
Part as GeminiPart,
GenerateContentResponse,
GoogleGenerativeAI,
InlineDataPart,
RequestOptions,
StartChatParams,
} from '@google/generative-ai';
Expand Down Expand Up @@ -72,7 +80,7 @@ export const geminiPro = modelRef({
supports: {
multiturn: true,
media: false,
tools: false,
tools: true,
systemRole: true,
},
},
Expand Down Expand Up @@ -117,7 +125,7 @@ export const geminiUltra = modelRef({
supports: {
multiturn: true,
media: false,
tools: false,
tools: true,
systemRole: true,
},
},
Expand Down Expand Up @@ -171,7 +179,44 @@ function toGeminiRole(
}
}

function toInlineData(part: MediaPart): GeminiPart {
function convertSchemaProperty(property) {
if (!property) {
return null;
}
if (property.type === 'object') {
const nestedProperties = {};
Object.keys(property.properties).forEach((key) => {
nestedProperties[key] = convertSchemaProperty(property.properties[key]);
});
return {
type: FunctionDeclarationSchemaType.OBJECT,
properties: nestedProperties,
required: property.required,
};
} else if (property.type === 'array') {
return {
type: FunctionDeclarationSchemaType.ARRAY,
items: convertSchemaProperty(property.items),
};
} else {
return {
type: FunctionDeclarationSchemaType[property.type.toUpperCase()],
};
}
}

function toGeminiTool(
tool: z.infer<typeof ToolDefinitionSchema>
): FunctionDeclaration {
const declaration: FunctionDeclaration = {
name: tool.name.replace(/\//g, '__'), // Gemini throws on '/' in tool name
description: tool.description,
parameters: convertSchemaProperty(tool.inputSchema),
};
return declaration;
}

function toInlineData(part: MediaPart): InlineDataPart {
const dataUrl = part.media.url;
const b64Data = dataUrl.substring(dataUrl.indexOf(',')! + 1);
const contentType =
Expand All @@ -180,14 +225,14 @@ function toInlineData(part: MediaPart): GeminiPart {
return { inlineData: { mimeType: contentType, data: b64Data } };
}

function fromInlineData(inlinePart: GeminiPart): MediaPart {
function fromInlineData(inlinePart: InlineDataPart): MediaPart {
// Check if the required properties exist
if (
!inlinePart.inlineData ||
!inlinePart.inlineData.hasOwnProperty('mimeType') ||
!inlinePart.inlineData.hasOwnProperty('data')
) {
throw new Error('Invalid GeminiPart: missing required properties');
throw new Error('Invalid InlineDataPart: missing required properties');
}
const { mimeType, data } = inlinePart.inlineData;
// Combine data and mimeType into a data URL
Expand All @@ -200,19 +245,74 @@ function fromInlineData(inlinePart: GeminiPart): MediaPart {
};
}

function toFunctionCall(part: ToolRequestPart): FunctionCallPart {
if (!part?.toolRequest?.input) {
throw Error('Invalid ToolRequestPart: input was missing.');
}
return {
functionCall: {
name: part.toolRequest.name,
args: part.toolRequest.input,
},
};
}

function fromFunctionCall(part: FunctionCallPart): ToolRequestPart {
if (!part.functionCall) {
throw Error('Invalid FunctionCallPart');
}
return {
toolRequest: {
name: part.functionCall.name,
input: part.functionCall.args,
},
};
}

function toFunctionResponse(part: ToolResponsePart): FunctionResponsePart {
if (!part?.toolResponse?.output) {
throw Error('Invalid ToolResponsePart: output was missing.');
}
return {
functionResponse: {
name: part.toolResponse.name,
response: {
name: part.toolResponse.name,
content: part.toolResponse.output,
},
},
};
}

function fromFunctionResponse(part: FunctionResponsePart): ToolResponsePart {
if (!part.functionResponse) {
throw new Error('Invalid FunctionResponsePart.');
}
return {
toolResponse: {
name: part.functionResponse.name.replace(/__/g, '/'), // restore slashes
output: part.functionResponse.response,
},
};
}

function toGeminiPart(part: Part): GeminiPart {
if (part.text !== undefined) return { text: part.text };
if (part.media) return toInlineData(part);
throw new Error('Only text and media parts are supported currently.');
if (part.toolRequest) return toFunctionCall(part);
if (part.toolResponse) return toFunctionResponse(part);
throw new Error('Unsupported Part type');
}

function fromGeminiPart(part: GeminiPart): Part {
if (part.text !== undefined) return { text: part.text };
if (part.inlineData) return fromInlineData(part);
throw new Error('Only support text for the moment.');
if (part.functionCall) return fromFunctionCall(part);
if (part.functionResponse) return fromFunctionResponse(part);
throw new Error('Unsupported GeminiPart type');
}

function toGeminiMessage(
export function toGeminiMessage(
message: MessageData,
model?: ModelReference<z.ZodTypeAny>
): GeminiMessage {
Expand All @@ -222,7 +322,7 @@ function toGeminiMessage(
};
}

function toGeminiSystemInstruction(message: MessageData): GeminiMessage {
export function toGeminiSystemInstruction(message: MessageData): GeminiMessage {
return {
role: 'user',
parts: message.content.map(toGeminiPart),
Expand All @@ -246,9 +346,9 @@ function fromGeminiFinishReason(
}
}

function fromGeminiCandidate(candidate: GeminiCandidate): CandidateData {
export function fromGeminiCandidate(candidate: GeminiCandidate): CandidateData {
return {
index: candidate.index,
index: candidate.index || 0, // reasonable default?
message: {
role: 'model',
content: (candidate.content?.parts || []).map(fromGeminiPart),
Expand Down Expand Up @@ -335,6 +435,9 @@ export function googleAIModel(

const chatRequest = {
systemInstruction,
tools: request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [],
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand Down
Loading