Skip to content
62 changes: 22 additions & 40 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"log/slog"
"maps"
"os"
"path/filepath"
"path"
"reflect"
"strings"

Expand Down Expand Up @@ -517,70 +517,54 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
return result, nil
}

// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
useDefaultDir := false
if dir == "" {
dir = "./prompts"
useDefaultDir = true
// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where prompts are located.
func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) {
if fsys == nil {
panic(errors.New("no prompt filesystem provided"))
}

path, err := filepath.Abs(dir)
if err != nil {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
if _, err := fs.Stat(fsys, dir); err != nil {
panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err))
}

if _, err := os.Stat(path); os.IsNotExist(err) {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
}

loadPromptDir(r, path, namespace)
}

// loadPromptDir recursively loads prompts and partials from the directory.
func loadPromptDir(r api.Registry, dir string, namespace string) {
entries, err := os.ReadDir(dir)
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to read prompt directory structure: %w", err))
}

for _, entry := range entries {
filename := entry.Name()
path := filepath.Join(dir, filename)
filePath := path.Join(dir, filename)
if entry.IsDir() {
loadPromptDir(r, path, namespace)
LoadPromptDirFromFS(r, fsys, filePath, namespace)
} else if strings.HasSuffix(filename, ".prompt") {
if strings.HasPrefix(filename, "_") {
partialName := strings.TrimSuffix(filename[1:], ".prompt")
source, err := os.ReadFile(path)
source, err := fs.ReadFile(fsys, filePath)
if err != nil {
slog.Error("Failed to read partial file", "error", err)
continue
}
r.RegisterPartial(partialName, string(source))
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path)
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath)
} else {
LoadPrompt(r, dir, filename, namespace)
LoadPromptFromFS(r, fsys, dir, filename, namespace)
}
}
}
}

// LoadPrompt loads a single prompt into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
// LoadPromptFromFS loads a single prompt from a filesystem into the registry.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where the prompt is located.
func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt {
name := strings.TrimSuffix(filename, ".prompt")
name, variant, _ := strings.Cut(name, ".")

sourceFile := filepath.Join(dir, filename)
source, err := os.ReadFile(sourceFile)
sourceFile := path.Join(dir, filename)
source, err := fs.ReadFile(fsys, sourceFile)
if err != nil {
slog.Error("Failed to read prompt file", "file", sourceFile, "error", err)
return nil
Expand Down Expand Up @@ -696,12 +680,10 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

promptOpts := []PromptOption{opts}

// Add system prompt if found
if systemText != "" {
promptOpts = append(promptOpts, WithSystem(systemText))
}

// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
if len(nonSystemMessages) > 0 {
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
} else if systemText == "" {
Expand Down
Loading
Loading