Skip to content

Commit be12117

Browse files
authored
fix(go/ai): genkit ignores dotprompt-defined roles (#3780)
1 parent 33f1f6a commit be12117

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

‎go/ai/prompt.go‎

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,54 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
624624
}
625625

626626
key := promptKey(name, variant, namespace)
627-
prompt := DefinePrompt(r, key, opts, WithPrompt(parsedPrompt.Template))
627+
628+
dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
629+
if err != nil {
630+
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
631+
return nil
632+
}
633+
634+
var systemText string
635+
var nonSystemMessages []*Message
636+
for _, dpMsg := range dpMessages {
637+
parts, err := convertToPartPointers(dpMsg.Content)
638+
if err != nil {
639+
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
640+
return nil
641+
}
642+
643+
role := Role(dpMsg.Role)
644+
if role == RoleSystem {
645+
var textParts []string
646+
for _, part := range parts {
647+
if part.IsText() {
648+
textParts = append(textParts, part.Text)
649+
}
650+
}
651+
652+
if len(textParts) > 0 {
653+
systemText = strings.Join(textParts, " ")
654+
}
655+
} else {
656+
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
657+
}
658+
}
659+
660+
promptOpts := []PromptOption{opts}
661+
662+
// Add system prompt if found
663+
if systemText != "" {
664+
promptOpts = append(promptOpts, WithSystem(systemText))
665+
}
666+
667+
// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
668+
if len(nonSystemMessages) > 0 {
669+
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
670+
} else if systemText == "" {
671+
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
672+
}
673+
674+
prompt := DefinePrompt(r, key, promptOpts...)
628675

629676
slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile)
630677

‎go/ai/prompt_test.go‎

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,3 +1148,64 @@ Hello!
11481148
t.Fatalf("Failed to execute prompt: %v", err)
11491149
}
11501150
}
1151+
1152+
func TestMultiMessagesRenderPrompt(t *testing.T) {
1153+
tempDir := t.TempDir()
1154+
1155+
mockPromptFile := filepath.Join(tempDir, "example.prompt")
1156+
mockPromptContent := `---
1157+
model: test/chat
1158+
description: A test prompt
1159+
---
1160+
<<<dotprompt:role:system>>>
1161+
You are a pirate!
1162+
1163+
<<<dotprompt:role:user>>>
1164+
Hello!
1165+
`
1166+
1167+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1168+
t.Fatalf("Failed to create mock prompt file: %v", err)
1169+
}
1170+
1171+
prompt := LoadPrompt(registry.New(), tempDir, "example.prompt", "multi-namespace-roles")
1172+
1173+
actionOpts, err := prompt.Render(context.Background(), map[string]any{})
1174+
if err != nil {
1175+
t.Fatalf("Failed to execute prompt: %v", err)
1176+
}
1177+
1178+
// Check that actionOpts is not nil
1179+
if actionOpts == nil {
1180+
t.Fatal("Expected actionOpts to be non-nil")
1181+
}
1182+
1183+
// Check that we have exactly 2 messages (system and user)
1184+
if len(actionOpts.Messages) != 2 {
1185+
t.Fatalf("Expected 2 messages, got %d", len(actionOpts.Messages))
1186+
}
1187+
1188+
// Check first message (system role)
1189+
systemMsg := actionOpts.Messages[0]
1190+
if systemMsg.Role != RoleSystem {
1191+
t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role)
1192+
}
1193+
if len(systemMsg.Content) == 0 {
1194+
t.Fatal("Expected system message to have content")
1195+
}
1196+
if strings.TrimSpace(systemMsg.Content[0].Text) != "You are a pirate!" {
1197+
t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Content[0].Text)
1198+
}
1199+
1200+
// Check second message (user role)
1201+
userMsg := actionOpts.Messages[1]
1202+
if userMsg.Role != RoleUser {
1203+
t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role)
1204+
}
1205+
if len(userMsg.Content) == 0 {
1206+
t.Fatal("Expected user message to have content")
1207+
}
1208+
if strings.TrimSpace(userMsg.Content[0].Text) != "Hello!" {
1209+
t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text)
1210+
}
1211+
}

0 commit comments

Comments
 (0)