Skip to content

Commit bf10054

Browse files
committed
fix(go): defer template rendering in LoadPrompt to execution time
Previously, LoadPrompt called ToMessages with an empty DataArgument at load time, causing template variables to be replaced with empty values. This meant all subsequent Execute() calls would use prompts with empty template variable values. This change defers template rendering to execution time by using WithMessagesFn. The closure captures the raw template text and compiles/renders it with actual input values when Execute() or Render() is called. The fix properly handles: 1. Template variable substitution with actual input values 2. Multi-role messages (<<<dotprompt:role:XXX>>> markers) 3. History insertion (<<<dotprompt:history>>> markers) Added convertDotpromptMessages helper to convert dotprompt.Message to ai.Message format. Added regression test TestLoadPromptTemplateVariableSubstitution to verify template variables are correctly substituted with different input values on multiple calls. Fixes #3924
1 parent 9dcde54 commit bf10054

File tree

2 files changed

+239
-36
lines changed

2 files changed

+239
-36
lines changed

‎go/ai/prompt.go‎

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,31 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
517517
return result, nil
518518
}
519519

520+
// convertDotpromptMessages converts []dotprompt.Message to []*Message
521+
func convertDotpromptMessages(msgs []dotprompt.Message) ([]*Message, error) {
522+
result := make([]*Message, 0, len(msgs))
523+
for _, msg := range msgs {
524+
parts, err := convertToPartPointers(msg.Content)
525+
if err != nil {
526+
return nil, err
527+
}
528+
// Filter out nil parts
529+
filteredParts := make([]*Part, 0, len(parts))
530+
for _, p := range parts {
531+
if p != nil {
532+
filteredParts = append(filteredParts, p)
533+
}
534+
}
535+
if len(filteredParts) > 0 {
536+
result = append(result, &Message{
537+
Role: Role(msg.Role),
538+
Content: filteredParts,
539+
})
540+
}
541+
}
542+
return result, nil
543+
}
544+
520545
// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
521546
func LoadPromptDir(r api.Registry, dir string, namespace string) {
522547
useDefaultDir := false
@@ -662,51 +687,57 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
662687

663688
key := promptKey(name, variant, namespace)
664689

665-
dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
690+
// Store the raw template text to defer rendering until Execute() is called.
691+
// This ensures template variables are properly substituted with actual input values.
692+
// Previously, ToMessages was called with empty DataArgument which caused template
693+
// variables to be replaced with empty values at load time.
694+
// See: https://github.com/firebase/genkit/issues/3924
695+
templateText := parsedPrompt.Template
696+
697+
promptOpts := []PromptOption{opts}
698+
699+
// Use WithMessagesFn to defer template rendering until execution time.
700+
// This approach properly handles:
701+
// 1. Template variable substitution with actual input values
702+
// 2. Multi-role messages (<<<dotprompt:role:XXX>>> markers)
703+
// 3. History insertion (<<<dotprompt:history>>> markers)
704+
compiledTemplate, err := dp.Compile(templateText, &dotprompt.PromptMetadata{
705+
Input: dotprompt.PromptMetadataInput{
706+
Default: opts.DefaultInput,
707+
},
708+
})
666709
if err != nil {
667-
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
710+
slog.Error("Failed to compile prompt template", "file", sourceFile, "error", err)
668711
return nil
669712
}
670713

671-
var systemText string
672-
var nonSystemMessages []*Message
673-
for _, dpMsg := range dpMessages {
674-
parts, err := convertToPartPointers(dpMsg.Content)
714+
promptOpts = append(promptOpts, WithMessagesFn(func(ctx context.Context, input any) ([]*Message, error) {
715+
inputMap, err := buildVariables(input)
675716
if err != nil {
676-
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
677-
return nil
717+
return nil, err
678718
}
679719

680-
role := Role(dpMsg.Role)
681-
if role == RoleSystem {
682-
var textParts []string
683-
for _, part := range parts {
684-
if part.IsText() {
685-
textParts = append(textParts, part.Text)
686-
}
687-
}
688-
689-
if len(textParts) > 0 {
690-
systemText = strings.Join(textParts, " ")
691-
}
692-
} else {
693-
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
720+
// Prepare the data context for rendering
721+
dataContext := map[string]any{}
722+
actionCtx := core.FromContext(ctx)
723+
maps.Copy(dataContext, actionCtx)
724+
725+
// Render with actual input values at execution time
726+
rendered, err := compiledTemplate(&dotprompt.DataArgument{
727+
Input: inputMap,
728+
Context: dataContext,
729+
}, &dotprompt.PromptMetadata{
730+
Input: dotprompt.PromptMetadataInput{
731+
Default: opts.DefaultInput,
732+
},
733+
})
734+
if err != nil {
735+
return nil, fmt.Errorf("failed to render template: %w", err)
694736
}
695-
}
696-
697-
promptOpts := []PromptOption{opts}
698737

699-
// Add system prompt if found
700-
if systemText != "" {
701-
promptOpts = append(promptOpts, WithSystem(systemText))
702-
}
703-
704-
// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
705-
if len(nonSystemMessages) > 0 {
706-
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
707-
} else if systemText == "" {
708-
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
709-
}
738+
// Convert dotprompt messages to ai messages
739+
return convertDotpromptMessages(rendered.Messages)
740+
}))
710741

711742
prompt := DefinePrompt(r, key, promptOpts...)
712743

‎go/ai/prompt_test.go‎

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,3 +1274,175 @@ Hello!
12741274
t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text)
12751275
}
12761276
}
1277+
1278+
// TestLoadPromptTemplateVariableSubstitution tests that template variables are
1279+
// properly substituted with actual input values at execution time.
1280+
// This is a regression test for https://github.com/firebase/genkit/issues/3924
1281+
func TestLoadPromptTemplateVariableSubstitution(t *testing.T) {
1282+
t.Run("single role", func(t *testing.T) {
1283+
tempDir := t.TempDir()
1284+
1285+
mockPromptFile := filepath.Join(tempDir, "greeting.prompt")
1286+
mockPromptContent := `---
1287+
model: test/chat
1288+
description: A greeting prompt with variables
1289+
---
1290+
Hello {{name}}, welcome to {{place}}!
1291+
`
1292+
1293+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1294+
t.Fatalf("Failed to create mock prompt file: %v", err)
1295+
}
1296+
1297+
prompt := LoadPrompt(registry.New(), tempDir, "greeting.prompt", "template-var-test")
1298+
1299+
// Test with first set of input values
1300+
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
1301+
"name": "Alice",
1302+
"place": "Wonderland",
1303+
})
1304+
if err != nil {
1305+
t.Fatalf("Failed to render prompt with first input: %v", err)
1306+
}
1307+
1308+
if len(actionOpts1.Messages) != 1 {
1309+
t.Fatalf("Expected 1 message, got %d", len(actionOpts1.Messages))
1310+
}
1311+
1312+
text1 := actionOpts1.Messages[0].Content[0].Text
1313+
if !strings.Contains(text1, "Alice") {
1314+
t.Errorf("Expected message to contain 'Alice', got: %s", text1)
1315+
}
1316+
if !strings.Contains(text1, "Wonderland") {
1317+
t.Errorf("Expected message to contain 'Wonderland', got: %s", text1)
1318+
}
1319+
1320+
// Test with second set of input values (different from first)
1321+
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
1322+
"name": "Bob",
1323+
"place": "Paradise",
1324+
})
1325+
if err != nil {
1326+
t.Fatalf("Failed to render prompt with second input: %v", err)
1327+
}
1328+
1329+
if len(actionOpts2.Messages) != 1 {
1330+
t.Fatalf("Expected 1 message, got %d", len(actionOpts2.Messages))
1331+
}
1332+
1333+
text2 := actionOpts2.Messages[0].Content[0].Text
1334+
if !strings.Contains(text2, "Bob") {
1335+
t.Errorf("Expected message to contain 'Bob', got: %s", text2)
1336+
}
1337+
if !strings.Contains(text2, "Paradise") {
1338+
t.Errorf("Expected message to contain 'Paradise', got: %s", text2)
1339+
}
1340+
1341+
// Critical: Ensure the second render did NOT use the first input values
1342+
if strings.Contains(text2, "Alice") {
1343+
t.Errorf("BUG: Second render contains 'Alice' from first input! Got: %s", text2)
1344+
}
1345+
if strings.Contains(text2, "Wonderland") {
1346+
t.Errorf("BUG: Second render contains 'Wonderland' from first input! Got: %s", text2)
1347+
}
1348+
})
1349+
1350+
t.Run("multi role", func(t *testing.T) {
1351+
tempDir := t.TempDir()
1352+
1353+
mockPromptFile := filepath.Join(tempDir, "multi_role.prompt")
1354+
mockPromptContent := `---
1355+
model: test/chat
1356+
description: A multi-role prompt with variables
1357+
---
1358+
<<<dotprompt:role:system>>>
1359+
You are a {{personality}} assistant.
1360+
1361+
<<<dotprompt:role:user>>>
1362+
Hello {{name}}, please help me with {{task}}.
1363+
`
1364+
1365+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1366+
t.Fatalf("Failed to create mock prompt file: %v", err)
1367+
}
1368+
1369+
prompt := LoadPrompt(registry.New(), tempDir, "multi_role.prompt", "multi-role-var-test")
1370+
1371+
// Test with first set of input values
1372+
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
1373+
"personality": "helpful",
1374+
"name": "Alice",
1375+
"task": "coding",
1376+
})
1377+
if err != nil {
1378+
t.Fatalf("Failed to render prompt with first input: %v", err)
1379+
}
1380+
1381+
if len(actionOpts1.Messages) != 2 {
1382+
t.Fatalf("Expected 2 messages, got %d", len(actionOpts1.Messages))
1383+
}
1384+
1385+
// Check system message
1386+
systemMsg := actionOpts1.Messages[0]
1387+
if systemMsg.Role != RoleSystem {
1388+
t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role)
1389+
}
1390+
systemText := systemMsg.Content[0].Text
1391+
if !strings.Contains(systemText, "helpful") {
1392+
t.Errorf("Expected system message to contain 'helpful', got: %s", systemText)
1393+
}
1394+
1395+
// Check user message
1396+
userMsg := actionOpts1.Messages[1]
1397+
if userMsg.Role != RoleUser {
1398+
t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role)
1399+
}
1400+
userText := userMsg.Content[0].Text
1401+
if !strings.Contains(userText, "Alice") {
1402+
t.Errorf("Expected user message to contain 'Alice', got: %s", userText)
1403+
}
1404+
if !strings.Contains(userText, "coding") {
1405+
t.Errorf("Expected user message to contain 'coding', got: %s", userText)
1406+
}
1407+
1408+
// Test with second set of input values (different from first)
1409+
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
1410+
"personality": "professional",
1411+
"name": "Bob",
1412+
"task": "writing",
1413+
})
1414+
if err != nil {
1415+
t.Fatalf("Failed to render prompt with second input: %v", err)
1416+
}
1417+
1418+
if len(actionOpts2.Messages) != 2 {
1419+
t.Fatalf("Expected 2 messages, got %d", len(actionOpts2.Messages))
1420+
}
1421+
1422+
// Check system message with new values
1423+
systemMsg2 := actionOpts2.Messages[0]
1424+
systemText2 := systemMsg2.Content[0].Text
1425+
if !strings.Contains(systemText2, "professional") {
1426+
t.Errorf("Expected system message to contain 'professional', got: %s", systemText2)
1427+
}
1428+
if strings.Contains(systemText2, "helpful") {
1429+
t.Errorf("BUG: Second render system message contains 'helpful' from first input! Got: %s", systemText2)
1430+
}
1431+
1432+
// Check user message with new values
1433+
userMsg2 := actionOpts2.Messages[1]
1434+
userText2 := userMsg2.Content[0].Text
1435+
if !strings.Contains(userText2, "Bob") {
1436+
t.Errorf("Expected user message to contain 'Bob', got: %s", userText2)
1437+
}
1438+
if !strings.Contains(userText2, "writing") {
1439+
t.Errorf("Expected user message to contain 'writing', got: %s", userText2)
1440+
}
1441+
if strings.Contains(userText2, "Alice") {
1442+
t.Errorf("BUG: Second render user message contains 'Alice' from first input! Got: %s", userText2)
1443+
}
1444+
if strings.Contains(userText2, "coding") {
1445+
t.Errorf("BUG: Second render user message contains 'coding' from first input! Got: %s", userText2)
1446+
}
1447+
})
1448+
}

0 commit comments

Comments
 (0)