Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 57 additions & 34 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,31 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
return result, nil
}

// convertDotpromptMessages converts []dotprompt.Message to []*Message
func convertDotpromptMessages(msgs []dotprompt.Message) ([]*Message, error) {
result := make([]*Message, 0, len(msgs))
for _, msg := range msgs {
parts, err := convertToPartPointers(msg.Content)
if err != nil {
return nil, err
}
// Filter out nil parts
filteredParts := make([]*Part, 0, len(parts))
for _, p := range parts {
if p != nil {
filteredParts = append(filteredParts, p)
}
}
if len(filteredParts) > 0 {
result = append(result, &Message{
Role: Role(msg.Role),
Content: filteredParts,
})
}
}
return result, nil
}

// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
useDefaultDir := false
Expand Down Expand Up @@ -708,49 +733,47 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

key := promptKey(name, variant, namespace)

dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
// Defer template rendering to execution time.
// See: https://github.com/firebase/genkit/issues/3924
templateText := parsedPrompt.Template
compiledTemplate, err := dp.Compile(templateText, &dotprompt.PromptMetadata{
Input: dotprompt.PromptMetadataInput{
Default: opts.DefaultInput,
},
})
if err != nil {
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
slog.Error("Failed to compile prompt template", "file", sourceFile, "error", err)
return nil
}

var systemText string
var nonSystemMessages []*Message
for _, dpMsg := range dpMessages {
parts, err := convertToPartPointers(dpMsg.Content)
promptOpts := []PromptOption{opts}
promptOpts = append(promptOpts, WithMessagesFn(func(ctx context.Context, input any) ([]*Message, error) {
inputMap, err := buildVariables(input)
if err != nil {
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
return nil
return nil, err
}

role := Role(dpMsg.Role)
if role == RoleSystem {
var textParts []string
for _, part := range parts {
if part.IsText() {
textParts = append(textParts, part.Text)
}
}

if len(textParts) > 0 {
systemText = strings.Join(textParts, " ")
}
} else {
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
// Prepare the data context for rendering
dataContext := map[string]any{}
actionCtx := core.FromContext(ctx)
maps.Copy(dataContext, actionCtx)

// Render with actual input values at execution time
rendered, err := compiledTemplate(&dotprompt.DataArgument{
Input: inputMap,
Context: dataContext,
}, &dotprompt.PromptMetadata{
Input: dotprompt.PromptMetadataInput{
Default: opts.DefaultInput,
},
})
if err != nil {
return nil, fmt.Errorf("failed to render template: %w", err)
}
}

promptOpts := []PromptOption{opts}

if systemText != "" {
promptOpts = append(promptOpts, WithSystem(systemText))
}

if len(nonSystemMessages) > 0 {
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
} else if systemText == "" {
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
}
// Convert dotprompt messages to ai messages
return convertDotpromptMessages(rendered.Messages)
}))

prompt := DefinePrompt(r, key, promptOpts...)

Expand Down
172 changes: 172 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1500,3 +1500,175 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) {
t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err)
}
}

// TestLoadPromptTemplateVariableSubstitution tests that template variables are
// properly substituted with actual input values at execution time.
// This is a regression test for https://github.com/firebase/genkit/issues/3924
func TestLoadPromptTemplateVariableSubstitution(t *testing.T) {
t.Run("single role", func(t *testing.T) {
tempDir := t.TempDir()

mockPromptFile := filepath.Join(tempDir, "greeting.prompt")
mockPromptContent := `---
model: test/chat
description: A greeting prompt with variables
---
Hello {{name}}, welcome to {{place}}!
`

if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}

prompt := LoadPrompt(registry.New(), tempDir, "greeting.prompt", "template-var-test")

// Test with first set of input values
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
"name": "Alice",
"place": "Wonderland",
})
if err != nil {
t.Fatalf("Failed to render prompt with first input: %v", err)
}

if len(actionOpts1.Messages) != 1 {
t.Fatalf("Expected 1 message, got %d", len(actionOpts1.Messages))
}

