@@ -40,10 +40,20 @@ const provider = "ollama"
4040
4141var (
4242 mediaSupportedModels = []string {"llava" , "bakllava" , "llava-llama3" , "llava:13b" , "llava:7b" , "llava:latest" }
43- roleMapping = map [ai.Role ]string {
43+ toolSupportedModels = []string {
44+ "qwq" , "mistral-small3.1" , "llama3.3" , "llama3.2" , "llama3.1" , "mistral" ,
45+ "qwen2.5" , "qwen2.5-coder" , "qwen2" , "mistral-nemo" , "mixtral" , "smollm2" ,
46+ "mistral-small" , "command-r" , "hermes3" , "mistral-large" , "command-r-plus" ,
47+ "phi4-mini" , "granite3.1-dense" , "granite3-dense" , "granite3.2" , "athene-v2" ,
48+ "nemotron-mini" , "nemotron" , "llama3-groq-tool-use" , "aya-expanse" , "granite3-moe" ,
49+ "granite3.2-vision" , "granite3.1-moe" , "cogito" , "command-r7b" , "firefunction-v2" ,
50+ "granite3.3" , "command-a" , "command-r7b-arabic" ,
51+ }
52+ roleMapping = map [ai.Role ]string {
4453 ai .RoleUser : "user" ,
4554 ai .RoleModel : "assistant" ,
4655 ai .RoleSystem : "system" ,
56+ ai .RoleTool : "tool" ,
4757 }
4858)
4959
@@ -58,12 +68,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
5868 if info != nil {
5969 mi = * info
6070 } else {
71+ // Check if the model supports tools (must be a chat model and in the supported list)
72+ supportsTools := model .Type == "chat" && slices .Contains (toolSupportedModels , model .Name )
6173 mi = ai.ModelInfo {
6274 Label : model .Name ,
6375 Supports : & ai.ModelSupports {
6476 Multiturn : true ,
6577 SystemRole : true ,
6678 Media : slices .Contains (mediaSupportedModels , model .Name ),
79+ Tools : supportsTools ,
6780 },
6881 Versions : []string {},
6982 }
@@ -100,9 +113,10 @@ type generator struct {
100113}
101114
102115type ollamaMessage struct {
103- Role string `json:"role"`
104- Content string `json:"content"`
105- Images []string `json:"images,omitempty"`
116+ Role string `json:"role"`
117+ Content string `json:"content,omitempty"`
118+ Images []string `json:"images,omitempty"`
119+ ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
106120}
107121
108122// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
@@ -125,6 +139,7 @@ type ollamaChatRequest struct {
125139 Model string `json:"model"`
126140 Stream bool `json:"stream"`
127141 Format string `json:"format,omitempty"`
142+ Tools []ollamaTool `json:"tools,omitempty"`
128143}
129144
130145type ollamaModelRequest struct {
@@ -136,13 +151,38 @@ type ollamaModelRequest struct {
136151 Format string `json:"format,omitempty"`
137152}
138153
154+ // Tool definition from Ollama API
155+ type ollamaTool struct {
156+ Type string `json:"type"`
157+ Function ollamaFunction `json:"function"`
158+ }
159+
160+ // Function definition for Ollama API
161+ type ollamaFunction struct {
162+ Name string `json:"name"`
163+ Description string `json:"description"`
164+ Parameters map [string ]any `json:"parameters"`
165+ }
166+
167+ // Tool Call from Ollama API
168+ type ollamaToolCall struct {
169+ Function ollamaFunctionCall `json:"function"`
170+ }
171+
172+ // Function Call for Ollama API
173+ type ollamaFunctionCall struct {
174+ Name string `json:"name"`
175+ Arguments any `json:"arguments"`
176+ }
177+
139178// TODO: Add optional parameters (images, format, options, etc.) based on your use case
140179type ollamaChatResponse struct {
141180 Model string `json:"model"`
142181 CreatedAt string `json:"created_at"`
143182 Message struct {
144- Role string `json:"role"`
145- Content string `json:"content"`
183+ Role string `json:"role"`
184+ Content string `json:"content"`
185+ ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
146186 } `json:"message"`
147187}
148188
@@ -217,34 +257,47 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
217257 }
218258 messages = append (messages , message )
219259 }
220- payload = ollamaChatRequest {
260+ chatReq : = ollamaChatRequest {
221261 Messages : messages ,
222262 Model : g .model .Name ,
223263 Stream : stream ,
224264 Images : images ,
225265 }
266+ if len (input .Tools ) > 0 {
267+ tools , err := convertTools (input .Tools )
268+ if err != nil {
269+ return nil , fmt .Errorf ("failed to convert tools: %v" , err )
270+ }
271+ chatReq .Tools = tools
272+ }
273+ payload = chatReq
226274 }
275+
227276 client := & http.Client {Timeout : 30 * time .Second }
228277 payloadBytes , err := json .Marshal (payload )
229278 if err != nil {
230279 return nil , err
231280 }
281+
232282 // Determine the correct endpoint
233283 endpoint := g .serverAddress + "/api/chat"
234284 if ! isChatModel {
235285 endpoint = g .serverAddress + "/api/generate"
236286 }
287+
237288 req , err := http .NewRequest ("POST" , endpoint , bytes .NewReader (payloadBytes ))
238289 if err != nil {
239290 return nil , fmt .Errorf ("failed to create request: %v" , err )
240291 }
241292 req .Header .Set ("Content-Type" , "application/json" )
242293 req = req .WithContext (ctx )
294+
243295 resp , err := client .Do (req )
244296 if err != nil {
245297 return nil , fmt .Errorf ("failed to send request: %v" , err )
246298 }
247299 defer resp .Body .Close ()
300+
248301 if cb == nil {
249302 // Existing behavior for non-streaming responses
250303 var err error
@@ -255,6 +308,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
255308 if resp .StatusCode != http .StatusOK {
256309 return nil , fmt .Errorf ("server returned non-200 status: %d, body: %s" , resp .StatusCode , body )
257310 }
311+
258312 var response * ai.ModelResponse
259313 if isChatModel {
260314 response , err = translateChatResponse (body )
@@ -269,8 +323,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
269323 } else {
270324 var chunks []* ai.ModelResponseChunk
271325 scanner := bufio .NewScanner (resp .Body )
326+ chunkCount := 0
327+
272328 for scanner .Scan () {
273329 line := scanner .Text ()
330+ chunkCount ++
331+
274332 var chunk * ai.ModelResponseChunk
275333 if isChatModel {
276334 chunk , err = translateChatChunk (line )
@@ -283,9 +341,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
283341 chunks = append (chunks , chunk )
284342 cb (ctx , chunk )
285343 }
344+
286345 if err := scanner .Err (); err != nil {
287346 return nil , fmt .Errorf ("reading response stream: %v" , err )
288347 }
348+
289349 // Create a final response with the merged chunks
290350 finalResponse := & ai.ModelResponse {
291351 Request : input ,
@@ -303,13 +363,29 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
303363 }
304364}
305365
366+ // convertTools converts Genkit tool definitions to Ollama tool format
367+ func convertTools (tools []* ai.ToolDefinition ) ([]ollamaTool , error ) {
368+ ollamaTools := make ([]ollamaTool , 0 , len (tools ))
369+ for _ , tool := range tools {
370+ ollamaTools = append (ollamaTools , ollamaTool {
371+ Type : "function" ,
372+ Function : ollamaFunction {
373+ Name : tool .Name ,
374+ Description : tool .Description ,
375+ Parameters : tool .InputSchema ,
376+ },
377+ })
378+ }
379+ return ollamaTools , nil
380+ }
381+
306382func convertParts (role ai.Role , parts []* ai.Part ) (* ollamaMessage , error ) {
307383 message := & ollamaMessage {
308384 Role : roleMapping [role ],
309385 }
310386 var contentBuilder strings.Builder
387+ var toolCalls []ollamaToolCall
311388 var images []string
312-
313389 for _ , part := range parts {
314390 if part .IsText () {
315391 contentBuilder .WriteString (part .Text )
@@ -320,12 +396,30 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
320396 }
321397 base64Encoded := base64 .StdEncoding .EncodeToString (data )
322398 images = append (images , base64Encoded )
399+ } else if part .IsToolRequest () {
400+ toolReq := part .ToolRequest
401+ toolCalls = append (toolCalls , ollamaToolCall {
402+ Function : ollamaFunctionCall {
403+ Name : toolReq .Name ,
404+ Arguments : toolReq .Input ,
405+ },
406+ })
407+ } else if part .IsToolResponse () {
408+ toolResp := part .ToolResponse
409+ outputJSON , err := json .Marshal (toolResp .Output )
410+ if err != nil {
411+ return nil , fmt .Errorf ("failed to marshal tool response: %v" , err )
412+ }
413+ contentBuilder .WriteString (string (outputJSON ))
323414 } else {
324415 return nil , errors .New ("unsupported content type" )
325416 }
326417 }
327418
328419 message .Content = contentBuilder .String ()
420+ if len (toolCalls ) > 0 {
421+ message .ToolCalls = toolCalls
422+ }
329423 if len (images ) > 0 {
330424 message .Images = images
331425 }
@@ -342,17 +436,27 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
342436 modelResponse := & ai.ModelResponse {
343437 FinishReason : ai .FinishReason ("stop" ),
344438 Message : & ai.Message {
345- Role : ai .Role ( response . Message . Role ) ,
439+ Role : ai .RoleModel ,
346440 },
347441 }
348-
349- aiPart := ai .NewTextPart (response .Message .Content )
350- modelResponse .Message .Content = append (modelResponse .Message .Content , aiPart )
442+ if len (response .Message .ToolCalls ) > 0 {
443+ for _ , toolCall := range response .Message .ToolCalls {
444+ toolRequest := & ai.ToolRequest {
445+ Name : toolCall .Function .Name ,
446+ Input : toolCall .Function .Arguments ,
447+ }
448+ toolPart := ai .NewToolRequestPart (toolRequest )
449+ modelResponse .Message .Content = append (modelResponse .Message .Content , toolPart )
450+ }
451+ } else if response .Message .Content != "" {
452+ aiPart := ai .NewTextPart (response .Message .Content )
453+ modelResponse .Message .Content = append (modelResponse .Message .Content , aiPart )
454+ }
351455
352456 return modelResponse , nil
353457}
354458
355- // translateResponse translates Ollama generate response into a genkit response.
459+ // translateModelResponse translates Ollama generate response into a genkit response.
356460func translateModelResponse (responseData []byte ) (* ai.ModelResponse , error ) {
357461 var response ollamaModelResponse
358462
@@ -380,8 +484,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
380484 return nil , fmt .Errorf ("failed to parse response JSON: %v" , err )
381485 }
382486 chunk := & ai.ModelResponseChunk {}
383- aiPart := ai .NewTextPart (response .Message .Content )
384- chunk .Content = append (chunk .Content , aiPart )
487+ if len (response .Message .ToolCalls ) > 0 {
488+ for _ , toolCall := range response .Message .ToolCalls {
489+ toolRequest := & ai.ToolRequest {
490+ Name : toolCall .Function .Name ,
491+ Input : toolCall .Function .Arguments ,
492+ }
493+ toolPart := ai .NewToolRequestPart (toolRequest )
494+ chunk .Content = append (chunk .Content , toolPart )
495+ }
496+ } else if response .Message .Content != "" {
497+ aiPart := ai .NewTextPart (response .Message .Content )
498+ chunk .Content = append (chunk .Content , aiPart )
499+ }
500+
385501 return chunk , nil
386502}
387503
0 commit comments