Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ func Embed(ctx context.Context, r *registry.Registry, opts ...EmbedderOption) (*
}
}

if embedOpts.Embedder == nil {
return nil, fmt.Errorf("ai.Embed: embedder must be set")
}
e, ok := embedOpts.Embedder.(Embedder)
if !ok {
e = LookupEmbedder(r, embedOpts.Embedder.Name())
Expand Down
48 changes: 44 additions & 4 deletions go/ai/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ type Evaluator interface {
Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error)
}

// EvaluatorArg is the interface for evaluator arguments. It can either be the evaluator action itself or a reference to be looked up.
type EvaluatorArg interface {
Name() string
}

// EvaluatorRef is a struct to hold evaluator name and configuration.
type EvaluatorRef struct {
name string
config any
}

// NewEvaluatorRef creates a new EvaluatorRef with the given name and configuration.
func NewEvaluatorRef(name string, config any) EvaluatorRef {
return EvaluatorRef{name: name, config: config}
}

// Name returns the name of the evaluator.
func (e EvaluatorRef) Name() string {
return e.name
}

// Config returns the configuration to use by default for this evaluator.
func (e EvaluatorRef) Config() any {
return e.config
}

// evaluator is an action with functions specific to evaluating a dataset.
type evaluator struct {
core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
Expand Down Expand Up @@ -273,20 +299,34 @@ func (e evaluator) Evaluate(ctx context.Context, req *EvaluatorRequest) (*Evalua
}

// Evaluate calls the retrivers with provided options.
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluatorOption) (*EvaluatorResponse, error) {
func Evaluate(ctx context.Context, r *registry.Registry, opts ...EvaluatorOption) (*EvaluatorResponse, error) {
evalOpts := &evaluatorOptions{}
for _, opt := range opts {
err := opt.applyEvaluator(evalOpts)
if err != nil {
if err := opt.applyEvaluator(evalOpts); err != nil {
return nil, err
}
}

if evalOpts.Evaluator == nil {
return nil, fmt.Errorf("ai.Evaluate: evaluator must be set")
}
e, ok := evalOpts.Evaluator.(Evaluator)
if !ok {
e = LookupEvaluator(r, evalOpts.Evaluator.Name())
}
if e == nil {
return nil, fmt.Errorf("ai.Evaluate: evaluator not found: %s", evalOpts.Evaluator.Name())
}

if evalRef, ok := evalOpts.Evaluator.(EvaluatorRef); ok && evalOpts.Config == nil {
evalOpts.Config = evalRef.Config()
}

req := &EvaluatorRequest{
Dataset: evalOpts.Dataset,
EvaluationId: evalOpts.ID,
Options: evalOpts.Config,
}

return r.Evaluate(ctx, req)
return e.Evaluate(ctx, req)
}
3 changes: 2 additions & 1 deletion go/ai/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ func TestEvaluate(t *testing.T) {

evalAction := DefineEvaluator(r, "test/testEvaluator", &evalOpts, testEvalFunc)

resp, err := Evaluate(context.Background(), evalAction,
resp, err := Evaluate(context.Background(), r,
WithEvaluator(evalAction),
WithDataset(dataset...),
WithID("testrun"),
WithConfig("test-options"))
Expand Down
24 changes: 22 additions & 2 deletions go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,9 @@ func WithDocs(docs ...*Document) DocumentOption {
// evaluatorOptions are options for providing a dataset to evaluate.
type evaluatorOptions struct {
configOptions
Dataset []*Example // Dataset to evaluate.
ID string // ID of the evaluation.
Dataset []*Example // Dataset to evaluate.
ID string // ID of the evaluation.
Evaluator EvaluatorArg // Evaluator to use.
}

// EvaluatorOption is an option for providing a dataset to evaluate.
Expand Down Expand Up @@ -656,6 +657,13 @@ func (o *evaluatorOptions) applyEvaluator(evalOpts *evaluatorOptions) error {
evalOpts.ID = o.ID
}

if o.Evaluator != nil {
if evalOpts.Evaluator != nil {
return errors.New("cannot set evaluator more than once (WithEvaluator or WithEvaluatorName)")
}
evalOpts.Evaluator = o.Evaluator
}

return nil
}

Expand All @@ -669,6 +677,18 @@ func WithID(ID string) EvaluatorOption {
return &evaluatorOptions{ID: ID}
}

// WithEvaluator sets either a [Evaluator] or a [EvaluatorRef] that may contain a config.
// Passing [WithConfig] will take precedence over the config in WithEvaluator.
func WithEvaluator(evaluator EvaluatorArg) EvaluatorOption {
return &evaluatorOptions{Evaluator: evaluator}
}

// WithEvaluatorName sets the evaluator name to call for document evaluation.
// The evaluator name will be resolved to a [Evaluator] and may error if the reference is invalid.
func WithEvaluatorName(name string) EvaluatorOption {
return &evaluatorOptions{Evaluator: NewEvaluatorRef(name, nil)}
}

// embedderOptions holds configuration and input for an embedder request.
type embedderOptions struct {
configOptions
Expand Down
3 changes: 3 additions & 0 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ func Retrieve(ctx context.Context, r *registry.Registry, opts ...RetrieverOption
return nil, errors.New("ai.Retrieve: only supports a single document as input")
}

if retOpts.Retriever == nil {
return nil, fmt.Errorf("ai.Retrieve: retriever must be set")
}
ret, ok := retOpts.Retriever.(Retriever)
if !ok {
ret = LookupRetriever(r, retOpts.Retriever.Name())
Expand Down
Loading