Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
)

// Embedder is the interface used to convert a document to a
// multidimensional vector. A [Retriever] will use a value of this type.
// multidimensional vector. A [DocumentStore] will use a value of this type.
type Embedder interface {
Embed(context.Context, *EmbedRequest) ([]float32, error)
}
Expand Down
24 changes: 12 additions & 12 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"github.com/firebase/genkit/go/core"
)

// Retriever supports adding documents to a database, and
// DocumentStore supports adding documents to a database, and
// retrieving documents from the database that are similar to a query document.
// Vector databases will implement this interface.
type Retriever interface {
type DocumentStore interface {
// Add a document to the database.
Index(context.Context, *IndexerRequest) error
// Retrieve matching documents from the database.
Expand All @@ -49,30 +49,30 @@ type RetrieverResponse struct {
Documents []*Document `json:"documents"`
}

// DefineRetriever takes index and retrieve functions that access a vector DB
// and returns a new Retriever that wraps them in registered actions.
func DefineRetriever(
// DefineDocumentStore takes index and retrieve functions that access a document store
// and returns a new [DocumentStore] that wraps them in registered actions.
func DefineDocumentStore(
name string,
index func(context.Context, *IndexerRequest) error,
retrieve func(context.Context, *RetrieverRequest) (*RetrieverResponse, error),
) Retriever {
) DocumentStore {
ia := core.DefineAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
})
ra := core.DefineAction(name, core.ActionTypeRetriever, nil, retrieve)
return &retriever{ia, ra}
return &docStore{ia, ra}
}

type retriever struct {
type docStore struct {
index *core.Action[*IndexerRequest, struct{}, struct{}]
retrieve *core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
}

func (r *retriever) Index(ctx context.Context, req *IndexerRequest) error {
_, err := r.index.Run(ctx, req, nil)
func (ds *docStore) Index(ctx context.Context, req *IndexerRequest) error {
_, err := ds.index.Run(ctx, req, nil)
return err
}

func (r *retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
return r.retrieve.Run(ctx, req, nil)
func (ds *docStore) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
return ds.retrieve.Run(ctx, req, nil)
}
52 changes: 26 additions & 26 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ import (
// retriever with genkit, and also return it.
// This retriever may only be used by a single goroutine at a time.
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
r, err := newRetriever(ctx, dir, name, embedder, embedderOptions)
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
r, err := newDocStore(ctx, dir, name, embedder, embedderOptions)
if err != nil {
return nil, err
}
return ai.DefineRetriever("devLocalVectorStore/"+name, r.Index, r.Retrieve), nil
return ai.DefineDocumentStore("devLocalVectorStore/"+name, r.Index, r.Retrieve), nil
}

// retriever implements the [ai.Retriever] interface
// docStore implements the [ai.DocumentStore] interface
// for a local vector database.
type retriever struct {
type docStore struct {
filename string
embedder ai.Embedder
embedderOptions any
Expand All @@ -61,8 +61,8 @@ type dbValue struct {
Embedding []float32 `json:"embedding"`
}

// newRetriever returns a new ai.Retriever to register.
func newRetriever(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
Expand All @@ -82,23 +82,23 @@ func newRetriever(ctx context.Context, dir, name string, embedder ai.Embedder, e
}
}

r := &retriever{
ds := &docStore{
filename: filename,
embedder: embedder,
embedderOptions: embedderOptions,
data: data,
}
return r, nil
return ds, nil
}

// Index implements the genkit [ai.Retriever.Index] method.
func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error {
// Index implements the genkit [ai.DocumentStore.Index] method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
for _, doc := range req.Documents {
ereq := &ai.EmbedRequest{
Document: doc,
Options: r.embedderOptions,
Options: ds.embedderOptions,
}
vals, err := r.embedder.Embed(ctx, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
Expand All @@ -108,16 +108,16 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error {
return err
}

if _, ok := r.data[id]; ok {
if _, ok := ds.data[id]; ok {
logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id)
continue
}

if r.data == nil {
r.data = make(map[string]dbValue)
if ds.data == nil {
ds.data = make(map[string]dbValue)
}

r.data[id] = dbValue{
ds.data[id] = dbValue{
Doc: doc,
Embedding: vals,
}
Expand All @@ -126,19 +126,19 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error {
// Update the file every time we add documents.
// We use a temporary file to avoid losing the original
// file, in case of a crash.
tmpname := r.filename + ".tmp"
tmpname := ds.filename + ".tmp"
f, err := os.Create(tmpname)
if err != nil {
return err
}
encoder := json.NewEncoder(f)
if err := encoder.Encode(r.data); err != nil {
if err := encoder.Encode(ds.data); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
if err := os.Rename(tmpname, r.filename); err != nil {
if err := os.Rename(tmpname, ds.filename); err != nil {
return err
}

Expand All @@ -152,15 +152,15 @@ type RetrieverOptions struct {
K int `json:"k,omitempty"` // number of entries to return
}

// Retrieve implements the genkit [ai.Retriever.Retrieve] method.
func (r *retriever) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Retrieve implements the genkit [ai.DocumentStore.Retrieve] method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Document: req.Document,
Options: r.embedderOptions,
Options: ds.embedderOptions,
}
vals, err := r.embedder.Embed(ctx, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
Expand All @@ -169,8 +169,8 @@ func (r *retriever) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
score float64
doc *ai.Document
}
scoredDocs := make([]scoredDoc, 0, len(r.data))
for _, dbv := range r.data {
scoredDocs := make([]scoredDoc, 0, len(ds.data))
for _, dbv := range ds.data {
score := similarity(vals, dbv.Embedding)
scoredDocs = append(scoredDocs, scoredDoc{
score: score,
Expand Down
18 changes: 9 additions & 9 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func TestLocalVec(t *testing.T) {
embedder.Register(d2, v2)
embedder.Register(d3, v3)

r, err := newRetriever(ctx, t.TempDir(), "testLocalVec", embedder, nil)
ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedder, nil)
if err != nil {
t.Fatal(err)
}

indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2, d3},
}
err = r.Index(ctx, indexerReq)
err = ds.Index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -72,7 +72,7 @@ func TestLocalVec(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := r.Retrieve(ctx, retrieverReq)
retrieverResp, err := ds.Retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down Expand Up @@ -113,15 +113,15 @@ func TestPersistentIndexing(t *testing.T) {

tDir := t.TempDir()

r, err := newRetriever(ctx, tDir, "testLocalVec", embedder, nil)
ds, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
if err != nil {
t.Fatal(err)
}

indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2},
}
err = r.Index(ctx, indexerReq)
err = ds.Index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -134,7 +134,7 @@ func TestPersistentIndexing(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := r.Retrieve(ctx, retrieverReq)
retrieverResp, err := ds.Retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand All @@ -144,15 +144,15 @@ func TestPersistentIndexing(t *testing.T) {
t.Errorf("got %d results, expected 2", len(docs))
}

rAnother, err := newRetriever(ctx, tDir, "testLocalVec", embedder, nil)
dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
if err != nil {
t.Fatal(err)
}

indexerReq = &ai.IndexerRequest{
Documents: []*ai.Document{d3},
}
err = rAnother.Index(ctx, indexerReq)
err = dsAnother.Index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -165,7 +165,7 @@ func TestPersistentIndexing(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err = rAnother.Retrieve(ctx, retrieverReq)
retrieverResp, err = dsAnother.Retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down
Loading