Skip to content

Commit b1be343

Browse files
authored
Update golang ollama plugin to support images (#505)
1 parent 9235758 commit b1be343

File tree

1 file changed

+118
-49
lines changed

1 file changed

+118
-49
lines changed

‎go/plugins/ollama/ollama.go‎

Lines changed: 118 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,58 @@ import (
1818
"bufio"
1919
"bytes"
2020
"context"
21+
"encoding/base64"
2122
"encoding/json"
2223
"errors"
2324
"fmt"
2425
"io"
2526
"net/http"
27+
"slices"
2628
"strings"
29+
"sync"
2730
"time"
2831

2932
"github.com/firebase/genkit/go/ai"
33+
"github.com/firebase/genkit/go/plugins/internal/uri"
3034
)
3135

3236
const provider = "ollama"
3337

38+
var mediaSupportedModels = []string{"llava"}
3439
var roleMapping = map[ai.Role]string{
3540
ai.RoleUser: "user",
3641
ai.RoleModel: "assistant",
3742
ai.RoleSystem: "system",
3843
}
44+
var state struct {
45+
mu sync.Mutex
46+
initted bool
47+
serverAddress string
48+
}
3949

40-
func defineModel(model ModelDefinition, serverAddress string) {
50+
func DefineModel(model ModelDefinition, caps *ai.ModelCapabilities) *ai.Model {
51+
state.mu.Lock()
52+
defer state.mu.Unlock()
53+
if !state.initted {
54+
panic("ollama.Init not called")
55+
}
56+
var mc ai.ModelCapabilities
57+
if caps != nil {
58+
mc = *caps
59+
} else {
60+
mc = ai.ModelCapabilities{
61+
Multiturn: true,
62+
SystemRole: true,
63+
Media: slices.Contains(mediaSupportedModels, model.Name),
64+
}
65+
}
4166
meta := &ai.ModelMetadata{
42-
Label: "Ollama - " + model.Name,
43-
Supports: ai.ModelCapabilities{
44-
Multiturn: model.Type == "chat",
45-
SystemRole: model.Type != "chat",
46-
},
67+
Label: "Ollama - " + model.Name,
68+
Supports: mc,
4769
}
48-
g := &generator{model: model, serverAddress: serverAddress}
49-
ai.DefineModel(provider, model.Name, meta, g.generate)
70+
g := &generator{model: model, serverAddress: state.serverAddress}
71+
return ai.DefineModel(provider, model.Name, meta, g.generate)
72+
5073
}
5174

5275
// Model returns the [ai.Model] with the given name.
@@ -69,22 +92,15 @@ type Config struct {
6992
Models []ModelDefinition
7093
}
7194

72-
// Init registers all the actions in this package with ai.
73-
func Init(ctx context.Context, cfg Config) error {
74-
for _, model := range cfg.Models {
75-
defineModel(model, cfg.ServerAddress)
76-
}
77-
return nil
78-
}
79-
8095
type generator struct {
8196
model ModelDefinition
8297
serverAddress string
8398
}
8499

85100
type ollamaMessage struct {
86-
Role string // json:"role"
87-
Content string // json:"content"
101+
Role string `json:"role"`
102+
Content string `json:"content"`
103+
Images []string `json:"images,omitempty"`
88104
}
89105

90106
// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
@@ -108,10 +124,11 @@ type ollamaChatRequest struct {
108124
}
109125

110126
type ollamaGenerateRequest struct {
111-
System string `json:"system,omitempty"` // Optional System field
112-
Model string `json:"model"`
113-
Prompt string `json:"prompt"`
114-
Stream bool `json:"stream"`
127+
System string `json:"system,omitempty"`
128+
Images []string `json:"images,omitempty"`
129+
Model string `json:"model"`
130+
Prompt string `json:"prompt"`
131+
Stream bool `json:"stream"`
115132
}
116133

117134
// TODO: Add optional parameters (images, format, options, etc.) based on your use case
@@ -130,17 +147,35 @@ type ollamaGenerateResponse struct {
130147
Response string `json:"response"`
131148
}
132149

150+
// Note: Since Ollama models are locally hosted, the plugin doesn't initialize any default models.
151+
// The user has to explicitly decide which model to pull down.
152+
func Init(ctx context.Context, serverAddress string) (err error) {
153+
state.mu.Lock()
154+
defer state.mu.Unlock()
155+
if state.initted {
156+
panic("ollama.Init already called")
157+
}
158+
state.serverAddress = serverAddress
159+
state.initted = true
160+
return nil
161+
}
162+
133163
// Generate makes a request to the Ollama API and processes the response.
134164
func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
135165

136166
stream := cb != nil
137167
var payload any
138168
isChatModel := g.model.Type == "chat"
139169
if !isChatModel {
170+
images, err := concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel})
171+
if err != nil {
172+
return nil, fmt.Errorf("failed to grab image parts: %v", err)
173+
}
140174
payload = ollamaGenerateRequest{
141175
Model: g.model.Name,
142-
Prompt: concatMessages(input, []ai.Role{ai.Role("user"), ai.Role("model"), ai.Role("tool")}),
143-
System: concatMessages(input, []ai.Role{ai.Role("system")}),
176+
Prompt: concatMessages(input, []ai.Role{ai.RoleUser, ai.RoleModel, ai.RoleTool}),
177+
System: concatMessages(input, []ai.Role{ai.RoleSystem}),
178+
Images: images,
144179
Stream: stream,
145180
}
146181
} else {
@@ -149,7 +184,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
149184
for _, m := range input.Messages {
150185
message, err := convertParts(m.Role, m.Content)
151186
if err != nil {
152-
return nil, fmt.Errorf("error converting message parts: %v", err)
187+
return nil, fmt.Errorf("failed to convert message parts: %v", err)
153188
}
154189
messages = append(messages, message)
155190
}
@@ -191,7 +226,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
191226
}
192227
var response *ai.GenerateResponse
193228
if isChatModel {
194-
response, err = translateResponse(body)
229+
response, err = translateChatResponse(body)
195230
} else {
196231
response, err = translateGenerateResponse(body)
197232
}
@@ -201,26 +236,24 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
201236
}
202237
return response, nil
203238
} else {
204-
// Handle streaming response here
205239
var chunks []*ai.GenerateResponseChunk
206-
scanner := bufio.NewScanner(resp.Body) // Create a scanner to read lines
240+
scanner := bufio.NewScanner(resp.Body)
207241
for scanner.Scan() {
208242
line := scanner.Text()
209243
var chunk *ai.GenerateResponseChunk
210244
if isChatModel {
211-
chunk, err = translateChunk(line)
245+
chunk, err = translateChatChunk(line)
212246
} else {
213247
chunk, err = translateGenerateChunk(line)
214248
}
215249
if err != nil {
216-
// Handle parsing error (log, maybe send an error candidate?)
217-
return nil, fmt.Errorf("error translating chunk: %v", err)
250+
return nil, fmt.Errorf("failed to translate chunk: %v", err)
218251
}
219252
chunks = append(chunks, chunk)
220253
cb(ctx, chunk)
221254
}
222255
if err := scanner.Err(); err != nil {
223-
return nil, fmt.Errorf("error reading stream: %v", err)
256+
return nil, fmt.Errorf("failed to read stream: %v", err)
224257
}
225258
// Create a final response with the merged chunks
226259
finalResponse := &ai.GenerateResponse{
@@ -229,7 +262,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
229262
{
230263
FinishReason: ai.FinishReason("stop"),
231264
Message: &ai.Message{
232-
Role: ai.RoleModel, // Assuming the response is from the model
265+
Role: ai.RoleModel,
233266
},
234267
},
235268
},
@@ -243,29 +276,35 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
243276
}
244277
}
245278

