Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"

"github.com/firebase/genkit/go/core"
)

// PromptRequest is a request to execute a prompt template and
// pass the result to a [Generator].
type PromptRequest struct {
// Input fields for the prompt. If not nil this should be a struct
// or pointer to a struct that matches the prompt's input schema.
Variables any `json:"variables,omitempty"`
// Number of candidates to return; if 0, will be taken
// from the prompt config; if still 0, will use 1.
Candidates int `json:"candidates,omitempty"`
// Generator Configuration. If nil will be taken from the prompt config.
Config *GenerationCommonConfig `json:"config,omitempty"`
// Context to pass to model, if any.
Context []any `json:"context,omitempty"`
// The model to use. This overrides any model specified by the prompt.
Model string `json:"model,omitempty"`
}

// Prompt is the interface used to execute a prompt template and
// pass the result to a [Generator].
type Prompt interface {
Generate(context.Context, *PromptRequest, func(context.Context, *Candidate) error) (*GenerateResponse, error)
}

// RegisterPrompt registers a prompt in the global registry.
func RegisterPrompt(provider, name string, prompt Prompt) {
metadata := map[string]any{
"type": "prompt",
"prompt": prompt,
}
core.RegisterAction(provider,
core.NewStreamingAction(name, core.ActionTypePrompt, metadata, prompt.Generate))
}
3 changes: 1 addition & 2 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ type Prompt struct {
// A hash of the prompt contents.
hash string

// A Generator to use. If not nil, this is used by the
// [genkit.Action] returned by [Prompt.Action] to execute the prompt.
// A Generator to use. If not nil, this is used to execute the prompt.
generator ai.Generator
}

Expand Down
98 changes: 36 additions & 62 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,25 @@ import (
"strings"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/tracing"
)

// ActionInput is the input type of a prompt action.
// This should have all the fields of GenerateRequest other than
// Messages, Tools, and Output.
type ActionInput struct {
// Input variables to substitute in the template.
// TODO(ianlancetaylor) Not sure variables is the right name here.
Variables map[string]any `json:"variables,omitempty"`
// Number of candidates to return; if 0, 1 is used.
Candidates int `json:"candidates,omitempty"`
// Configuration.
Config *ai.GenerationCommonConfig `json:"config,omitempty"`
// The model to use. This overrides any model in the prompt.
Model string `json:"model,omitempty"`
}

