Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,29 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
modelName = genOpts.ModelName
}

var dynamicTools []Tool
tools := make([]string, len(genOpts.Tools))
for i, tool := range genOpts.Tools {
tools[i] = tool.Name()
toolNames := make(map[string]bool)
for i, toolRef := range genOpts.Tools {
name := toolRef.Name()
// Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools.
if toolNames[name] {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name)
}
toolNames[name] = true
tools[i] = name
// Dynamic tools wouldn't have been registered by this point.
if LookupTool(r, name) == nil {
if tool, ok := toolRef.(Tool); ok {
dynamicTools = append(dynamicTools, tool)
}
}
}
if len(dynamicTools) > 0 {
r = r.NewChild()
for _, tool := range dynamicTools {
tool.Register(r)
}
}

messages := []*Message{}
Expand Down Expand Up @@ -596,7 +616,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq

output, err := tool.RunRaw(ctx, toolReq.Input)
if err != nil {
var tie *ToolInterruptError
var tie *toolInterruptError
if errors.As(err, &tie) {
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)

Expand Down Expand Up @@ -636,7 +656,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
for range toolCount {
res := <-resultChan
if res.err != nil {
var tie *ToolInterruptError
var tie *toolInterruptError
if errors.As(res.err, &tie) {
hasInterrupts = true
continue
Expand Down Expand Up @@ -878,7 +898,7 @@ func handleResumedToolRequest(ctx context.Context, r *registry.Registry, genOpts

output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input)
if err != nil {
var tie *ToolInterruptError
var tie *toolInterruptError
if errors.As(err, &tie) {
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", restartPart.ToolRequest.Name, tie.Metadata)

Expand Down
118 changes: 118 additions & 0 deletions go/ai/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,124 @@ func TestGenerate(t *testing.T) {
t.Errorf("got text %q, want %q", res.Text(), expectedText)
}
})

t.Run("registers dynamic tools", func(t *testing.T) {
// Create a tool that is NOT registered in the global registry
dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered",
func(ctx *ToolContext, input struct {
Message string
}) (string, error) {
return "Dynamic: " + input.Message, nil
},
)

// Verify the tool is not in the global registry
if LookupTool(r, "dynamicTestTool") != nil {
t.Fatal("dynamicTestTool should not be registered in global registry")
}

// Create a model that will call the dynamic tool then provide a final response
roundCount := 0
info := &ModelInfo{
Supports: &ModelSupports{
Multiturn: true,
Tools: true,
},
}
toolCallModel := DefineModel(r, "test", "toolcall", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
// First response: call the dynamic tool
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "dynamicTestTool",
Input: map[string]any{"Message": "Hello from dynamic tool"},
}),
},
},
}, nil
}
// Second response: provide final answer based on tool response
var toolResult string
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.ToolResponse != nil {
toolResult = part.ToolResponse.Output.(string)
}
}
}
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart(toolResult),
},
},
}, nil
})

// Use Generate with the dynamic tool - this should trigger the dynamic registration
res, err := Generate(context.Background(), r,
WithModel(toolCallModel),
WithPrompt("call the dynamic tool"),
WithTools(dynamicTool),
)
if err != nil {
t.Fatal(err)
}

// The tool should have been called and returned a response
expectedText := "Dynamic: Hello from dynamic tool"
if res.Text() != expectedText {
t.Errorf("expected text %q, got %q", expectedText, res.Text())
}

// Verify two rounds were executed: tool call + final response
if roundCount != 2 {
t.Errorf("expected 2 rounds, got %d", roundCount)
}

// Verify the tool is still not in the global registry (it was registered in a child)
if LookupTool(r, "dynamicTestTool") != nil {
t.Error("dynamicTestTool should not be registered in global registry after generation")
}
})

