@@ -39,7 +39,7 @@ import (
3939const provider = "ollama"
4040
4141var (
42- mediaSupportedModels = []string {"llava" }
42+ mediaSupportedModels = []string {"llava" , "bakllava" , "llava-llama3" , "llava:13b" , "llava:7b" , "llava:latest" }
4343 roleMapping = map [ai.Role ]string {
4444 ai .RoleUser : "user" ,
4545 ai .RoleModel : "assistant" ,
@@ -54,6 +54,7 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
5454 panic ("ollama.Init not called" )
5555 }
5656 var mi ai.ModelInfo
57+
5758 if info != nil {
5859 mi = * info
5960 } else {
@@ -120,8 +121,10 @@ keep_alive: controls how long the model will stay loaded into memory following t
120121*/
121122type ollamaChatRequest struct {
122123 Messages []* ollamaMessage `json:"messages"`
124+ Images []string `json:"images,omitempty"`
123125 Model string `json:"model"`
124126 Stream bool `json:"stream"`
127+ Format string `json:"format,omitempty"`
125128}
126129
127130type ollamaModelRequest struct {
@@ -130,6 +133,7 @@ type ollamaModelRequest struct {
130133 Model string `json:"model"`
131134 Prompt string `json:"prompt"`
132135 Stream bool `json:"stream"`
136+ Format string `json:"format,omitempty"`
133137}
134138
135139// TODO: Add optional parameters (images, format, options, etc.) based on your use case
@@ -181,11 +185,21 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
181185 stream := cb != nil
182186 var payload any
183187 isChatModel := g .model .Type == "chat"
184- if ! isChatModel {
185- images , err := concatImages (input , []ai.Role {ai .RoleUser , ai .RoleModel })
188+
189+ // Check if this is an image model
190+ hasMediaSupport := slices .Contains (mediaSupportedModels , g .model .Name )
191+
192+ // Extract images if the model supports them
193+ var images []string
194+ var err error
195+ if hasMediaSupport {
196+ images , err = concatImages (input , []ai.Role {ai .RoleUser , ai .RoleModel })
186197 if err != nil {
187198 return nil , fmt .Errorf ("failed to grab image parts: %v" , err )
188199 }
200+ }
201+
202+ if ! isChatModel {
189203 payload = ollamaModelRequest {
190204 Model : g .model .Name ,
191205 Prompt : concatMessages (input , []ai.Role {ai .RoleUser , ai .RoleModel , ai .RoleTool }),
@@ -207,6 +221,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
207221 Messages : messages ,
208222 Model : g .model .Name ,
209223 Stream : stream ,
224+ Images : images ,
210225 }
211226 }
212227 client := & http.Client {Timeout : 30 * time .Second }
@@ -293,21 +308,27 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
293308 Role : roleMapping [role ],
294309 }
295310 var contentBuilder strings.Builder
311+ var images []string
312+
296313 for _ , part := range parts {
297314 if part .IsText () {
298315 contentBuilder .WriteString (part .Text )
299316 } else if part .IsMedia () {
300317 _ , data , err := uri .Data (part )
301318 if err != nil {
302- return nil , err
319+ return nil , fmt . Errorf ( "failed to extract media data: %v" , err )
303320 }
304321 base64Encoded := base64 .StdEncoding .EncodeToString (data )
305- message . Images = append (message . Images , base64Encoded )
322+ images = append (images , base64Encoded )
306323 } else {
307- return nil , errors .New ("unknown content type" )
324+ return nil , errors .New ("unsupported content type" )
308325 }
309326 }
327+
310328 message .Content = contentBuilder .String ()
329+ if len (images ) > 0 {
330+ message .Images = images
331+ }
311332 return message , nil
312333}
313334
@@ -414,10 +435,18 @@ func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error
414435 if ! part .IsMedia () {
415436 continue
416437 }
417- _ , data , err := uri .Data (part )
438+
439+ // Get the media type and data
440+ mediaType , data , err := uri .Data (part )
418441 if err != nil {
419- return nil , err
442+ return nil , fmt . Errorf ( "failed to extract image data: %v" , err )
420443 }
444+
445+ // Only include image media types
446+ if ! strings .HasPrefix (mediaType , "image/" ) {
447+ continue
448+ }
449+
421450 base64Encoded := base64 .StdEncoding .EncodeToString (data )
422451 images = append (images , base64Encoded )
423452 }
0 commit comments