Skip to content

Commit 44912eb

Browse files
authored
feat(go): Add ollama vision support (#2795)
1 parent 1ed2757 commit 44912eb

File tree

2 files changed

+174
-8
lines changed

2 files changed

+174
-8
lines changed

‎go/plugins/ollama/ollama.go‎

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import (
3939
const provider = "ollama"
4040

4141
var (
42-
mediaSupportedModels = []string{"llava"}
42+
mediaSupportedModels = []string{"llava", "bakllava", "llava-llama3", "llava:13b", "llava:7b", "llava:latest"}
4343
roleMapping = map[ai.Role]string{
4444
ai.RoleUser: "user",
4545
ai.RoleModel: "assistant",
@@ -54,6 +54,7 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
5454
panic("ollama.Init not called")
5555
}
5656
var mi ai.ModelInfo
57+
5758
if info != nil {
5859
mi = *info
5960
} else {
@@ -120,8 +121,10 @@ keep_alive: controls how long the model will stay loaded into memory following t
120121
*/
121122
type ollamaChatRequest struct {
122123
Messages []*ollamaMessage `json:"messages"`
124+
Images []string `json:"images,omitempty"`
123125
Model string `json:"model"`
124126
Stream bool `json:"stream"`
127+
Format string `json:"format,omitempty"`
125128
}
126129

127130
type ollamaModelRequest struct {
@@ -130,6 +133,7 @@ type ollamaModelRequest struct {
130133
Model string `json:"model"`
131134
Prompt string `json:"prompt"`
132135
Stream bool `json:"stream"`
136+
Format string `json:"format,omitempty"`
133137
}
134138

135139
// TODO: Add optional parameters (images, format, options, etc.) based on your use case
@@ -181,11 +185,21 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
181185
stream := cb != nil
182186
var payload any
183187
isChatModel := g.model.Type == "chat"
184-
if !isChatModel {
185-
images, err := concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel})
188+
189+
// Check if this is an image model
190+
hasMediaSupport := slices.Contains(mediaSupportedModels, g.model.Name)
191+
192+
// Extract images if the model supports them
193+
var images []string
194+
var err error
195+
if hasMediaSupport {
196+
images, err = concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel})
186197
if err != nil {
187198
return nil, fmt.Errorf("failed to grab image parts: %v", err)
188199
}
200+
}
201+
202+
if !isChatModel {
189203
payload = ollamaModelRequest{
190204
Model: g.model.Name,
191205
Prompt: concatMessages(input, []ai.Role{ai.RoleUser, ai.RoleModel, ai.RoleTool}),
@@ -207,6 +221,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
207221
Messages: messages,
208222
Model: g.model.Name,
209223
Stream: stream,
224+
Images: images,
210225
}
211226
}
212227
client := &http.Client{Timeout: 30 * time.Second}
@@ -293,21 +308,27 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
293308
Role: roleMapping[role],
294309
}
295310
var contentBuilder strings.Builder
311+
var images []string
312+
296313
for _, part := range parts {
297314
if part.IsText() {
298315
contentBuilder.WriteString(part.Text)
299316
} else if part.IsMedia() {
300317
_, data, err := uri.Data(part)
301318
if err != nil {
302-
return nil, err
319+
return nil, fmt.Errorf("failed to extract media data: %v", err)
303320
}
304321
base64Encoded := base64.StdEncoding.EncodeToString(data)
305-
message.Images = append(message.Images, base64Encoded)
322+
images = append(images, base64Encoded)
306323
} else {
307-
return nil, errors.New("unknown content type")
324+
return nil, errors.New("unsupported content type")
308325
}
309326
}
327+
310328
message.Content = contentBuilder.String()
329+
if len(images) > 0 {
330+
message.Images = images
331+
}
311332
return message, nil
312333
}
313334

@@ -414,10 +435,18 @@ func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error
414435
if !part.IsMedia() {
415436
continue
416437
}
417-
_, data, err := uri.Data(part)
438+
439+
// Get the media type and data
440+
mediaType, data, err := uri.Data(part)
418441
if err != nil {
419-
return nil, err
442+
return nil, fmt.Errorf("failed to extract image data: %v", err)
420443
}
444+
445+
// Only include image media types
446+
if !strings.HasPrefix(mediaType, "image/") {
447+
continue
448+
}
449+
421450
base64Encoded := base64.StdEncoding.EncodeToString(data)
422451
images = append(images, base64Encoded)
423452
}

