Skip to content

Commit bce221e

Browse files
authored
fix(go/ai): fixed format resolution + prompt variants metadata (#3931)
1 parent b0ba327 commit bce221e

File tree

3 files changed

+57
-51
lines changed

3 files changed

+57
-51
lines changed

‎go/ai/formatter.go‎

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
package ai
1616

1717
import (
18-
"fmt"
19-
18+
"github.com/firebase/genkit/go/core"
2019
"github.com/firebase/genkit/go/core/api"
2120
)
2221

@@ -95,27 +94,18 @@ func DefineFormat(r api.Registry, name string, formatter Formatter) {
9594
// resolveFormat returns a [Formatter], either a default one or one from the registry.
9695
func resolveFormat(reg api.Registry, schema map[string]any, format string) (Formatter, error) {
9796
var formatter any
98-
99-
// If schema is set but no explicit format is set we default to json.
100-
if schema != nil && format == "" {
101-
formatter = reg.LookupValue("/format/" + OutputFormatJSON)
102-
}
103-
104-
// If format is not set we default to text
10597
if format == "" {
106-
formatter = reg.LookupValue("/format/" + OutputFormatText)
107-
}
108-
109-
// Lookup format in registry
110-
if format != "" {
111-
formatter = reg.LookupValue("/format/" + format)
98+
if schema != nil {
99+
format = OutputFormatJSON
100+
} else {
101+
format = OutputFormatText
102+
}
112103
}
113-
104+
formatter = reg.LookupValue("/format/" + format)
114105
if f, ok := formatter.(Formatter); ok {
115106
return f, nil
116107
}
117-
118-
return nil, fmt.Errorf("output format %q is invalid", format)
108+
return nil, core.NewError(core.INVALID_ARGUMENT, "output format %q is invalid", format)
119109
}
120110

121111
// injectInstructions looks through the messages and injects formatting directives

‎go/ai/prompt.go‎

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -78,43 +78,40 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt {
7878
pOpts.Config = modelRef.Config()
7979
}
8080

81-
meta := p.Metadata
82-
if meta == nil {
83-
meta = map[string]any{}
84-
}
85-
8681
var tools []string
8782
for _, value := range pOpts.commonGenOptions.Tools {
8883
tools = append(tools, value.Name())
8984
}
9085

91-
inputMeta := map[string]any{}
92-
if p.InputSchema != nil {
93-
inputMeta["schema"] = p.InputSchema
86+
metadata := p.Metadata
87+
if metadata == nil {
88+
metadata = map[string]any{}
9489
}
90+
metadata["type"] = api.ActionTypeExecutablePrompt
9591

96-
outputMeta := map[string]any{}
97-
if p.OutputSchema != nil {
98-
outputMeta["schema"] = p.OutputSchema
92+
baseName := name
93+
if idx := strings.LastIndex(name, "."); idx != -1 {
94+
baseName = name[:idx]
9995
}
10096

101-
promptMeta := map[string]any{
102-
"type": api.ActionTypeExecutablePrompt,
103-
"prompt": map[string]any{
104-
"name": name,
105-
"description": p.Description,
106-
"model": modelName,
107-
"config": p.Config,
108-
"input": inputMeta,
109-
"output": outputMeta,
110-
"defaultInput": p.DefaultInput,
111-
"tools": tools,
112-
"maxTurns": p.MaxTurns,
113-
},
97+
promptMetadata := map[string]any{
98+
"name": baseName,
99+
"description": p.Description,
100+
"model": modelName,
101+
"config": p.Config,
102+
"input": map[string]any{"schema": p.InputSchema},
103+
"output": map[string]any{"schema": p.OutputSchema},
104+
"defaultInput": p.DefaultInput,
105+
"tools": tools,
106+
"maxTurns": p.MaxTurns,
107+
}
108+
if m, ok := metadata["prompt"].(map[string]any); ok {
109+
maps.Copy(m, promptMetadata)
110+
} else {
111+
metadata["prompt"] = promptMetadata
114112
}
115-
maps.Copy(meta, promptMeta)
116113

117-
p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, meta, p.InputSchema, p.buildRequest)
114+
p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest)
118115

119116
return p
120117
}
@@ -641,16 +638,23 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
641638
toolRefs[i] = ToolName(tool)
642639
}
643640

644-
promptMetadata := map[string]any{
645-
"template": parsedPrompt.Template,
641+
promptOptMetadata := metadata.Metadata
642+
if promptOptMetadata == nil {
643+
promptOptMetadata = make(map[string]any)
646644
}
647-
maps.Copy(promptMetadata, metadata.Metadata)
648645

649-
promptOptMetadata := map[string]any{
650-
"type": "prompt",
651-
"prompt": promptMetadata,
646+
var promptMetadata map[string]any
647+
if m, ok := promptOptMetadata["prompt"].(map[string]any); ok {
648+
promptMetadata = m
649+
} else {
650+
promptMetadata = make(map[string]any)
651+
}
652+
promptMetadata["template"] = parsedPrompt.Template
653+
if variant != "" {
654+
promptMetadata["variant"] = variant
652655
}
653-
maps.Copy(promptOptMetadata, metadata.Metadata)
656+
promptOptMetadata["prompt"] = promptMetadata
657+
promptOptMetadata["type"] = api.ActionTypeExecutablePrompt
654658

655659
opts := &promptOptions{
656660
commonGenOptions: commonGenOptions{

‎go/ai/prompt_test.go‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,18 @@ Hello, {{name}}!
10821082
if prompt == nil {
10831083
t.Fatalf("Prompt was not registered")
10841084
}
1085+
1086+
// Verify that the metadata name does NOT include the variant
1087+
promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any)
1088+
if !ok {
1089+
t.Fatalf("Expected Metadata['prompt'] to be a map")
1090+
}
1091+
if promptMetadata["name"] != "test-namespace/example" {
1092+
t.Errorf("Expected metadata name 'test-namespace/example', got '%s'", promptMetadata["name"])
1093+
}
1094+
if promptMetadata["variant"] != "variant" {
1095+
t.Errorf("Expected variant 'variant', got '%s'", promptMetadata["variant"])
1096+
}
10851097
}
10861098

10871099
func TestLoadPromptFolder(t *testing.T) {

0 commit comments

Comments
 (0)