Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedReq
return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)}
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the Embedder was not defined.
func LookupEmbedder(provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}
return embedder{action}
}

type embedder struct {
embedAction *core.Action[*EmbedRequest, []float32, struct{}]
}
Expand Down
27 changes: 10 additions & 17 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generat
return generator{a}
}

// LookupGenerator looks up a [Generator] registered by [DefineGenerator].
// It returns nil if the Generator was not defined.
func LookupGenerator(provider, name string) Generator {
action := core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](core.ActionTypeModel, provider, name)
if action == nil {
return nil
}
return generator{action}
}

type generator struct {
generateAction *core.Action[*GenerateRequest, *GenerateResponse, *Candidate]
}
Expand Down Expand Up @@ -114,23 +124,6 @@ func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb Generat
}
}

// generatorActionType is the instantiated core.Action type registered
// by RegisterGenerator.
type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate]

// LookupGenerator looks up a [Generator] registered by [DefineGenerator].
func LookupGenerator(provider, name string) (Generator, error) {
action := core.LookupAction(core.ActionTypeModel, provider, name)
if action == nil {
return nil, fmt.Errorf("LookupGenerator: no generator action named %q/%q", provider, name)
}
actionInst, ok := action.(*generatorActionType)
if !ok {
return nil, fmt.Errorf("LookupGenerator: generator action %q has type %T, want %T", name, action, &generatorActionType{})
}
return generator{actionInst}, nil
}

// conformOutput appends a message to the request indicating conformance to the expected schema.
func conformOutput(req *GenerateRequest) error {
if req.Output != nil && req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 {
Expand Down
7 changes: 7 additions & 0 deletions go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ func LookupAction(typ ActionType, provider, name string) action {
return globalRegistry.lookupAction(key)
}

// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ ActionType, provider, name string) *Action[In, Out, Stream] {
return LookupAction(typ, provider, name).(*Action[In, Out, Stream])
}

// listActions returns a list of descriptions of all registered actions.
// The list is sorted by action name.
func (r *registry) listActions() []actionDesc {
Expand Down
7 changes: 4 additions & 3 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package dotprompt
import (
"context"
"errors"
"fmt"
"reflect"
"strings"

Expand Down Expand Up @@ -156,9 +157,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *ai.PromptRequest, cb func(con
return nil, errors.New("dotprompt model not in provider/name format")
}

generator, err = ai.LookupGenerator(provider, name)
if err != nil {
return nil, err
generator := ai.LookupGenerator(provider, name)
if generator == nil {
return nil, fmt.Errorf("no generator named %q for provider %q", name, provider)
}
}

Expand Down
43 changes: 0 additions & 43 deletions go/plugins/googleai/embed.go

This file was deleted.

119 changes: 100 additions & 19 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package googleai

import (
"context"
"errors"
"fmt"
"path"
"slices"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/internal/uri"
Expand All @@ -25,8 +28,103 @@ import (
"google.golang.org/api/option"
)

func newClient(ctx context.Context, apiKey string) (*genai.Client, error) {
return genai.NewClient(ctx, option.WithAPIKey(apiKey))
const provider = "google-genai"

// Config configures the plugin.
type Config struct {
// API key. Required.
APIKey string
// Generative models to provide.
// If empty, a complete list will be obtained from the service.
Models []string
// Embedding models to provide.
// If empty, a complete list will be obtained from the service.
Embedders []string
}

func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("googleai.Init: %w", err)
}
}()

if cfg.APIKey == "" {
return errors.New("missing API key")
}

client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey))
if err != nil {
return err
}

needModels := len(cfg.Models) == 0
needEmbedders := len(cfg.Embedders) == 0
if needModels || needEmbedders {
iter := client.ListModels(ctx)
for {
mi, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
// Model names are of the form "models/name".
name := path.Base(mi.Name)
if needModels && slices.Contains(mi.SupportedGenerationMethods, "generateContent") {
cfg.Models = append(cfg.Models, name)
}
if needEmbedders && slices.Contains(mi.SupportedGenerationMethods, "embedContent") {
cfg.Embedders = append(cfg.Embedders, name)
}
}
}
for _, name := range cfg.Models {
defineModel(name, client)
}
for _, name := range cfg.Embedders {
defineEmbedder(name, client)
}
return nil
}

func defineModel(name string, client *genai.Client) {
meta := &ai.GeneratorMetadata{
Label: "Google AI - " + name,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}
g := generator{model: name, client: client}
ai.DefineGenerator(provider, name, meta, g.Generate)
}

func defineEmbedder(name string, client *genai.Client) {
ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
em := client.EmbeddingModel(name)
parts, err := convertParts(input.Document.Content)
if err != nil {
return nil, err
}
res, err := em.EmbedContent(ctx, parts...)
if err != nil {
return nil, err
}
return res.Embedding.Values, nil
})
}

// Generator returns the generator with the given name.
// It returns nil if the generator was not configured.
func Generator(name string) ai.Generator {
return ai.LookupGenerator(provider, name)
}

// Embedder returns the embedder with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

type generator struct {
Expand Down Expand Up @@ -195,23 +293,6 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse
return r
}

// NewGenerator returns an [ai.Generator] which sends a request to
// the google AI model and returns the response.
func NewGenerator(ctx context.Context, model, apiKey string) (ai.Generator, error) {
client, err := newClient(ctx, apiKey)
if err != nil {
return nil, err
}
meta := &ai.GeneratorMetadata{
Label: "Google AI - " + model,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}
g := generator{model: model, client: client}
return ai.DefineGenerator("google-genai", model, meta, g.Generate), nil
}

// convertParts converts a slice of *ai.Part to a slice of genai.Part.
func convertParts(parts []*ai.Part) ([]genai.Part, error) {
res := make([]genai.Part, 0, len(parts))
Expand Down
Loading