Skip to content
1 change: 1 addition & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type ModelInfo struct {
Label string `json:"label,omitempty"`
Supports *ModelInfoSupports `json:"supports,omitempty"`
Versions []string `json:"versions,omitempty"`
ConfigSchema map[string]any `json:"configSchema,omitempty"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to come from genkit-tools otherwise it will get overwritten on next generation.

}

type ModelInfoSupports struct {
Expand Down
5 changes: 3 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ func DefineModel(
}
metadataMap["supports"] = supports
metadataMap["versions"] = info.Versions

if info.ConfigSchema != nil {
metadataMap["customOptions"] = info.ConfigSchema
}
generate = core.ChainMiddleware(ValidateSupport(name, info))(generate)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, generate))
}

Expand Down
193 changes: 172 additions & 21 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package googleai

import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
Expand All @@ -17,9 +18,11 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/plugins/internal/gemini"
"github.com/firebase/genkit/go/plugins/internal/uri"
"github.com/google/generative-ai-go/genai"
"github.com/invopop/jsonschema"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
Expand All @@ -36,6 +39,57 @@ var state struct {
initted bool
}

type HarmCategory int32

const (
// HarmCategoryUnspecified means category is unspecified.
HarmCategoryUnspecified HarmCategory = 0
// HarmCategoryDerogatory means negative or harmful comments targeting identity and/or protected attribute.
HarmCategoryDerogatory HarmCategory = 1
// HarmCategoryToxicity means content that is rude, disrespectful, or profane.
HarmCategoryToxicity HarmCategory = 2
// HarmCategoryViolence means describes scenarios depicting violence against an individual or group, or
// general descriptions of gore.
HarmCategoryViolence HarmCategory = 3
// HarmCategorySexual means contains references to sexual acts or other lewd content.
HarmCategorySexual HarmCategory = 4
// HarmCategoryMedical means promotes unchecked medical advice.
HarmCategoryMedical HarmCategory = 5
// HarmCategoryDangerous means dangerous content that promotes, facilitates, or encourages harmful acts.
HarmCategoryDangerous HarmCategory = 6
// HarmCategoryHarassment means harasment content.
HarmCategoryHarassment HarmCategory = 7
// HarmCategoryHateSpeech means hate speech and content.
HarmCategoryHateSpeech HarmCategory = 8
// HarmCategorySexuallyExplicit means sexually explicit content.
HarmCategorySexuallyExplicit HarmCategory = 9
// HarmCategoryDangerousContent means dangerous content.
HarmCategoryDangerousContent HarmCategory = 10
)

// HarmBlockThreshold specifies block at and beyond a specified harm probability.
type HarmBlockThreshold int32

const (
// HarmBlockUnspecified means threshold is unspecified.
HarmBlockUnspecified HarmBlockThreshold = 0
// HarmBlockLowAndAbove means content with NEGLIGIBLE will be allowed.
HarmBlockLowAndAbove HarmBlockThreshold = 1
// HarmBlockMediumAndAbove means content with NEGLIGIBLE and LOW will be allowed.
HarmBlockMediumAndAbove HarmBlockThreshold = 2
// HarmBlockOnlyHigh means content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockOnlyHigh HarmBlockThreshold = 3
// HarmBlockNone means all content will be allowed.
HarmBlockNone HarmBlockThreshold = 4
)

type SafetySetting struct {
// Required. The category for this setting.
Category HarmCategory
// Required. Controls the probability threshold at which harm is blocked.
Threshold HarmBlockThreshold
}

var (
supportedModels = map[string]ai.ModelInfo{
"gemini-1.5-flash": {
Expand Down Expand Up @@ -89,6 +143,12 @@ var (
}
)

// GenerationGoogleAIConfig extends GenerationCommonConfig with Google AI specific settings.
type GenerationGoogleAIConfig struct {
ai.GenerationCommonConfig
SafetySettings []*SafetySetting
}

// Config is the configuration for the plugin.
type Config struct {
// The API key to access the service.
Expand Down Expand Up @@ -179,9 +239,10 @@ func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, e
// requires state.mu
func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model {
meta := &ai.ModelInfo{
Label: labelPrefix + " - " + name,
Supports: info.Supports,
Versions: info.Versions,
Label: labelPrefix + " - " + name,
Supports: info.Supports,
Versions: info.Versions,
ConfigSchema: convertConfigSchemaToMap(&GenerationGoogleAIConfig{}),
}
return genkit.DefineModel(g, provider, name, meta, func(
ctx context.Context,
Expand All @@ -192,6 +253,16 @@ func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model {
})
}

func convertConfigSchemaToMap(config any) map[string]any {
r := jsonschema.Reflector{
DoNotReference: true, // Prevent $ref usage
ExpandedStruct: true, // Include all fields directly
}
schema := r.Reflect(config)
result := base.SchemaAsMap(schema)
return result
}

// IsDefinedModel reports whether the named [Model] is defined by this plugin.
func IsDefinedModel(g *genkit.Genkit, name string) bool {
return genkit.IsDefinedModel(g, provider, name)
Expand Down Expand Up @@ -338,26 +409,90 @@ func generate(
return r, nil
}

func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) {
gm := client.GenerativeModel(model)
gm.SetCandidateCount(1)
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
if c.MaxOutputTokens != 0 {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
}
if len(c.StopSequences) > 0 {
gm.StopSequences = c.StopSequences
}
if c.Temperature != 0 {
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
func mapToStruct(m map[string]any, v any) error {
jsonData, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}

// applyGenerationConfig applies the common generation configuration to the model
func applyGenerationConfig(gm *genai.GenerativeModel, c GenerationGoogleAIConfig) {
if c.MaxOutputTokens != 0 {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
}
if len(c.StopSequences) > 0 {
gm.StopSequences = c.StopSequences
}
if c.Temperature != 0 {
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
}
if len(c.SafetySettings) > 0 {
gm.SafetySettings = convertSafetySettings(c.SafetySettings)
}
}

// extractConfigFromInput converts any supported config type to GoogleAIConfig
func extractConfigFromInput(input *ai.ModelRequest) (GenerationGoogleAIConfig, error) {
var result GenerationGoogleAIConfig
switch config := input.Config.(type) {
case GenerationGoogleAIConfig:
return config, nil
case *GenerationGoogleAIConfig:
return *config, nil
case ai.GenerationCommonConfig:
result.MaxOutputTokens = config.MaxOutputTokens
result.StopSequences = config.StopSequences
result.Temperature = config.Temperature
result.TopK = config.TopK
result.TopP = config.TopP
result.Version = config.Version
return result, nil
case *ai.GenerationCommonConfig:
if config != nil {
result.MaxOutputTokens = config.MaxOutputTokens
result.StopSequences = config.StopSequences
result.Temperature = config.Temperature
result.TopK = config.TopK
result.TopP = config.TopP
result.Version = config.Version
}
return result, nil
case map[string]any:
// Todo: this will silently fail if extra parameters are passed, may want to expose errors
if err := mapToStruct(config, &result); err == nil {
return result, nil
} else {
return result, err
}
case nil:
// Empty but valid config
return result, nil
default:
return result, fmt.Errorf("unexpected config type: %T", input.Config)
}
}

func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) {
c, err := extractConfigFromInput(input)
if err != nil {
return nil, err
}

specifiedModel := model
if c.Version != "" {
specifiedModel = c.Version
}
gm := client.GenerativeModel(specifiedModel)
gm.SetCandidateCount(1)
applyGenerationConfig(gm, c)
for _, m := range input.Messages {
systemParts, err := convertParts(m.Content)
if err != nil {
Expand Down Expand Up @@ -658,3 +793,19 @@ func convertPart(p *ai.Part) (genai.Part, error) {
}

//copy:stop

// convertSafetySettings converts local SafetySetting to genai.SafetySetting
func convertSafetySettings(settings []*SafetySetting) []*genai.SafetySetting {
if len(settings) == 0 {
return nil
}

result := make([]*genai.SafetySetting, len(settings))
for i, s := range settings {
result[i] = &genai.SafetySetting{
Category: genai.HarmCategory(s.Category),
Threshold: genai.HarmBlockThreshold(s.Threshold),
}
}
return result
}
15 changes: 12 additions & 3 deletions go/samples/basic-gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,18 @@ func main() {

resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&ai.GenerationCommonConfig{
Temperature: 1,
Version: "gemini-2.0-flash-001",
ai.WithConfig(&googleai.GenerationGoogleAIConfig{
GenerationCommonConfig: ai.GenerationCommonConfig{
Temperature: 1.0,
MaxOutputTokens: 256,
},
// Set custom safety settings - reduce restriction on harmfulness
SafetySettings: []*googleai.SafetySetting{
{
Category: googleai.HarmCategoryHarassment,
Threshold: googleai.HarmBlockMediumAndAbove,
},
},
}),
ai.WithTextPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input)))
if err != nil {
Expand Down
Loading