@@ -18,35 +18,58 @@ import (
1818 "bufio"
1919 "bytes"
2020 "context"
21+ "encoding/base64"
2122 "encoding/json"
2223 "errors"
2324 "fmt"
2425 "io"
2526 "net/http"
27+ "slices"
2628 "strings"
29+ "sync"
2730 "time"
2831
2932 "github.com/firebase/genkit/go/ai"
33+ "github.com/firebase/genkit/go/plugins/internal/uri"
3034)
3135
3236const provider = "ollama"
3337
38+ var mediaSupportedModels = []string {"llava" }
3439var roleMapping = map [ai.Role ]string {
3540 ai .RoleUser : "user" ,
3641 ai .RoleModel : "assistant" ,
3742 ai .RoleSystem : "system" ,
3843}
44+ var state struct {
45+ mu sync.Mutex
46+ initted bool
47+ serverAddress string
48+ }
3949
40- func defineModel (model ModelDefinition , serverAddress string ) {
50+ func DefineModel (model ModelDefinition , caps * ai.ModelCapabilities ) * ai.Model {
51+ state .mu .Lock ()
52+ defer state .mu .Unlock ()
53+ if ! state .initted {
54+ panic ("ollama.Init not called" )
55+ }
56+ var mc ai.ModelCapabilities
57+ if caps != nil {
58+ mc = * caps
59+ } else {
60+ mc = ai.ModelCapabilities {
61+ Multiturn : true ,
62+ SystemRole : true ,
63+ Media : slices .Contains (mediaSupportedModels , model .Name ),
64+ }
65+ }
4166 meta := & ai.ModelMetadata {
42- Label : "Ollama - " + model .Name ,
43- Supports : ai.ModelCapabilities {
44- Multiturn : model .Type == "chat" ,
45- SystemRole : model .Type != "chat" ,
46- },
67+ Label : "Ollama - " + model .Name ,
68+ Supports : mc ,
4769 }
48- g := & generator {model : model , serverAddress : serverAddress }
49- ai .DefineModel (provider , model .Name , meta , g .generate )
70+ g := & generator {model : model , serverAddress : state .serverAddress }
71+ return ai .DefineModel (provider , model .Name , meta , g .generate )
72+
5073}
5174
5275// Model returns the [ai.Model] with the given name.
@@ -69,22 +92,15 @@ type Config struct {
6992 Models []ModelDefinition
7093}
7194
72- // Init registers all the actions in this package with ai.
73- func Init (ctx context.Context , cfg Config ) error {
74- for _ , model := range cfg .Models {
75- defineModel (model , cfg .ServerAddress )
76- }
77- return nil
78- }
79-
8095type generator struct {
8196 model ModelDefinition
8297 serverAddress string
8398}
8499
85100type ollamaMessage struct {
86- Role string // json:"role"
87- Content string // json:"content"
101+ Role string `json:"role"`
102+ Content string `json:"content"`
103+ Images []string `json:"images,omitempty"`
88104}
89105
90106// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
@@ -108,10 +124,11 @@ type ollamaChatRequest struct {
108124}
109125
110126type ollamaGenerateRequest struct {
111- System string `json:"system,omitempty"` // Optional System field
112- Model string `json:"model"`
113- Prompt string `json:"prompt"`
114- Stream bool `json:"stream"`
127+ System string `json:"system,omitempty"`
128+ Images []string `json:"images,omitempty"`
129+ Model string `json:"model"`
130+ Prompt string `json:"prompt"`
131+ Stream bool `json:"stream"`
115132}
116133
117134// TODO: Add optional parameters (images, format, options, etc.) based on your use case
@@ -130,17 +147,35 @@ type ollamaGenerateResponse struct {
130147 Response string `json:"response"`
131148}
132149
150+ // Note: Since Ollama models are locally hosted, the plugin doesn't initialize any default models.
151+ // The user has to explicitly decide which model to pull down.
152+ func Init (ctx context.Context , serverAddress string ) (err error ) {
153+ state .mu .Lock ()
154+ defer state .mu .Unlock ()
155+ if state .initted {
156+ panic ("ollama.Init already called" )
157+ }
158+ state .serverAddress = serverAddress
159+ state .initted = true
160+ return nil
161+ }
162+
133163// Generate makes a request to the Ollama API and processes the response.
134164func (g * generator ) generate (ctx context.Context , input * ai.GenerateRequest , cb func (context.Context , * ai.GenerateResponseChunk ) error ) (* ai.GenerateResponse , error ) {
135165
136166 stream := cb != nil
137167 var payload any
138168 isChatModel := g .model .Type == "chat"
139169 if ! isChatModel {
170+ images , err := concatImages (input , []ai.Role {ai .RoleUser , ai .RoleModel })
171+ if err != nil {
172+ return nil , fmt .Errorf ("failed to grab image parts: %v" , err )
173+ }
140174 payload = ollamaGenerateRequest {
141175 Model : g .model .Name ,
142- Prompt : concatMessages (input , []ai.Role {ai .Role ("user" ), ai .Role ("model" ), ai .Role ("tool" )}),
143- System : concatMessages (input , []ai.Role {ai .Role ("system" )}),
176+ Prompt : concatMessages (input , []ai.Role {ai .RoleUser , ai .RoleModel , ai .RoleTool }),
177+ System : concatMessages (input , []ai.Role {ai .RoleSystem }),
178+ Images : images ,
144179 Stream : stream ,
145180 }
146181 } else {
@@ -149,7 +184,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
149184 for _ , m := range input .Messages {
150185 message , err := convertParts (m .Role , m .Content )
151186 if err != nil {
152- return nil , fmt .Errorf ("error converting message parts: %v" , err )
187+ return nil , fmt .Errorf ("failed to convert message parts: %v" , err )
153188 }
154189 messages = append (messages , message )
155190 }
@@ -191,7 +226,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
191226 }
192227 var response * ai.GenerateResponse
193228 if isChatModel {
194- response , err = translateResponse (body )
229+ response , err = translateChatResponse (body )
195230 } else {
196231 response , err = translateGenerateResponse (body )
197232 }
@@ -201,26 +236,24 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
201236 }
202237 return response , nil
203238 } else {
204- // Handle streaming response here
205239 var chunks []* ai.GenerateResponseChunk
206- scanner := bufio .NewScanner (resp .Body ) // Create a scanner to read lines
240+ scanner := bufio .NewScanner (resp .Body )
207241 for scanner .Scan () {
208242 line := scanner .Text ()
209243 var chunk * ai.GenerateResponseChunk
210244 if isChatModel {
211- chunk , err = translateChunk (line )
245+ chunk , err = translateChatChunk (line )
212246 } else {
213247 chunk , err = translateGenerateChunk (line )
214248 }
215249 if err != nil {
216- // Handle parsing error (log, maybe send an error candidate?)
217- return nil , fmt .Errorf ("error translating chunk: %v" , err )
250+ return nil , fmt .Errorf ("failed to translate chunk: %v" , err )
218251 }
219252 chunks = append (chunks , chunk )
220253 cb (ctx , chunk )
221254 }
222255 if err := scanner .Err (); err != nil {
223- return nil , fmt .Errorf ("error reading stream: %v" , err )
256+ return nil , fmt .Errorf ("failed to read stream: %v" , err )
224257 }
225258 // Create a final response with the merged chunks
226259 finalResponse := & ai.GenerateResponse {
@@ -229,7 +262,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
229262 {
230263 FinishReason : ai .FinishReason ("stop" ),
231264 Message : & ai.Message {
232- Role : ai .RoleModel , // Assuming the response is from the model
265+ Role : ai .RoleModel ,
233266 },
234267 },
235268 },
@@ -243,29 +276,35 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
243276 }
244277}
245278
246- // convertParts serializes a slice of *ai.Part into an ollamaMessage (represents Ollama message type)
247279func convertParts (role ai.Role , parts []* ai.Part ) (* ollamaMessage , error ) {
248- // Initialize the message with the correct role from the mapping
249280 message := & ollamaMessage {
250281 Role : roleMapping [role ],
251282 }
252- // Concatenate content from all parts
283+ var contentBuilder strings. Builder
253284 for _ , part := range parts {
254285 if part .IsText () {
255- message .Content += part .Text
286+ contentBuilder .WriteString (part .Text )
287+ } else if part .IsMedia () {
288+ _ , data , err := uri .Data (part )
289+ if err != nil {
290+ return nil , err
291+ }
292+ base64Encoded := base64 .StdEncoding .EncodeToString (data )
293+ message .Images = append (message .Images , base64Encoded )
256294 } else {
257295 return nil , errors .New ("unknown content type" )
258296 }
259297 }
298+ message .Content = contentBuilder .String ()
260299 return message , nil
261300}
262301
263- // translateResponse deserializes a JSON response from the Ollama API into a GenerateResponse .
264- func translateResponse (responseData []byte ) (* ai.GenerateResponse , error ) {
302+ // translateChatResponse translates Ollama chat response into a genkit response .
303+ func translateChatResponse (responseData []byte ) (* ai.GenerateResponse , error ) {
265304 var response ollamaChatResponse
266305
267306 if err := json .Unmarshal (responseData , & response ); err != nil {
268- return nil , fmt .Errorf ("error parsing response JSON: %v" , err )
307+ return nil , fmt .Errorf ("failed to parse response JSON: %v" , err )
269308 }
270309 generateResponse := & ai.GenerateResponse {}
271310 aiCandidate := & ai.Candidate {
@@ -280,18 +319,18 @@ func translateResponse(responseData []byte) (*ai.GenerateResponse, error) {
280319 return generateResponse , nil
281320}
282321
283- // translateGenerateResponse deserializes a JSON response from the Ollama API into a GenerateResponse .
322+ // translateResponse translates Ollama generate response into a genkit response .
284323func translateGenerateResponse (responseData []byte ) (* ai.GenerateResponse , error ) {
285324 var response ollamaGenerateResponse
286325
287326 if err := json .Unmarshal (responseData , & response ); err != nil {
288- return nil , fmt .Errorf ("error parsing response JSON: %v" , err )
327+ return nil , fmt .Errorf ("failed to parse response JSON: %v" , err )
289328 }
290329 generateResponse := & ai.GenerateResponse {}
291330 aiCandidate := & ai.Candidate {
292331 FinishReason : ai .FinishReason ("stop" ),
293332 Message : & ai.Message {
294- Role : ai .Role ( "model" ) ,
333+ Role : ai .RoleModel ,
295334 },
296335 }
297336 aiPart := ai .NewTextPart (response .Response )
@@ -301,11 +340,11 @@ func translateGenerateResponse(responseData []byte) (*ai.GenerateResponse, error
301340 return generateResponse , nil
302341}
303342
304- func translateChunk (input string ) (* ai.GenerateResponseChunk , error ) {
343+ func translateChatChunk (input string ) (* ai.GenerateResponseChunk , error ) {
305344 var response ollamaChatResponse
306345
307346 if err := json .Unmarshal ([]byte (input ), & response ); err != nil {
308- return nil , fmt .Errorf ("error parsing response JSON: %v" , err )
347+ return nil , fmt .Errorf ("failed to parse response JSON: %v" , err )
309348 }
310349 chunk := & ai.GenerateResponseChunk {}
311350 aiPart := ai .NewTextPart (response .Message .Content )
@@ -317,7 +356,7 @@ func translateGenerateChunk(input string) (*ai.GenerateResponseChunk, error) {
317356 var response ollamaGenerateResponse
318357
319358 if err := json .Unmarshal ([]byte (input ), & response ); err != nil {
320- return nil , fmt .Errorf ("error parsing response JSON: %v" , err )
359+ return nil , fmt .Errorf ("failed to parse response JSON: %v" , err )
321360 }
322361 chunk := & ai.GenerateResponseChunk {}
323362 aiPart := ai .NewTextPart (response .Response )
@@ -331,16 +370,46 @@ func concatMessages(input *ai.GenerateRequest, roles []ai.Role) string {
331370 for _ , role := range roles {
332371 roleSet [role ] = true // Create a set for faster lookup
333372 }
334-
335373 var sb strings.Builder
374+ for _ , message := range input .Messages {
375+ // Check if the message role is in the allowed set
376+ if ! roleSet [message .Role ] {
377+ continue
378+ }
379+ for _ , part := range message .Content {
380+ if ! part .IsText () {
381+ continue
382+ }
383+ sb .WriteString (part .Text )
384+ }
385+ }
386+ return sb .String ()
387+ }
388+
389+ // concatImages grabs the images from genkit message parts
390+ func concatImages (input * ai.GenerateRequest , roleFilter []ai.Role ) ([]string , error ) {
391+ roleSet := make (map [ai.Role ]bool )
392+ for _ , role := range roleFilter {
393+ roleSet [role ] = true
394+ }
395+
396+ var images []string
336397
337398 for _ , message := range input .Messages {
338399 // Check if the message role is in the allowed set
339400 if roleSet [message .Role ] {
340401 for _ , part := range message .Content {
341- sb .WriteString (part .Text )
402+ if ! part .IsMedia () {
403+ continue
404+ }
405+ _ , data , err := uri .Data (part )
406+ if err != nil {
407+ return nil , err
408+ }
409+ base64Encoded := base64 .StdEncoding .EncodeToString (data )
410+ images = append (images , base64Encoded )
342411 }
343412 }
344413 }
345- return sb . String ()
414+ return images , nil
346415}
0 commit comments