‎go/samples/ollama-vision/main.go‎

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"context"
19+
"encoding/base64"
20+
"fmt"
21+
"log"
22+
"net/http"
23+
"os"
24+
"path/filepath"
25+
"strings"
26+
27+
"github.com/firebase/genkit/go/ai"
28+
"github.com/firebase/genkit/go/genkit"
29+
"github.com/firebase/genkit/go/plugins/ollama"
30+
)
31+
32+
// getContentTypeFromExtension returns a MIME type based on file extension
33+
func getContentTypeFromExtension(filename string) string {
34+
ext := strings.ToLower(filepath.Ext(filename))
35+
switch ext {
36+
case ".jpg", ".jpeg":
37+
return "image/jpeg"
38+
case ".png":
39+
return "image/png"
40+
case ".gif":
41+
return "image/gif"
42+
case ".webp":
43+
return "image/webp"
44+
case ".bmp":
45+
return "image/bmp"
46+
case ".svg":
47+
return "image/svg+xml"
48+
default:
49+
return "image/png" // Default fallback
50+
}
51+
}
52+
53+
func main() {
54+
// Get the image path from command line argument or use a default path
55+
imagePath := "test.png"
56+
if len(os.Args) > 1 {
57+
imagePath = os.Args[1]
58+
}
59+
60+
// Check if image exists
61+
if _, err := os.Stat(imagePath); os.IsNotExist(err) {
62+
log.Fatalf("Image file not found: %s", imagePath)
63+
}
64+
65+
// Read the image file
66+
imageData, err := os.ReadFile(imagePath)
67+
if err != nil {
68+
log.Fatalf("Failed to read image file: %v", err)
69+
}
70+
71+
// Detect content type (MIME type) from the file's binary signature
72+
contentType := http.DetectContentType(imageData)
73+
74+
// If content type is generic/unknown, try to infer from file extension
75+
if contentType == "application/octet-stream" {
76+
contentType = getContentTypeFromExtension(imagePath)
77+
}
78+
79+
// Encode image to base64
80+
base64Image := base64.StdEncoding.EncodeToString(imageData)
81+
dataURI := fmt.Sprintf("data:%s;base64,%s", contentType, base64Image)
82+
83+
// Create a new Genkit instance
84+
g, err := genkit.Init(context.Background())
85+
if err != nil {
86+
log.Fatalf("Failed to initialize Genkit: %v", err)
87+
}
88+
89+
// Initialize the Ollama plugin
90+
ollamaPlugin := &ollama.Ollama{
91+
ServerAddress: "http://localhost:11434", // Default Ollama server address
92+
}
93+
94+
// Initialize the plugin
95+
err = ollamaPlugin.Init(context.Background(), g)
96+
if err != nil {
97+
log.Fatalf("Failed to initialize Ollama plugin: %v", err)
98+
}
99+
100+
// Define a model that supports images (llava is one of the supported models)
101+
modelName := "llava"
102+
model := ollamaPlugin.DefineModel(g, ollama.ModelDefinition{
103+
Name: modelName,
104+
Type: "chat",
105+
}, nil)
106+
107+
// Create a context
108+
ctx := context.Background()
109+
110+
// Create a request with text and image
111+
request := &ai.ModelRequest{
112+
Messages: []*ai.Message{
113+
{
114+
Role: ai.RoleUser,
115+
Content: []*ai.Part{
116+
ai.NewTextPart("Describe what you see in this image:"),
117+
ai.NewMediaPart(contentType, dataURI),
118+
},
119+
},
120+
},
121+
}
122+
123+
// Call the model
124+
fmt.Printf("Sending request to %s model...\n", modelName)
125+
response, err := model.Generate(ctx, request, nil)
126+
if err != nil {
127+
log.Fatalf("Error generating response: %v", err)
128+
}
129+
130+
// Print the response
131+
fmt.Println("\nModel Response:")
132+
for _, part := range response.Message.Content {
133+
if part.IsText() {
134+
fmt.Println(part.Text)
135+
}
136+
}
137+
}

0 commit comments

Comments
 (0)