Skip to content

Commit b86533a

Browse files
authored
feat(go): Add ollama embeddings support (#841)
1 parent 50cdf5c commit b86533a

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed

‎go/plugins/ollama/embed.go‎

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ollama
16+
17+
import (
18+
"bytes"
19+
"context"
20+
"encoding/json"
21+
"fmt"
22+
"net/http"
23+
"strings"
24+
25+
"github.com/firebase/genkit/go/ai"
26+
)
27+
28+
type EmbedOptions struct {
29+
Model string `json:"model"`
30+
}
31+
32+
type ollamaEmbedRequest struct {
33+
Model string `json:"model"`
34+
Input interface{} `json:"input"` // todo: using interface{} to handle both string and []string, figure out better solution
35+
Options map[string]interface{} `json:"options,omitempty"`
36+
}
37+
38+
type ollamaEmbedResponse struct {
39+
Embeddings [][]float32 `json:"embeddings"`
40+
}
41+
42+
func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
43+
options, ok := req.Options.(*EmbedOptions)
44+
if !ok && req.Options != nil {
45+
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions")
46+
}
47+
if options == nil || options.Model == "" {
48+
return nil, fmt.Errorf("invalid embedding model: model must be specified")
49+
}
50+
51+
if serverAddress == "" {
52+
return nil, fmt.Errorf("invalid server address: address cannot be empty")
53+
}
54+
55+
ollamaReq := newOllamaEmbedRequest(options.Model, req.Documents)
56+
57+
jsonData, err := json.Marshal(ollamaReq)
58+
if err != nil {
59+
return nil, fmt.Errorf("failed to marshal embed request: %w", err)
60+
}
61+
62+
resp, err := sendEmbedRequest(ctx, serverAddress, jsonData)
63+
if err != nil {
64+
return nil, err
65+
}
66+
defer resp.Body.Close()
67+
68+
if resp.StatusCode != http.StatusOK {
69+
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode)
70+
}
71+
72+
var ollamaResp ollamaEmbedResponse
73+
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
74+
return nil, fmt.Errorf("failed to decode embed response: %w", err)
75+
}
76+
77+
return newEmbedResponse(ollamaResp.Embeddings), nil
78+
}
79+
80+
func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) {
81+
client := &http.Client{}
82+
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData))
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to create request: %w", err)
85+
}
86+
httpReq.Header.Set("Content-Type", "application/json")
87+
return client.Do(httpReq)
88+
}
89+
90+
func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest {
91+
var input interface{}
92+
if len(documents) == 1 {
93+
input = concatenateText(documents[0])
94+
} else {
95+
texts := make([]string, len(documents))
96+
for i, doc := range documents {
97+
texts[i] = concatenateText(doc)
98+
}
99+
input = texts
100+
}
101+
102+
return ollamaEmbedRequest{
103+
Model: model,
104+
Input: input,
105+
}
106+
}
107+
108+
func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse {
109+
resp := &ai.EmbedResponse{
110+
Embeddings: make([]*ai.DocumentEmbedding, len(embeddings)),
111+
}
112+
for i, embedding := range embeddings {
113+
resp.Embeddings[i] = &ai.DocumentEmbedding{Embedding: embedding}
114+
}
115+
return resp
116+
}
117+
118+
func concatenateText(doc *ai.Document) string {
119+
var builder strings.Builder
120+
for _, part := range doc.Content {
121+
builder.WriteString(part.Text)
122+
}
123+
result := builder.String()
124+
return result
125+
}
126+
127+
// DefineEmbedder defines an embedder with a given server address.
128+
func DefineEmbedder(serverAddress string, model string) ai.Embedder {
129+
state.mu.Lock()
130+
defer state.mu.Unlock()
131+
if !state.initted {
132+
panic("ollama.Init not called")
133+
}
134+
return ai.DefineEmbedder(provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
135+
if req.Options == nil {
136+
req.Options = &EmbedOptions{Model: model}
137+
}
138+
if req.Options.(*EmbedOptions).Model == "" {
139+
req.Options.(*EmbedOptions).Model = model
140+
}
141+
return embed(ctx, serverAddress, req)
142+
})
143+
}
144+
145+
// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin.
146+
func IsDefinedEmbedder(serverAddress string) bool {
147+
isDefined := ai.IsDefinedEmbedder(provider, serverAddress)
148+
return isDefined
149+
}
150+
151+
// Embedder returns the [ai.Embedder] with the given server address.
152+
// It returns nil if the embedder was not defined.
153+
func Embedder(serverAddress string) ai.Embedder {
154+
return ai.LookupEmbedder(provider, serverAddress)
155+
}

‎go/plugins/ollama/embed_test.go‎

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ollama
16+
17+
import (
18+
"context"
19+
"encoding/json"
20+
"net/http"
21+
"net/http/httptest"
22+
"strings"
23+
"testing"
24+
25+
"github.com/firebase/genkit/go/ai"
26+
)
27+
28+
func TestEmbedValidRequest(t *testing.T) {
29+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
w.WriteHeader(http.StatusOK)
31+
json.NewEncoder(w).Encode(ollamaEmbedResponse{
32+
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
33+
})
34+
}))
35+
defer server.Close()
36+
37+
req := &ai.EmbedRequest{
38+
Documents: []*ai.Document{
39+
ai.DocumentFromText("test", nil),
40+
},
41+
Options: &EmbedOptions{Model: "all-minilm"},
42+
}
43+
44+
resp, err := embed(context.Background(), server.URL, req)
45+
if err != nil {
46+
t.Fatalf("expected no error, got %v", err)
47+
}
48+
49+
if len(resp.Embeddings) != 1 {
50+
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings))
51+
}
52+
}
53+
54+
func TestEmbedInvalidServerAddress(t *testing.T) {
55+
req := &ai.EmbedRequest{
56+
Documents: []*ai.Document{
57+
ai.DocumentFromText("test", nil),
58+
},
59+
Options: &EmbedOptions{Model: "all-minilm"},
60+
}
61+
62+
_, err := embed(context.Background(), "", req)
63+
if err == nil || !strings.Contains(err.Error(), "invalid server address") {
64+
t.Fatalf("expected invalid server address error, got %v", err)
65+
}
66+
}

0 commit comments

Comments
 (0)