Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions go/ai/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
package ai

import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/api"
)

Expand Down Expand Up @@ -71,27 +70,18 @@ func DefineFormat(r api.Registry, name string, formatter Formatter) {
// resolveFormat returns a [Formatter], either a default one or one from the registry.
func resolveFormat(reg api.Registry, schema map[string]any, format string) (Formatter, error) {
var formatter any

// If schema is set but no explicit format is set we default to json.
if schema != nil && format == "" {
formatter = reg.LookupValue("/format/" + OutputFormatJSON)
}

// If format is not set we default to text
if format == "" {
formatter = reg.LookupValue("/format/" + OutputFormatText)
}

// Lookup format in registry
if format != "" {
formatter = reg.LookupValue("/format/" + format)
if schema != nil {
format = OutputFormatJSON
} else {
format = OutputFormatText
}
}

formatter = reg.LookupValue("/format/" + format)
if f, ok := formatter.(Formatter); ok {
return f, nil
}

return nil, fmt.Errorf("output format %q is invalid", format)
return nil, core.NewError(core.INVALID_ARGUMENT, "output format %q is invalid", format)
}

// injectInstructions looks through the messages and injects formatting directives
Expand Down
70 changes: 37 additions & 33 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,43 +78,40 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt {
pOpts.Config = modelRef.Config()
}

meta := p.Metadata
if meta == nil {
meta = map[string]any{}
}

var tools []string
for _, value := range pOpts.commonGenOptions.Tools {
tools = append(tools, value.Name())
}

inputMeta := map[string]any{}
if p.InputSchema != nil {
inputMeta["schema"] = p.InputSchema
metadata := p.Metadata
if metadata == nil {
metadata = map[string]any{}
}
metadata["type"] = api.ActionTypeExecutablePrompt

outputMeta := map[string]any{}
if p.OutputSchema != nil {
outputMeta["schema"] = p.OutputSchema
baseName := name
if idx := strings.LastIndex(name, "."); idx != -1 {
baseName = name[:idx]
}

promptMeta := map[string]any{
"type": api.ActionTypeExecutablePrompt,
"prompt": map[string]any{
"name": name,
"description": p.Description,
"model": modelName,
"config": p.Config,
"input": inputMeta,
"output": outputMeta,
"defaultInput": p.DefaultInput,
"tools": tools,
"maxTurns": p.MaxTurns,
},
promptMetadata := map[string]any{
"name": baseName,
"description": p.Description,
"model": modelName,
"config": p.Config,
"input": map[string]any{"schema": p.InputSchema},
"output": map[string]any{"schema": p.OutputSchema},
"defaultInput": p.DefaultInput,
"tools": tools,
"maxTurns": p.MaxTurns,
}
if m, ok := metadata["prompt"].(map[string]any); ok {
maps.Copy(m, promptMetadata)
} else {
metadata["prompt"] = promptMetadata
}
maps.Copy(meta, promptMeta)

p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, meta, p.InputSchema, p.buildRequest)
p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest)

return p
}
Expand Down Expand Up @@ -641,16 +638,23 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
toolRefs[i] = ToolName(tool)
}

promptMetadata := map[string]any{
"template": parsedPrompt.Template,
promptOptMetadata := metadata.Metadata
if promptOptMetadata == nil {
promptOptMetadata = make(map[string]any)
}
maps.Copy(promptMetadata, metadata.Metadata)

promptOptMetadata := map[string]any{
"type": "prompt",
"prompt": promptMetadata,
var promptMetadata map[string]any
if m, ok := promptOptMetadata["prompt"].(map[string]any); ok {
promptMetadata = m
} else {
promptMetadata = make(map[string]any)
}
promptMetadata["template"] = parsedPrompt.Template
if variant != "" {
promptMetadata["variant"] = variant
}
maps.Copy(promptOptMetadata, metadata.Metadata)
promptOptMetadata["prompt"] = promptMetadata
promptOptMetadata["type"] = api.ActionTypeExecutablePrompt

opts := &promptOptions{
commonGenOptions: commonGenOptions{
Expand Down
12 changes: 12 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,18 @@ Hello, {{name}}!
if prompt == nil {
t.Fatalf("Prompt was not registered")
}

// Verify that the metadata name does NOT include the variant
promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any)
if !ok {
t.Fatalf("Expected Metadata['prompt'] to be a map")
}
if promptMetadata["name"] != "test-namespace/example" {
t.Errorf("Expected metadata name 'test-namespace/example', got '%s'", promptMetadata["name"])
}
if promptMetadata["variant"] != "variant" {
t.Errorf("Expected variant 'variant', got '%s'", promptMetadata["variant"])
}
}

func TestLoadPromptFolder(t *testing.T) {
Expand Down
Loading