Skip to content

Commit 7d4e73c

Browse files
feat(py): add embedder reference support matching JS SDK (#3922)
Co-authored-by: Mengqin Shen <mengqin@google.com>
1 parent 3bdb6fe commit 7d4e73c

File tree

9 files changed

+403
-91
lines changed

9 files changed

+403
-91
lines changed

‎py/packages/genkit/src/genkit/ai/_aio.py‎

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class while customizing it with any plugins.
2626

2727
from genkit.aio import Channel
2828
from genkit.blocks.document import Document
29-
from genkit.blocks.embedding import EmbedRequest, EmbedResponse
29+
from genkit.blocks.embedding import EmbedderRef
3030
from genkit.blocks.generate import (
3131
StreamingCallback as ModelStreamingCallback,
3232
generate_action,
@@ -39,6 +39,7 @@ class while customizing it with any plugins.
3939
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
4040
from genkit.core.action import ActionRunContext
4141
from genkit.core.action.types import ActionKind
42+
from genkit.core.typing import EmbedRequest, EmbedResponse
4243
from genkit.types import (
4344
DocumentData,
4445
GenerationCommonConfig,
@@ -295,10 +296,12 @@ def generate_stream(
295296

296297
async def embed(
297298
self,
298-
embedder: str | None = None,
299+
embedder: str | EmbedderRef | None = None,
299300
documents: list[Document] | None = None,
300301
options: dict[str, Any] | None = None,
301302
) -> EmbedResponse:
303+
embedder_name: str
304+
embedder_config: dict[str, Any] = {}
302305
"""Calculates embeddings for documents.
303306
304307
Args:
@@ -309,9 +312,22 @@ async def embed(
309312
Returns:
310313
The generated response with embeddings.
311314
"""
312-
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder)
315+
if isinstance(embedder, EmbedderRef):
316+
embedder_name = embedder.name
317+
embedder_config = embedder.config or {}
318+
if embedder.version:
319+
embedder_config['version'] = embedder.version # Handle version from ref
320+
elif isinstance(embedder, str):
321+
embedder_name = embedder
322+
else:
323+
# Handle case where embedder is None
324+
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')
313325

314-
return (await embed_action.arun(EmbedRequest(input=documents, options=options))).response
326+
# Merge options passed to embed() with config from EmbedderRef
327+
final_options = {**(embedder_config or {}), **(options or {})}
328+
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)
329+
330+
return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
315331

316332
async def retrieve(
317333
self,
@@ -335,3 +351,30 @@ async def retrieve(
335351
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever)
336352

337353
return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response
354+
355+
async def embed(
356+
self,
357+
embedder: str | EmbedderRef | None = None,
358+
documents: list[Document] | None = None,
359+
options: dict[str, Any] | None = None,
360+
) -> EmbedResponse:
361+
embedder_name: str
362+
embedder_config: dict[str, Any] = {}
363+
364+
if isinstance(embedder, EmbedderRef):
365+
embedder_name = embedder.name
366+
embedder_config = embedder.config or {}
367+
if embedder.version:
368+
embedder_config['version'] = embedder.version # Handle version from ref
369+
elif isinstance(embedder, str):
370+
embedder_name = embedder
371+
else:
372+
# Handle case where embedder is None
373+
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')
374+
375+
# Merge options passed to embed() with config from EmbedderRef
376+
final_options = {**(embedder_config or {}), **(options or {})}
377+
378+
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)
379+
380+
return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response

‎py/packages/genkit/src/genkit/ai/_registry.py‎

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import structlog
4848
from pydantic import BaseModel
4949

50-
from genkit.blocks.embedding import EmbedderFn
50+
from genkit.blocks.embedding import EmbedderFn, EmbedderOptions
5151
from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn
5252
from genkit.blocks.formats.types import FormatDef
5353
from genkit.blocks.model import ModelFn, ModelMiddleware
@@ -458,8 +458,7 @@ def define_embedder(
458458
self,
459459
name: str,
460460
fn: EmbedderFn,
461-
config_schema: BaseModel | dict[str, Any] | None = None,
462-
metadata: dict[str, Any] | None = None,
461+
options: EmbedderOptions | None = None,
463462
description: str | None = None,
464463
) -> Action:
465464
"""Define a custom embedder action.
@@ -471,13 +470,20 @@ def define_embedder(
471470
metadata: Optional metadata for the model.
472471
description: Optional description for the embedder.
473472
"""
474-
embedder_meta: dict[str, Any] = metadata if metadata else {}
473+
embedder_meta: dict[str, Any] = {}
474+
if options:
475+
if options.label:
476+
embedder_meta['embedder']['label'] = options.label
477+
if options.dimensions:
478+
embedder_meta['embedder']['dimensions'] = options.dimensions
479+
if options.supports:
480+
embedder_meta['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
481+
if options.config_schema:
482+
embedder_meta['embedder']['customOptions'] = to_json_schema(options.config_schema)
483+
475484
if 'embedder' not in embedder_meta:
476485
embedder_meta['embedder'] = {}
477486

478-
if config_schema:
479-
embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema)
480-
481487
embedder_description = get_func_description(fn, description)
482488
return self.registry.register_action(
483489
name=name,

‎py/packages/genkit/src/genkit/blocks/embedding.py‎

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,76 @@
1616

1717
"""Embedding actions."""
1818

19-
from collections.abc import Callable
19+
from collections.abc import Awaitable, Callable
2020
from typing import Any
2121

22-
from genkit.ai import ActionKind
22+
from pydantic import BaseModel, ConfigDict, Field
23+
2324
from genkit.core.action import ActionMetadata
25+
from genkit.core.action.types import ActionKind
2426
from genkit.core.schema import to_json_schema
2527
from genkit.core.typing import EmbedRequest, EmbedResponse
2628

27-
# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse]
29+
30+
class EmbedderSupports(BaseModel):
31+
"""Embedder capability support."""
32+
33+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
34+
35+
input: list[str] | None = None
36+
multilingual: bool | None = None
37+
38+
39+
class EmbedderOptions(BaseModel):
40+
"""Configuration options for an embedder."""
41+
42+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
43+
44+
config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
45+
label: str | None = None
46+
supports: EmbedderSupports | None = None
47+
dimensions: int | None = None
48+
49+
50+
class EmbedderRef(BaseModel):
51+
"""Reference to an embedder with configuration."""
52+
53+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
54+
55+
name: str
56+
config: Any | None = None
57+
version: str | None = None
58+
59+
2860
EmbedderFn = Callable[[EmbedRequest], EmbedResponse]
2961

3062

3163
def embedder_action_metadata(
3264
name: str,
33-
info: dict[str, Any] | None = None,
34-
config_schema: Any | None = None,
65+
options: EmbedderOptions | None = None,
3566
) -> ActionMetadata:
36-
"""Generates an ActionMetadata for embedders."""
37-
info = info if info is not None else {}
67+
options = options if options is not None else EmbedderOptions()
68+
embedder_metadata_dict = {'embedder': {}}
69+
70+
if options.label:
71+
embedder_metadata_dict['embedder']['label'] = options.label
72+
73+
embedder_metadata_dict['embedder']['dimensions'] = options.dimensions
74+
75+
if options.supports:
76+
embedder_metadata_dict['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
77+
78+
embedder_metadata_dict['embedder']['customOptions'] = options.config_schema if options.config_schema else None
79+
3880
return ActionMetadata(
3981
kind=ActionKind.EMBEDDER,
4082
name=name,
4183
input_json_schema=to_json_schema(EmbedRequest),
4284
output_json_schema=to_json_schema(EmbedResponse),
43-
metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}},
85+
metadata=embedder_metadata_dict,
4486
)
87+
88+
89+
def create_embedder_ref(name: str, config: dict[str, Any] | None = None, version: str | None = None) -> EmbedderRef:
90+
"""Creates an EmbedderRef instance."""
91+
return EmbedderRef(name=name, config=config, version=version)

0 commit comments

Comments
 (0)