Skip to content

Commit 8fec38a

Browse files
authored
feat(go): Add tool support for ollama models (#2796)
1 parent 2023053 commit 8fec38a

File tree

2 files changed

+252
-15
lines changed

2 files changed

+252
-15
lines changed

‎go/plugins/ollama/ollama.go‎

Lines changed: 131 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@ const provider = "ollama"
4040

4141
var (
4242
mediaSupportedModels = []string{"llava", "bakllava", "llava-llama3", "llava:13b", "llava:7b", "llava:latest"}
43-
roleMapping = map[ai.Role]string{
43+
toolSupportedModels = []string{
44+
"qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
45+
"qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
46+
"mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
47+
"phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
48+
"nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
49+
"granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
50+
"granite3.3", "command-a", "command-r7b-arabic",
51+
}
52+
roleMapping = map[ai.Role]string{
4453
ai.RoleUser: "user",
4554
ai.RoleModel: "assistant",
4655
ai.RoleSystem: "system",
56+
ai.RoleTool: "tool",
4757
}
4858
)
4959

@@ -58,12 +68,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
5868
if info != nil {
5969
mi = *info
6070
} else {
71+
// Check if the model supports tools (must be a chat model and in the supported list)
72+
supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name)
6173
mi = ai.ModelInfo{
6274
Label: model.Name,
6375
Supports: &ai.ModelSupports{
6476
Multiturn: true,
6577
SystemRole: true,
6678
Media: slices.Contains(mediaSupportedModels, model.Name),
79+
Tools: supportsTools,
6780
},
6881
Versions: []string{},
6982
}
@@ -100,9 +113,10 @@ type generator struct {
100113
}
101114

102115
type ollamaMessage struct {
103-
Role string `json:"role"`
104-
Content string `json:"content"`
105-
Images []string `json:"images,omitempty"`
116+
Role string `json:"role"`
117+
Content string `json:"content,omitempty"`
118+
Images []string `json:"images,omitempty"`
119+
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
106120
}
107121

108122
// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
@@ -125,6 +139,7 @@ type ollamaChatRequest struct {
125139
Model string `json:"model"`
126140
Stream bool `json:"stream"`
127141
Format string `json:"format,omitempty"`
142+
Tools []ollamaTool `json:"tools,omitempty"`
128143
}
129144

130145
type ollamaModelRequest struct {
@@ -136,13 +151,38 @@ type ollamaModelRequest struct {
136151
Format string `json:"format,omitempty"`
137152
}
138153

154+
// Tool definition from Ollama API
155+
type ollamaTool struct {
156+
Type string `json:"type"`
157+
Function ollamaFunction `json:"function"`
158+
}
159+
160+
// Function definition for Ollama API
161+
type ollamaFunction struct {
162+
Name string `json:"name"`
163+
Description string `json:"description"`
164+
Parameters map[string]any `json:"parameters"`
165+
}
166+
167+
// Tool Call from Ollama API
168+
type ollamaToolCall struct {
169+
Function ollamaFunctionCall `json:"function"`
170+
}
171+
172+
// Function Call for Ollama API
173+
type ollamaFunctionCall struct {
174+
Name string `json:"name"`
175+
Arguments any `json:"arguments"`
176+
}
177+
139178
// TODO: Add optional parameters (images, format, options, etc.) based on your use case
140179
type ollamaChatResponse struct {
141180
Model string `json:"model"`
142181
CreatedAt string `json:"created_at"`
143182
Message struct {
144-
Role string `json:"role"`
145-
Content string `json:"content"`
183+
Role string `json:"role"`
184+
Content string `json:"content"`
185+
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
146186
} `json:"message"`
147187
}
148188

@@ -217,34 +257,47 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
217257
}
218258
messages = append(messages, message)
219259
}
220-
payload = ollamaChatRequest{
260+
chatReq := ollamaChatRequest{
221261
Messages: messages,
222262
Model: g.model.Name,
223263
Stream: stream,
224264
Images: images,
225265
}
266+
if len(input.Tools) > 0 {
267+
tools, err := convertTools(input.Tools)
268+
if err != nil {
269+
return nil, fmt.Errorf("failed to convert tools: %v", err)
270+
}
271+
chatReq.Tools = tools
272+
}
273+
payload = chatReq
226274
}
275+
227276
client := &http.Client{Timeout: 30 * time.Second}
228277
payloadBytes, err := json.Marshal(payload)
229278
if err != nil {
230279
return nil, err
231280
}
281+
232282
// Determine the correct endpoint
233283
endpoint := g.serverAddress + "/api/chat"
234284
if !isChatModel {
235285
endpoint = g.serverAddress + "/api/generate"
236286
}
287+
237288
req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes))
238289
if err != nil {
239290
return nil, fmt.Errorf("failed to create request: %v", err)
240291
}
241292
req.Header.Set("Content-Type", "application/json")
242293
req = req.WithContext(ctx)
294+
243295
resp, err := client.Do(req)
244296
if err != nil {
245297
return nil, fmt.Errorf("failed to send request: %v", err)
246298
}
247299
defer resp.Body.Close()
300+
248301
if cb == nil {
249302
// Existing behavior for non-streaming responses
250303
var err error
@@ -255,6 +308,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
255308
if resp.StatusCode != http.StatusOK {
256309
return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body)
257310
}
311+
258312
var response *ai.ModelResponse
259313
if isChatModel {
260314
response, err = translateChatResponse(body)
@@ -269,8 +323,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
269323
} else {
270324
var chunks []*ai.ModelResponseChunk
271325
scanner := bufio.NewScanner(resp.Body)
326+
chunkCount := 0
327+
272328
for scanner.Scan() {
273329
line := scanner.Text()
330+
chunkCount++
331+
274332
var chunk *ai.ModelResponseChunk
275333
if isChatModel {
276334
chunk, err = translateChatChunk(line)
@@ -283,9 +341,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
283341
chunks = append(chunks, chunk)
284342
cb(ctx, chunk)
285343
}
344+
286345
if err := scanner.Err(); err != nil {
287346
return nil, fmt.Errorf("reading response stream: %v", err)
288347
}
348+
289349
// Create a final response with the merged chunks
290350
finalResponse := &ai.ModelResponse{
291351
Request: input,
@@ -303,13 +363,29 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
303363
}
304364
}
305365

366+
// convertTools converts Genkit tool definitions to Ollama tool format
367+
func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) {
368+
ollamaTools := make([]ollamaTool, 0, len(tools))
369+
for _, tool := range tools {
370+
ollamaTools = append(ollamaTools, ollamaTool{
371+
Type: "function",
372+
Function: ollamaFunction{
373+
Name: tool.Name,
374+
Description: tool.Description,
375+
Parameters: tool.InputSchema,
376+
},
377+
})
378+
}
379+
return ollamaTools, nil
380+
}
381+
306382
func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
307383
message := &ollamaMessage{
308384
Role: roleMapping[role],
309385
}
310386
var contentBuilder strings.Builder
387+
var toolCalls []ollamaToolCall
311388
var images []string
312-
313389
for _, part := range parts {
314390
if part.IsText() {
315391
contentBuilder.WriteString(part.Text)
@@ -320,12 +396,30 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
320396
}
321397
base64Encoded := base64.StdEncoding.EncodeToString(data)
322398
images = append(images, base64Encoded)
399+
} else if part.IsToolRequest() {
400+
toolReq := part.ToolRequest
401+
toolCalls = append(toolCalls, ollamaToolCall{
402+
Function: ollamaFunctionCall{
403+
Name: toolReq.Name,
404+
Arguments: toolReq.Input,
405+
},
406+
})
407+
} else if part.IsToolResponse() {
408+
toolResp := part.ToolResponse
409+
outputJSON, err := json.Marshal(toolResp.Output)
410+
if err != nil {
411+
return nil, fmt.Errorf("failed to marshal tool response: %v", err)
412+
}
413+
contentBuilder.WriteString(string(outputJSON))
323414
} else {
324415
return nil, errors.New("unsupported content type")
325416
}
326417
}
327418