// BuildVariables returns a map for [ActionInput.Variables] based
// on a pointer to a struct value. The struct value should have
// buildVariables returns a map holding prompt field values based
// on a struct or a pointer to a struct. The struct value should have
// JSON tags that correspond to the Prompt's input schema.
// Only exported fields of the struct will be used.
func (p *Prompt) BuildVariables(input any) (map[string]any, error) {
v := reflect.ValueOf(input).Elem()
func (p *Prompt) buildVariables(variables any) (map[string]any, error) {
if variables == nil {
return nil, nil
}

v := reflect.Indirect(reflect.ValueOf(variables))
if v.Kind() != reflect.Struct {
return nil, errors.New("BuildVariables: not a pointer to a struct")
return nil, errors.New("dotprompt: fields not a struct or pointer to a struct")
}
vt := v.Type()

// TODO(ianlancetaylor): Verify the struct with p.Frontmatter.Schema.
// TODO(ianlancetaylor): Verify the struct with p.Config.InputSchema.

m := make(map[string]any)

Expand Down Expand Up @@ -89,55 +77,43 @@ fieldLoop:
}

// buildRequest prepares an [ai.GenerateRequest] based on the prompt,
// using the input variables and other information in the [ActionInput].
func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) {
// using the input variables and other information in the [ai.PromptRequest].
func (p *Prompt) buildRequest(pr *ai.PromptRequest) (*ai.GenerateRequest, error) {
req := &ai.GenerateRequest{}

var err error
if req.Messages, err = p.RenderMessages(input.Variables); err != nil {
m, err := p.buildVariables(pr.Variables)
if err != nil {
return nil, err
}
if req.Messages, err = p.RenderMessages(m); err != nil {
return nil, err
}

req.Candidates = input.Candidates
req.Candidates = pr.Candidates
if req.Candidates == 0 {
req.Candidates = p.Candidates
}
if req.Candidates == 0 {
req.Candidates = 1
}

req.Config = p.GenerationConfig
req.Config = pr.Config
if req.Config == nil {
req.Config = p.GenerationConfig
}

req.Context = pr.Context

req.Output = &ai.GenerateRequestOutput{
Format: p.OutputFormat,
Schema: p.OutputSchema,
}

req.Tools = p.Tools

return req, nil
}

// Action returns a [core.Action] that executes the prompt.
// The returned Action will take an [ActionInput] that provides
// variables to substitute into the template text.
// It will then pass the rendered text to an AI generator,
// and return whatever the generator computes.
func (p *Prompt) Action() (*core.Action[*ActionInput, *ai.GenerateResponse, struct{}], error) {
if p.Name == "" {
return nil, errors.New("dotprompt: missing name")
}
name := p.Name
if p.Variant != "" {
name += "." + p.Variant
}

a := core.NewAction(name, core.ActionTypePrompt, nil, p.Execute)
a.Metadata = map[string]any{
"type": "prompt",
"prompt": p,
}
return a, nil
}

// Register registers an action to execute a prompt.
func (p *Prompt) Register() error {
name := p.Name
Expand All @@ -148,34 +124,32 @@ func (p *Prompt) Register() error {
name += "." + p.Variant
}

action, err := p.Action()
if err != nil {
return err
}
ai.RegisterPrompt("dotprompt", name, p)

core.RegisterAction("dotprompt", action)
return nil
}

// Execute executes a prompt. It does variable substitution and
// Generate executes a prompt. It does variable substitution and
// passes the rendered template to the AI generator specified by
// the prompt.
func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateResponse, error) {
//
// This implements the [ai.Prompt] interface.
func (p *Prompt) Generate(ctx context.Context, pr *ai.PromptRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", "prompt")

genReq, err := p.buildRequest(input)
genReq, err := p.buildRequest(pr)
if err != nil {
return nil, err
}

generator := p.generator
if generator == nil {
model := p.Model
if input.Model != "" {
model = input.Model
if pr.Model != "" {
model = pr.Model
}
if model == "" {
return nil, errors.New("dotprompt action: model not specified")
return nil, errors.New("dotprompt execution: model not specified")
}
provider, name, found := strings.Cut(model, "/")
if !found {
Expand All @@ -188,7 +162,7 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR
}
}

resp, err := ai.Generate(ctx, generator, genReq, nil)
resp, err := ai.Generate(ctx, generator, genReq, cb)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestExecute(t *testing.T) {
t.Fatal(err)
}
p.generator = testGenerator{}
resp, err := p.Execute(context.Background(), &ActionInput{})
resp, err := p.Generate(context.Background(), &ai.PromptRequest{}, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 4 additions & 2 deletions go/plugins/dotprompt/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import (
"github.com/firebase/genkit/go/ai"
)

// RenderText returns the Prompt as a string.
// It may contain only a single, text, message.
// RenderText executes the prompt's template and returns the result
// as a string. The result may contain only a single, text, message.
// This just runs the template; it does not call a generator.
func (p *Prompt) RenderText(variables map[string]any) (string, error) {
msgs, err := p.RenderMessages(variables)
if err != nil {
Expand All @@ -47,6 +48,7 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) {
}

// RenderMessages executes the prompt's template and converts it into messages.
// This just runs the template; it does not call a generator.
func (p *Prompt) RenderMessages(variables map[string]any) ([]*ai.Message, error) {
if p.VariableDefaults != nil {
nv := make(map[string]any)
Expand Down
36 changes: 18 additions & 18 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ func main() {
}

simpleGreetingFlow := genkit.DefineFlow("simpleGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) {
vars, err := simpleGreetingPrompt.BuildVariables(input)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := simpleGreetingPrompt.Execute(ctx, ai)
resp, err := simpleGreetingPrompt.Generate(ctx,
&ai.PromptRequest{
Variables: input,
},
nil,
)
if err != nil {
return "", err
}
Expand All @@ -149,12 +149,12 @@ func main() {
}

greetingWithHistoryFlow := genkit.DefineFlow("greetingWithHistory", func(ctx context.Context, input *customerTimeAndHistoryInput, _ genkit.NoStream) (string, error) {
vars, err := greetingWithHistoryPrompt.BuildVariables(input)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := greetingWithHistoryPrompt.Execute(ctx, ai)
resp, err := greetingWithHistoryPrompt.Generate(ctx,
&ai.PromptRequest{
Variables: input,
},
nil,
)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -194,12 +194,12 @@ func main() {
}

genkit.DefineFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) {
vars, err := simpleGreetingPrompt.BuildVariables(input)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := simpleStructuredGreetingPrompt.Execute(ctx, ai)
resp, err := simpleStructuredGreetingPrompt.Generate(ctx,
&ai.PromptRequest{
Variables: input,
},
nil,
)
if err != nil {
return "", err
}
Expand Down
12 changes: 6 additions & 6 deletions go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ func main() {
Context: sb.String(),
}

vars, err := simpleQaPrompt.BuildVariables(promptInput)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := simpleQaPrompt.Execute(ctx, ai)
resp, err := simpleQaPrompt.Generate(ctx,
&ai.PromptRequest{
Variables: promptInput,
},
nil,
)
if err != nil {
return "", err
}
Expand Down