t.Run("handles duplicate dynamic tools", func(t *testing.T) {
// Create two tools with the same name
dynamicTool1 := NewTool("duplicateTool", "first tool",
func(ctx *ToolContext, input any) (string, error) {
return "tool1", nil
},
)
dynamicTool2 := NewTool("duplicateTool", "second tool",
func(ctx *ToolContext, input any) (string, error) {
return "tool2", nil
},
)

// Using both tools should result in an error
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithPrompt("test duplicate tools"),
WithTools(dynamicTool1, dynamicTool2),
)

if err == nil {
t.Fatal("expected error for duplicate tool names")
}
if !strings.Contains(err.Error(), "duplicate tool \"duplicateTool\"") {
t.Errorf("unexpected error message: %v", err)
}
})
}

func TestModelVersion(t *testing.T) {
Expand Down
76 changes: 40 additions & 36 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,21 @@ type Tool interface {
Definition() *ToolDefinition
// RunRaw runs this tool using the provided raw input.
RunRaw(ctx context.Context, input any) (any, error)
// Register sets the tracing state on the action and registers it with the registry.
Register(r *registry.Registry)
// Respond constructs a *Part with a ToolResponse for a given interrupted tool request.
Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part
// Restart constructs a *Part with a new ToolRequest to re-trigger a tool,
// potentially with new input and resumedMetadata.
// potentially with new input and metadata.
Restart(toolReq *Part, opts *RestartOptions) *Part
}

// ToolInterruptError represents an intentional interruption of tool execution.
type ToolInterruptError struct {
// toolInterruptError represents an intentional interruption of tool execution.
type toolInterruptError struct {
Metadata map[string]any
}

func (e *ToolInterruptError) Error() string {
func (e *toolInterruptError) Error() string {
return "tool execution interrupted"
}

Expand Down Expand Up @@ -112,58 +114,54 @@ type ToolContext struct {
OriginalInput any
}

// DefineTool defines a tool function with interrupt capability
// DefineTool defines a tool.
func DefineTool[In, Out any](r *registry.Registry, name, description string,
fn func(ctx *ToolContext, input In) (Out, error)) Tool {
wrappedFn := func(ctx context.Context, input In) (Out, error) {
toolCtx := &ToolContext{
Context: ctx,
Interrupt: func(opts *InterruptOptions) error {
return &ToolInterruptError{
Metadata: opts.Metadata,
}
},
Resumed: resumedCtxKey.FromContext(ctx),
OriginalInput: origInputCtxKey.FromContext(ctx),
}
return fn(toolCtx, input)
}

metadata := map[string]any{
"type": "tool",
"name": name,
"description": description,
}
metadata, wrappedFn := implementTool(name, description, fn)
toolAction := core.DefineAction(r, "", name, core.ActionTypeTool, metadata, wrappedFn)

return &tool{Action: toolAction}
}

// DefineToolWithInputSchema defines a tool function with a custom input schema and interrupt capability.
// The input schema allows specifying a JSON Schema for validating tool inputs.
// DefineToolWithInputSchema defines a tool function with a custom input schema.
func DefineToolWithInputSchema[Out any](r *registry.Registry, name, description string,
inputSchema *jsonschema.Schema,
fn func(ctx *ToolContext, input any) (Out, error)) Tool {
metadata := make(map[string]any)
metadata["type"] = "tool"
metadata["name"] = name
metadata["description"] = description
metadata, wrappedFn := implementTool(name, description, fn)
toolAction := core.DefineActionWithInputSchema(r, "", name, core.ActionTypeTool, metadata, inputSchema, wrappedFn)
return &tool{Action: toolAction}
}

wrappedFn := func(ctx context.Context, input any) (Out, error) {
// NewTool creates a tool but does not register it in the registry. It can be passed directly to [Generate].
func NewTool[In, Out any](name, description string,
fn func(ctx *ToolContext, input In) (Out, error)) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
toolAction := core.NewAction("", name, core.ActionTypeTool, metadata, wrappedFn)
return &tool{Action: toolAction}
}

// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
func implementTool[In, Out any](name, description string, fn func(ctx *ToolContext, input In) (Out, error)) (map[string]any, func(context.Context, In) (Out, error)) {
metadata := map[string]any{
"type": core.ActionTypeTool,
"name": name,
"description": description,
}
wrappedFn := func(ctx context.Context, input In) (Out, error) {
toolCtx := &ToolContext{
Context: ctx,
Interrupt: func(opts *InterruptOptions) error {
return &ToolInterruptError{
return &toolInterruptError{
Metadata: opts.Metadata,
}
},
Resumed: resumedCtxKey.FromContext(ctx),
OriginalInput: origInputCtxKey.FromContext(ctx),
}
return fn(toolCtx, input)
}

toolAction := core.DefineActionWithInputSchema(r, "", name, core.ActionTypeTool, metadata, inputSchema, wrappedFn)

return &tool{Action: toolAction}
return metadata, wrappedFn
}

// Name returns the name of the tool.
Expand Down Expand Up @@ -193,6 +191,12 @@ func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, t.Definition(), t.Action, input)
}

// Register sets the tracing state on the action and registers it with the registry.
func (t *tool) Register(r *registry.Registry) {
t.Action.SetTracingState(r.TracingState())
r.RegisterAction(fmt.Sprintf("/%s/%s", core.ActionTypeTool, t.Action.Name()), t.Action)
}

// runAction runs the given action with the provided raw input and returns the output in raw format.
func runAction(ctx context.Context, def *ToolDefinition, action core.Action, input any) (any, error) {
mi, err := json.Marshal(input)
Expand Down
40 changes: 37 additions & 3 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type Action interface {
RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)
// Desc returns a descriptor of the action.
Desc() ActionDesc
// SetTracingState sets the tracing state on the action.
SetTracingState(tstate *tracing.State)
}

// An ActionType is the kind of an action.
Expand Down Expand Up @@ -106,6 +108,23 @@ func DefineAction[In, Out any](
})
}

