Skip to content

Commit 12706a0

Browse files
authored
feat(go): Added ModelArg interface and ModelRef (#2487)
1 parent b3b1595 commit 12706a0

File tree

6 files changed

+290
-46
lines changed

6 files changed

+290
-46
lines changed

‎go/ai/generate.go‎

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ type (
4141
Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error)
4242
}
4343

44+
// ModelArg is the interface for model arguments.
45+
ModelArg interface {
46+
Name() string
47+
}
48+
49+
// ModelRef is a struct to hold model name and configuration.
50+
ModelRef struct {
51+
name string
52+
config any
53+
}
54+
4455
// ToolConfig handles configuration around tool calls during generation.
4556
ToolConfig struct {
4657
MaxTurns int // Maximum number of tool call iterations before erroring.
@@ -87,6 +98,7 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
8798

8899
// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
89100
func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, fn ModelFunc) Model {
101+
90102
if info == nil {
91103
// Always make sure there's at least minimal metadata.
92104
info = &ModelInfo{
@@ -113,6 +125,17 @@ func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, f
113125
metadata["label"] = info.Label
114126
}
115127

128+
if info.ConfigSchema != nil {
129+
metadata["customOptions"] = info.ConfigSchema
130+
// Make sure "model" exists in metadata
131+
if metadata["model"] == nil {
132+
metadata["model"] = make(map[string]any)
133+
}
134+
// Add customOptios to the model metadata
135+
modelMeta := metadata["model"].(map[string]any)
136+
modelMeta["customOptions"] = info.ConfigSchema
137+
}
138+
116139
// Create the middleware list
117140
middlewares := []ModelMiddleware{
118141
simulateSystemPrompt(info, nil),
@@ -162,7 +185,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) {
162185
// GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call.
163186
func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) {
164187
if opts.Model == "" {
165-
opts.Model = r.LookupValue(registry.DefaultModelKey).(string)
188+
if defaultModel, ok := r.LookupValue(registry.DefaultModelKey).(string); ok && defaultModel != "" {
189+
opts.Model = defaultModel
190+
}
166191
if opts.Model == "" {
167192
return nil, errors.New("ai.GenerateWithRequest: model is required")
168193
}
@@ -209,7 +234,6 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
209234
output.Format = string(OutputFormatJSON)
210235
}
211236
}
212-
213237
req := &ModelRequest{
214238
Messages: opts.Messages,
215239
Config: opts.Config,
@@ -280,9 +304,11 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
280304
}
281305
}
282306

283-
modelName := genOpts.ModelName
284-
if modelName == "" && genOpts.Model != nil {
307+
var modelName string
308+
if genOpts.Model != nil {
285309
modelName = genOpts.Model.Name()
310+
} else {
311+
modelName = genOpts.ModelName
286312
}
287313

288314
tools := make([]string, len(genOpts.Tools))
@@ -316,6 +342,13 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
316342
messages = append(messages, NewUserTextMessage(prompt))
317343
}
318344

345+
// Apply Model config if no Generate config.
346+
modelArg := genOpts.Model
347+
if modelRef, ok := modelArg.(ModelRef); ok {
348+
if genOpts.Config == nil {
349+
genOpts.Config = modelRef.Config()
350+
}
351+
}
319352
actionOpts := &GenerateActionOptions{
320353
Model: modelName,
321354
Messages: messages,
@@ -626,3 +659,18 @@ func (m *Message) Text() string {
626659
}
627660
return sb.String()
628661
}
662+
663+
// NewModelRef creates a new ModelRef with the given name and configuration.
664+
func NewModelRef(name string, config any) ModelRef {
665+
return ModelRef{name: name, config: config}
666+
}
667+
668+
// Name returns the name of the ModelRef.
669+
func (m ModelRef) Name() string {
670+
return m.name
671+
}
672+
673+
// ModelConfig returns the configuration of a ModelRef.
674+
func (m ModelRef) Config() any {
675+
return m.config
676+
}

‎go/ai/option.go‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ type messagesFn = func(context.Context, any) ([]*Message, error)
3434

3535
// commonOptions are common options for model generation, prompt definition, and prompt execution.
3636
type commonOptions struct {
37-
ModelName string // Name of the model to use.
38-
Model Model // Model to use.
37+
Model ModelArg // Resolvable reference to a model to use with optional embedded config.
38+
ModelName string // Name of model to use
3939
MessagesFn messagesFn // Messages function. If this is set, Messages should be an empty.
4040
Config any // Model configuration. If nil will be taken from the prompt config.
4141
Tools []ToolRef // References to tools to use.
@@ -67,6 +67,7 @@ func (o *commonOptions) applyCommon(opts *commonOptions) error {
6767
return errors.New("cannot set model more than once (either WithModel or WithModelName)")
6868
}
6969
opts.Model = o.Model
70+
return nil
7071
}
7172

7273
if o.ModelName != "" {
@@ -164,8 +165,8 @@ func WithConfig(config any) CommonOption {
164165
return &commonOptions{Config: config}
165166
}
166167

167-
// WithModel sets the model to call for generation.
168-
func WithModel(model Model) CommonOption {
168+
// WithModel sets a resolvable model reference to use for generation.
169+
func WithModel(model ModelArg) CommonOption {
169170
return &commonOptions{Model: model}
170171
}
171172

‎go/ai/prompt.go‎

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ func DefinePrompt(r *registry.Registry, name string, opts ...PromptOption) (*Pro
5050
return nil, err
5151
}
5252
}
53-
53+
// Apply Model config if no Prompt config.
54+
modelArg := pOpts.Model
55+
if modelRef, ok := modelArg.(ModelRef); ok {
56+
if pOpts.Config == nil {
57+
pOpts.Config = modelRef.Config()
58+
}
59+
}
5460
p := &Prompt{
5561
registry: r,
5662
promptOptions: *pOpts,
@@ -111,6 +117,13 @@ func (p *Prompt) Execute(ctx context.Context, opts ...PromptGenerateOption) (*Mo
111117
return nil, err
112118
}
113119
}
120+
// Apply Model config if no Prompt Generate config.
121+
modelArg := genOpts.Model
122+
if modelRef, ok := modelArg.(ModelRef); ok {
123+
if genOpts.Config == nil {
124+
genOpts.Config = modelRef.Config()
125+
}
126+
}
114127

115128
p.MessagesFn = mergeMessagesFn(p.MessagesFn, genOpts.MessagesFn)
116129

0 commit comments

Comments
 (0)