Skip to content
5 changes: 2 additions & 3 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
)

Expand All @@ -43,13 +42,13 @@ func DefineEmbedder(
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
return (*embedder)(core.DefineAction(r, provider, name, core.ActionTypeEmbedder, nil, embed))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, atype.Embedder, provider, name)
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}
Expand Down
7 changes: 3 additions & 4 deletions go/ai/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -134,7 +133,7 @@ func DefineEvaluator(r *registry.Registry, provider, name string, options *Evalu
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition

actionDef := (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
actionDef := (*evaluator)(core.DefineAction(r, provider, name, core.ActionTypeEvaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
var evalResponses []EvaluationResult
for _, datapoint := range req.Dataset {
if datapoint.TestCaseId == "" {
Expand Down Expand Up @@ -193,13 +192,13 @@ func DefineBatchEvaluator(r *registry.Registry, provider, name string, options *
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition

return (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
return (*evaluator)(core.DefineAction(r, provider, name, core.ActionTypeEvaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
}

// LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator].
// It returns nil if the evaluator was not defined.
func LookupEvaluator(r *registry.Registry, provider, name string) Evaluator {
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, core.ActionTypeEvaluator, provider, name))
}

// Evaluate calls the retrivers with provided options.
Expand Down
7 changes: 3 additions & 4 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)
Expand Down Expand Up @@ -76,7 +75,7 @@ type (

// DefineGenerateAction defines a utility generate action.
func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction {
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, nil,
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", core.ActionTypeUtil, nil,
func(ctx context.Context, actionOpts *GenerateActionOptions, cb ModelStreamCallback) (resp *ModelResponse, err error) {
logger.FromContext(ctx).Debug("GenerateAction",
"input", fmt.Sprintf("%#v", actionOpts))
Expand Down Expand Up @@ -137,13 +136,13 @@ func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, f
}
fn = core.ChainMiddleware(mws...)(fn)

return (*model)(core.DefineStreamingAction(r, provider, name, atype.Model, metadata, fn))
return (*model)(core.DefineStreamingAction(r, provider, name, core.ActionTypeModel, metadata, fn))
}

// LookupModel looks up a [Model] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(r *registry.Registry, provider, name string) Model {
action := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, atype.Model, provider, name)
action := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, core.ActionTypeModel, provider, name)
if action == nil {
return nil
}
Expand Down
5 changes: 2 additions & 3 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/dotprompt/go/dotprompt"
Expand Down Expand Up @@ -93,15 +92,15 @@ func DefinePrompt(r *registry.Registry, name string, opts ...PromptOption) (*Pro
}
maps.Copy(meta, promptMeta)

p.action = *core.DefineActionWithInputSchema(r, "", name, atype.ExecutablePrompt, meta, p.InputSchema, p.buildRequest)
p.action = *core.DefineActionWithInputSchema(r, "", name, core.ActionTypeExecutablePrompt, meta, p.InputSchema, p.buildRequest)

return p, nil
}

// LookupPrompt looks up a [Prompt] registered by [DefinePrompt].
// It returns nil if the prompt was not defined.
func LookupPrompt(r *registry.Registry, name string) *Prompt {
action := core.LookupActionFor[any, *GenerateActionOptions, struct{}](r, atype.ExecutablePrompt, "", name)
action := core.LookupActionFor[any, *GenerateActionOptions, struct{}](r, core.ActionTypeExecutablePrompt, "", name)
if action == nil {
return nil
}
Expand Down
5 changes: 2 additions & 3 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
)

Expand All @@ -41,13 +40,13 @@ type retriever core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]
// DefineRetriever registers the given retrieve function as an action, and returns a
// [Retriever] that runs it.
func DefineRetriever(r *registry.Registry, provider, name string, fn RetrieverFunc) Retriever {
return (*retriever)(core.DefineAction(r, provider, name, atype.Retriever, nil, fn))
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, nil, fn))
}

// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
// It returns nil if the retriever was not defined.
func LookupRetriever(r *registry.Registry, provider, name string) Retriever {
return (*retriever)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, atype.Retriever, provider, name))
return (*retriever)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, core.ActionTypeRetriever, provider, name))
}

// Retrieve runs the given [Retriever].
Expand Down
58 changes: 19 additions & 39 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/action"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)
Expand All @@ -42,13 +40,11 @@ func (t ToolName) Name() string {
return (string)(t)
}

// A ToolDef is an implementation of a single tool.
type ToolDef[In, Out any] core.ActionDef[In, Out, struct{}]

// tool is genericless version of ToolDef. It's required to make [LookupTool] possible.
// tool is the internal implementation of the Tool interface.
// It holds the underlying core action and allows looking up tools
// by name without knowing their specific input/output types.
type tool struct {
// action is the underlying internal action. It's needed for the descriptor.
action action.Action
core.Action
}

// Tool represents an instance of a tool.
Expand Down Expand Up @@ -85,8 +81,11 @@ type ToolContext struct {
}

// DefineTool defines a tool function with interrupt capability
func DefineTool[In, Out any](r *registry.Registry, name, description string,
fn func(ctx *ToolContext, input In) (Out, error)) Tool {
func DefineTool[In, Out any](
r *registry.Registry,
name, description string,
fn func(ctx *ToolContext, input In) (Out, error),
) Tool {
metadata := make(map[string]any)
metadata["type"] = "tool"
metadata["name"] = name
Expand All @@ -104,35 +103,22 @@ func DefineTool[In, Out any](r *registry.Registry, name, description string,
return fn(toolCtx, input)
}

toolAction := core.DefineAction(r, "", name, atype.Tool, metadata, wrappedFn)

return &tool{action: toolAction}
}
toolAction := core.DefineAction(r, "", name, core.ActionTypeTool, metadata, wrappedFn)

// Name returns the name of the tool.
func (ta *tool) Name() string {
return ta.Definition().Name
return &tool{Action: toolAction}
}

// Name returns the name of the tool.
func (t *ToolDef[In, Out]) Name() string {
return t.Definition().Name
}

// Definition returns [ToolDefinition] for for this tool.
func (t *ToolDef[In, Out]) Definition() *ToolDefinition {
return definition((*core.ActionDef[In, Out, struct{}])(t).Desc())
func (t *tool) Name() string {
return t.Action.Name()
}

// Definition returns [ToolDefinition] for for this tool.
func (t *tool) Definition() *ToolDefinition {
return definition(t.action.Desc())
}

func definition(desc action.Desc) *ToolDefinition {
desc := t.Action.Desc()
td := &ToolDefinition{
Name: desc.Metadata["name"].(string),
Description: desc.Metadata["description"].(string),
Name: desc.Name,
Description: desc.Description,
}
if desc.InputSchema != nil {
td.InputSchema = base.SchemaAsMap(desc.InputSchema)
Expand All @@ -146,13 +132,7 @@ func definition(desc action.Desc) *ToolDefinition {
// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, t.Definition(), t.action, input)
}

// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (t *ToolDef[In, Out]) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, t.Definition(), (*core.ActionDef[In, Out, struct{}])(t), input)
return runAction(ctx, t.Definition(), t.Action, input)
}

// runAction runs the given action with the provided raw input and returns the output in raw format.
Expand Down Expand Up @@ -180,9 +160,9 @@ func LookupTool(r *registry.Registry, name string) Tool {
return nil
}

action := r.LookupAction(fmt.Sprintf("/%s/%s", atype.Tool, name))
action := r.LookupAction(fmt.Sprintf("/%s/%s", core.ActionTypeTool, name))
if action == nil {
return nil
}
return &tool{action: action}
return &tool{Action: action.(core.Action)}
}
Loading
Loading