246-
// convertParts serializes a slice of *ai.Part into an ollamaMessage (represents Ollama message type)
247279
func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
248-
// Initialize the message with the correct role from the mapping
249280
message := &ollamaMessage{
250281
Role: roleMapping[role],
251282
}
252-
// Concatenate content from all parts
283+
var contentBuilder strings.Builder
253284
for _, part := range parts {
254285
if part.IsText() {
255-
message.Content += part.Text
286+
contentBuilder.WriteString(part.Text)
287+
} else if part.IsMedia() {
288+
_, data, err := uri.Data(part)
289+
if err != nil {
290+
return nil, err
291+
}
292+
base64Encoded := base64.StdEncoding.EncodeToString(data)
293+
message.Images = append(message.Images, base64Encoded)
256294
} else {
257295
return nil, errors.New("unknown content type")
258296
}
259297
}
298+
message.Content = contentBuilder.String()
260299
return message, nil
261300
}
262301

263-
// translateResponse deserializes a JSON response from the Ollama API into a GenerateResponse.
264-
func translateResponse(responseData []byte) (*ai.GenerateResponse, error) {
302+
// translateChatResponse translates Ollama chat response into a genkit response.
303+
func translateChatResponse(responseData []byte) (*ai.GenerateResponse, error) {
265304
var response ollamaChatResponse
266305

267306
if err := json.Unmarshal(responseData, &response); err != nil {
268-
return nil, fmt.Errorf("error parsing response JSON: %v", err)
307+
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
269308
}
270309
generateResponse := &ai.GenerateResponse{}
271310
aiCandidate := &ai.Candidate{
@@ -280,18 +319,18 @@ func translateResponse(responseData []byte) (*ai.GenerateResponse, error) {
280319
return generateResponse, nil
281320
}
282321

283-
// translateGenerateResponse deserializes a JSON response from the Ollama API into a GenerateResponse.
322+
// translateResponse translates Ollama generate response into a genkit response.
284323
func translateGenerateResponse(responseData []byte) (*ai.GenerateResponse, error) {
285324
var response ollamaGenerateResponse
286325

287326
if err := json.Unmarshal(responseData, &response); err != nil {
288-
return nil, fmt.Errorf("error parsing response JSON: %v", err)
327+
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
289328
}
290329
generateResponse := &ai.GenerateResponse{}
291330
aiCandidate := &ai.Candidate{
292331
FinishReason: ai.FinishReason("stop"),
293332
Message: &ai.Message{
294-
Role: ai.Role("model"),
333+
Role: ai.RoleModel,
295334
},
296335
}
297336
aiPart := ai.NewTextPart(response.Response)
@@ -301,11 +340,11 @@ func translateGenerateResponse(responseData []byte) (*ai.GenerateResponse, error
301340
return generateResponse, nil
302341
}
303342

304-
func translateChunk(input string) (*ai.GenerateResponseChunk, error) {
343+
func translateChatChunk(input string) (*ai.GenerateResponseChunk, error) {
305344
var response ollamaChatResponse
306345

307346
if err := json.Unmarshal([]byte(input), &response); err != nil {
308-
return nil, fmt.Errorf("error parsing response JSON: %v", err)
347+
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
309348
}
310349
chunk := &ai.GenerateResponseChunk{}
311350
aiPart := ai.NewTextPart(response.Message.Content)
@@ -317,7 +356,7 @@ func translateGenerateChunk(input string) (*ai.GenerateResponseChunk, error) {
317356
var response ollamaGenerateResponse
318357

319358
if err := json.Unmarshal([]byte(input), &response); err != nil {
320-
return nil, fmt.Errorf("error parsing response JSON: %v", err)
359+
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
321360
}
322361
chunk := &ai.GenerateResponseChunk{}
323362
aiPart := ai.NewTextPart(response.Response)
@@ -331,16 +370,46 @@ func concatMessages(input *ai.GenerateRequest, roles []ai.Role) string {
331370
for _, role := range roles {
332371
roleSet[role] = true // Create a set for faster lookup
333372
}
334-
335373
var sb strings.Builder
374+
for _, message := range input.Messages {
375+
// Check if the message role is in the allowed set
376+
if !roleSet[message.Role] {
377+
continue
378+
}
379+
for _, part := range message.Content {
380+
if !part.IsText() {
381+
continue
382+
}
383+
sb.WriteString(part.Text)
384+
}
385+
}
386+
return sb.String()
387+
}
388+
389+
// concatImages grabs the images from genkit message parts
390+
func concatImages(input *ai.GenerateRequest, roleFilter []ai.Role) ([]string, error) {
391+
roleSet := make(map[ai.Role]bool)
392+
for _, role := range roleFilter {
393+
roleSet[role] = true
394+
}
395+
396+
var images []string
336397

337398
for _, message := range input.Messages {
338399
// Check if the message role is in the allowed set
339400
if roleSet[message.Role] {
340401
for _, part := range message.Content {
341-
sb.WriteString(part.Text)
402+
if !part.IsMedia() {
403+
continue
404+
}
405+
_, data, err := uri.Data(part)
406+
if err != nil {
407+
return nil, err
408+
}
409+
base64Encoded := base64.StdEncoding.EncodeToString(data)
410+
images = append(images, base64Encoded)
342411
}
343412
}
344413
}
345-
return sb.String()
414+
return images, nil
346415
}

0 commit comments

Comments
 (0)