Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package vertexai

import (
"context"
"fmt"

"cloud.google.com/go/vertexai/genai"
"github.com/google/genkit/go/ai"
Expand All @@ -39,7 +40,7 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb

// Translate from a ai.GenerateRequest to a genai request.
gm.SetCandidateCount(int32(input.Candidates))
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok {
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
gm.StopSequences = c.StopSequences
gm.SetTemperature(float32(c.Temperature))
Expand All @@ -65,7 +66,40 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb
if len(messages) > 0 {
parts = convertParts(messages[0].Content)
}
//TODO: convert input.Tools and append to gm.Tools

// Convert input.Tools and append to gm.Tools.
for _, t := range input.Tools {
schema := &genai.Schema{
Type: genai.TypeObject,
Properties: make(map[string]*genai.Schema),
}
for k, v := range t.InputSchema {
typ := genai.TypeUnspecified
switch v {
case "string":
typ = genai.TypeString
case "float64":
typ = genai.TypeNumber
case "int":
typ = genai.TypeInteger
case "bool":
typ = genai.TypeBoolean
default:
return nil, fmt.Errorf("schema value %q not supported", v)
}
schema.Properties[k] = &genai.Schema{Type: typ}
}

fd := &genai.FunctionDeclaration{
Name: t.Name,
Parameters: schema,
}

gm.Tools = append(gm.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{fd},
})
}
// TODO: gm.ToolConfig?

// Send out the actual request.
resp, err := cs.SendMessage(ctx, parts...)
Expand Down Expand Up @@ -103,10 +137,13 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
case genai.FunctionResponse:
p = ai.NewBlobPart("TODO", string(part.Name))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Input: part.Args,
})
default:
panic("unknown part type")
panic(fmt.Sprintf("unknown part #%v", part))
}
m.Content = append(m.Content, p)
}
Expand Down Expand Up @@ -159,7 +196,21 @@ func convertPart(p *ai.Part) genai.Part {
switch {
case p.IsText():
return genai.Text(p.Text())
default:
case p.IsBlob():
return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())}
case p.IsToolResponse():
toolResp := p.ToolResponse()
return genai.FunctionResponse{
Name: toolResp.Name,
Response: toolResp.Output,
}
case p.IsToolRequest():
toolReq := p.ToolRequest()
return genai.FunctionCall{
Name: toolReq.Name,
Args: toolReq.Input,
}
default:
panic("unknown part type in a request")
}
}
76 changes: 76 additions & 0 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package vertexai_test
import (
"context"
"flag"
"math"
"strings"
"testing"

Expand Down Expand Up @@ -57,4 +58,79 @@ func TestGenerator(t *testing.T) {
if !strings.Contains(out, "France") {
t.Errorf("got \"%s\", expecting it would contain \"France\"", out)
}
if resp.Request != req {
t.Error("Request field not set properly")
}
}

func TestGeneratorTool(t *testing.T) {
if *projectID == "" {
t.Skip("no -projectid provided")
}
ctx := context.Background()
g, err := vertexai.NewGenerator(ctx, "gemini-1.0-pro", *projectID, *location)
if err != nil {
t.Fatal(err)
}
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")},
Role: ai.RoleUser,
},
},
Tools: []*ai.ToolDefinition{
&ai.ToolDefinition{
Name: "exponentiation",
InputSchema: map[string]any{"base": "float64", "exponent": "int"},
OutputSchema: map[string]any{"output": "float64"},
},
},
}

resp, err := g.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
p := resp.Candidates[0].Message.Content[0]
if !p.IsToolRequest() {
t.Fatalf("tool not requested")
}
toolReq := p.ToolRequest()
if toolReq.Name != "exponentiation" {
t.Errorf("tool name is %q, want \"exponentiation\"", toolReq.Name)
}
if toolReq.Input["base"] != 3.5 {
t.Errorf("base is %f, want 3.5", toolReq.Input["base"])
}
if toolReq.Input["exponent"] != 2 && toolReq.Input["exponent"] != 2.0 {
// Note: 2.0 is wrong given the schema, but Gemini returns a float anyway.
t.Errorf("exponent is %f, want 2", toolReq.Input["exponent"])
}

// Update our conversation with the tool request the model made and our tool response.
// (Our "tool" is just math.Pow.)
req.Messages = append(req.Messages,
resp.Candidates[0].Message,
&ai.Message{
Content: []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{
Name: "exponentiation",
Output: map[string]any{"output": math.Pow(3.5, 2)},
})},
Role: ai.RoleTool,
},
)

// Issue our request again.
resp, err = g.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}

// Check final response.
out := resp.Candidates[0].Message.Content[0].Text()
if !strings.Contains(out, "12.25") {
t.Errorf("got %s, expecting it to contain \"12.25\"", out)
}
}