// NewAction creates a new non-streaming Action without registering it.
func NewAction[In, Out any](
provider, name string,
atype ActionType,
metadata map[string]any,
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
fullName := name
if provider != "" {
fullName = provider + "/" + name
}
return newAction(nil, fullName, atype, metadata, nil,
func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// DefineStreamingAction creates a new streaming action and registers it.
func DefineStreamingAction[In, Out, Stream any](
r *registry.Registry,
Expand Down Expand Up @@ -155,6 +174,7 @@ func defineAction[In, Out, Stream any](
}

// newAction creates a new Action with the given name and arguments.
// If registry is nil, tracing state is left nil to be set later.
// If inputSchema is nil, it is inferred from In.
func newAction[In, Out, Stream any](
r *registry.Registry,
Expand All @@ -164,23 +184,31 @@ func newAction[In, Out, Stream any](
inputSchema *jsonschema.Schema,
fn StreamingFunc[In, Out, Stream],
) *ActionDef[In, Out, Stream] {
var i In
var o Out
if inputSchema == nil {
var i In
if reflect.ValueOf(i).Kind() != reflect.Invalid {
inputSchema = base.InferJSONSchema(i)
}
}

var o Out
var outputSchema *jsonschema.Schema
if reflect.ValueOf(o).Kind() != reflect.Invalid {
outputSchema = base.InferJSONSchema(o)
}

var description string
if desc, ok := metadata["description"].(string); ok {
description = desc
}

var tstate *tracing.State
if r != nil {
tstate = r.TracingState()
}

return &ActionDef[In, Out, Stream]{
tstate: r.TracingState(),
tstate: tstate,
fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, cb)
Expand All @@ -200,6 +228,12 @@ func newAction[In, Out, Stream any](
// Name returns the Action's Name.
func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name }

// SetTracingState sets the tracing state on the action. This is used when an action
// created without a registry needs to have its tracing state set later.
func (a *ActionDef[In, Out, Stream]) SetTracingState(tstate *tracing.State) {
a.tstate = tstate
}

// Run executes the Action's function in a new trace span.
func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) {
logger.FromContext(ctx).Debug("Action.Run",
Expand Down
Loading
Loading