@@ -41,6 +41,17 @@ type (
4141 Generate (ctx context.Context , req * ModelRequest , cb ModelStreamCallback ) (* ModelResponse , error )
4242 }
4343
44+ // ModelArg is the interface for model arguments.
45+ ModelArg interface {
46+ Name () string
47+ }
48+
49+ // ModelRef is a struct to hold model name and configuration.
50+ ModelRef struct {
51+ name string
52+ config any
53+ }
54+
4455 // ToolConfig handles configuration around tool calls during generation.
4556 ToolConfig struct {
4657 MaxTurns int // Maximum number of tool call iterations before erroring.
@@ -87,6 +98,7 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
8798
8899// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
89100func DefineModel (r * registry.Registry , provider , name string , info * ModelInfo , fn ModelFunc ) Model {
101+
90102 if info == nil {
91103 // Always make sure there's at least minimal metadata.
92104 info = & ModelInfo {
@@ -113,6 +125,17 @@ func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, f
113125 metadata ["label" ] = info .Label
114126 }
115127
128+ if info .ConfigSchema != nil {
129+ metadata ["customOptions" ] = info .ConfigSchema
130+ // Make sure "model" exists in metadata
131+ if metadata ["model" ] == nil {
132+ metadata ["model" ] = make (map [string ]any )
133+ }
134+ // Add customOptios to the model metadata
135+ modelMeta := metadata ["model" ].(map [string ]any )
136+ modelMeta ["customOptions" ] = info .ConfigSchema
137+ }
138+
116139 // Create the middleware list
117140 middlewares := []ModelMiddleware {
118141 simulateSystemPrompt (info , nil ),
@@ -162,7 +185,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) {
162185// GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call.
163186func GenerateWithRequest (ctx context.Context , r * registry.Registry , opts * GenerateActionOptions , mw []ModelMiddleware , cb ModelStreamCallback ) (* ModelResponse , error ) {
164187 if opts .Model == "" {
165- opts .Model = r .LookupValue (registry .DefaultModelKey ).(string )
188+ if defaultModel , ok := r .LookupValue (registry .DefaultModelKey ).(string ); ok && defaultModel != "" {
189+ opts .Model = defaultModel
190+ }
166191 if opts .Model == "" {
167192 return nil , errors .New ("ai.GenerateWithRequest: model is required" )
168193 }
@@ -209,7 +234,6 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
209234 output .Format = string (OutputFormatJSON )
210235 }
211236 }
212-
213237 req := & ModelRequest {
214238 Messages : opts .Messages ,
215239 Config : opts .Config ,
@@ -280,9 +304,11 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
280304 }
281305 }
282306
283- modelName := genOpts . ModelName
284- if modelName == "" && genOpts .Model != nil {
307+ var modelName string
308+ if genOpts .Model != nil {
285309 modelName = genOpts .Model .Name ()
310+ } else {
311+ modelName = genOpts .ModelName
286312 }
287313
288314 tools := make ([]string , len (genOpts .Tools ))
@@ -316,6 +342,13 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
316342 messages = append (messages , NewUserTextMessage (prompt ))
317343 }
318344
345+ // Apply Model config if no Generate config.
346+ modelArg := genOpts .Model
347+ if modelRef , ok := modelArg .(ModelRef ); ok {
348+ if genOpts .Config == nil {
349+ genOpts .Config = modelRef .Config ()
350+ }
351+ }
319352 actionOpts := & GenerateActionOptions {
320353 Model : modelName ,
321354 Messages : messages ,
@@ -626,3 +659,18 @@ func (m *Message) Text() string {
626659 }
627660 return sb .String ()
628661}
662+
663+ // NewModelRef creates a new ModelRef with the given name and configuration.
664+ func NewModelRef (name string , config any ) ModelRef {
665+ return ModelRef {name : name , config : config }
666+ }
667+
668+ // Name returns the name of the ModelRef.
669+ func (m ModelRef ) Name () string {
670+ return m .name
671+ }
672+
673+ // ModelConfig returns the configuration of a ModelRef.
674+ func (m ModelRef ) Config () any {
675+ return m .config
676+ }
0 commit comments