text1 := actionOpts1.Messages[0].Content[0].Text
if !strings.Contains(text1, "Alice") {
t.Errorf("Expected message to contain 'Alice', got: %s", text1)
}
if !strings.Contains(text1, "Wonderland") {
t.Errorf("Expected message to contain 'Wonderland', got: %s", text1)
}

// Test with second set of input values (different from first)
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
"name": "Bob",
"place": "Paradise",
})
if err != nil {
t.Fatalf("Failed to render prompt with second input: %v", err)
}

if len(actionOpts2.Messages) != 1 {
t.Fatalf("Expected 1 message, got %d", len(actionOpts2.Messages))
}

text2 := actionOpts2.Messages[0].Content[0].Text
if !strings.Contains(text2, "Bob") {
t.Errorf("Expected message to contain 'Bob', got: %s", text2)
}
if !strings.Contains(text2, "Paradise") {
t.Errorf("Expected message to contain 'Paradise', got: %s", text2)
}

// Critical: Ensure the second render did NOT use the first input values
if strings.Contains(text2, "Alice") {
t.Errorf("BUG: Second render contains 'Alice' from first input! Got: %s", text2)
}
if strings.Contains(text2, "Wonderland") {
t.Errorf("BUG: Second render contains 'Wonderland' from first input! Got: %s", text2)
}
})

t.Run("multi role", func(t *testing.T) {
tempDir := t.TempDir()

mockPromptFile := filepath.Join(tempDir, "multi_role.prompt")
mockPromptContent := `---
model: test/chat
description: A multi-role prompt with variables
---
<<<dotprompt:role:system>>>
You are a {{personality}} assistant.

<<<dotprompt:role:user>>>
Hello {{name}}, please help me with {{task}}.
`

if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}

prompt := LoadPrompt(registry.New(), tempDir, "multi_role.prompt", "multi-role-var-test")

// Test with first set of input values
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
"personality": "helpful",
"name": "Alice",
"task": "coding",
})
if err != nil {
t.Fatalf("Failed to render prompt with first input: %v", err)
}

if len(actionOpts1.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(actionOpts1.Messages))
}

// Check system message
systemMsg := actionOpts1.Messages[0]
if systemMsg.Role != RoleSystem {
t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role)
}
systemText := systemMsg.Content[0].Text
if !strings.Contains(systemText, "helpful") {
t.Errorf("Expected system message to contain 'helpful', got: %s", systemText)
}

// Check user message
userMsg := actionOpts1.Messages[1]
if userMsg.Role != RoleUser {
t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role)
}
userText := userMsg.Content[0].Text
if !strings.Contains(userText, "Alice") {
t.Errorf("Expected user message to contain 'Alice', got: %s", userText)
}
if !strings.Contains(userText, "coding") {
t.Errorf("Expected user message to contain 'coding', got: %s", userText)
}

// Test with second set of input values (different from first)
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
"personality": "professional",
"name": "Bob",
"task": "writing",
})
if err != nil {
t.Fatalf("Failed to render prompt with second input: %v", err)
}

if len(actionOpts2.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(actionOpts2.Messages))
}

// Check system message with new values
systemMsg2 := actionOpts2.Messages[0]
systemText2 := systemMsg2.Content[0].Text
if !strings.Contains(systemText2, "professional") {
t.Errorf("Expected system message to contain 'professional', got: %s", systemText2)
}
if strings.Contains(systemText2, "helpful") {
t.Errorf("BUG: Second render system message contains 'helpful' from first input! Got: %s", systemText2)
}

// Check user message with new values
userMsg2 := actionOpts2.Messages[1]
userText2 := userMsg2.Content[0].Text
if !strings.Contains(userText2, "Bob") {
t.Errorf("Expected user message to contain 'Bob', got: %s", userText2)
}
if !strings.Contains(userText2, "writing") {
t.Errorf("Expected user message to contain 'writing', got: %s", userText2)
}
if strings.Contains(userText2, "Alice") {
t.Errorf("BUG: Second render user message contains 'Alice' from first input! Got: %s", userText2)
}
if strings.Contains(userText2, "coding") {
t.Errorf("BUG: Second render user message contains 'coding' from first input! Got: %s", userText2)
}
})
}
Loading