328419
message.Content = contentBuilder.String()
420+
if len(toolCalls) > 0 {
421+
message.ToolCalls = toolCalls
422+
}
329423
if len(images) > 0 {
330424
message.Images = images
331425
}
@@ -342,17 +436,27 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
342436
modelResponse := &ai.ModelResponse{
343437
FinishReason: ai.FinishReason("stop"),
344438
Message: &ai.Message{
345-
Role: ai.Role(response.Message.Role),
439+
Role: ai.RoleModel,
346440
},
347441
}
348-
349-
aiPart := ai.NewTextPart(response.Message.Content)
350-
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
442+
if len(response.Message.ToolCalls) > 0 {
443+
for _, toolCall := range response.Message.ToolCalls {
444+
toolRequest := &ai.ToolRequest{
445+
Name: toolCall.Function.Name,
446+
Input: toolCall.Function.Arguments,
447+
}
448+
toolPart := ai.NewToolRequestPart(toolRequest)
449+
modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart)
450+
}
451+
} else if response.Message.Content != "" {
452+
aiPart := ai.NewTextPart(response.Message.Content)
453+
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
454+
}
351455

352456
return modelResponse, nil
353457
}
354458

355-
// translateResponse translates Ollama generate response into a genkit response.
459+
// translateModelResponse translates Ollama generate response into a genkit response.
356460
func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) {
357461
var response ollamaModelResponse
358462

@@ -380,8 +484,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
380484
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
381485
}
382486
chunk := &ai.ModelResponseChunk{}
383-
aiPart := ai.NewTextPart(response.Message.Content)
384-
chunk.Content = append(chunk.Content, aiPart)
487+
if len(response.Message.ToolCalls) > 0 {
488+
for _, toolCall := range response.Message.ToolCalls {
489+
toolRequest := &ai.ToolRequest{
490+
Name: toolCall.Function.Name,
491+
Input: toolCall.Function.Arguments,
492+
}
493+
toolPart := ai.NewToolRequestPart(toolRequest)
494+
chunk.Content = append(chunk.Content, toolPart)
495+
}
496+
} else if response.Message.Content != "" {
497+
aiPart := ai.NewTextPart(response.Message.Content)
498+
chunk.Content = append(chunk.Content, aiPart)
499+
}
500+
385501
return chunk, nil
386502
}
387503

0 commit comments

Comments
 (0)