@@ -26,7 +26,7 @@ class while customizing it with any plugins.
2626
2727from genkit .aio import Channel
2828from genkit .blocks .document import Document
29- from genkit .blocks .embedding import EmbedRequest , EmbedResponse
29+ from genkit .blocks .embedding import EmbedderRef
3030from genkit .blocks .generate import (
3131 StreamingCallback as ModelStreamingCallback ,
3232 generate_action ,
@@ -39,6 +39,7 @@ class while customizing it with any plugins.
3939from genkit .blocks .prompt import PromptConfig , to_generate_action_options
4040from genkit .core .action import ActionRunContext
4141from genkit .core .action .types import ActionKind
42+ from genkit .core .typing import EmbedRequest , EmbedResponse
4243from genkit .types import (
4344 DocumentData ,
4445 GenerationCommonConfig ,
@@ -295,10 +296,12 @@ def generate_stream(
295296
296297 async def embed (
297298 self ,
298- embedder : str | None = None ,
299+ embedder : str | EmbedderRef | None = None ,
299300 documents : list [Document ] | None = None ,
300301 options : dict [str , Any ] | None = None ,
301302 ) -> EmbedResponse :
303+ embedder_name : str
304+ embedder_config : dict [str , Any ] = {}
302305 """Calculates embeddings for documents.
303306
304307 Args:
@@ -309,9 +312,22 @@ async def embed(
309312 Returns:
310313 The generated response with embeddings.
311314 """
312- embed_action = self .registry .lookup_action (ActionKind .EMBEDDER , embedder )
315+ if isinstance (embedder , EmbedderRef ):
316+ embedder_name = embedder .name
317+ embedder_config = embedder .config or {}
318+ if embedder .version :
319+ embedder_config ['version' ] = embedder .version # Handle version from ref
320+ elif isinstance (embedder , str ):
321+ embedder_name = embedder
322+ else :
323+ # Handle case where embedder is None
324+ raise ValueError ('Embedder must be specified as a string name or an EmbedderRef.' )
313325
314- return (await embed_action .arun (EmbedRequest (input = documents , options = options ))).response
326+ # Merge options passed to embed() with config from EmbedderRef
327+ final_options = {** (embedder_config or {}), ** (options or {})}
328+ embed_action = self .registry .lookup_action (ActionKind .EMBEDDER , embedder_name )
329+
330+ return (await embed_action .arun (EmbedRequest (input = documents , options = final_options ))).response
315331
316332 async def retrieve (
317333 self ,
@@ -335,3 +351,30 @@ async def retrieve(
335351 retrieve_action = self .registry .lookup_action (ActionKind .RETRIEVER , retriever )
336352
337353 return (await retrieve_action .arun (RetrieverRequest (query = query , options = options ))).response
354+
355+ async def embed (
356+ self ,
357+ embedder : str | EmbedderRef | None = None ,
358+ documents : list [Document ] | None = None ,
359+ options : dict [str , Any ] | None = None ,
360+ ) -> EmbedResponse :
361+ embedder_name : str
362+ embedder_config : dict [str , Any ] = {}
363+
364+ if isinstance (embedder , EmbedderRef ):
365+ embedder_name = embedder .name
366+ embedder_config = embedder .config or {}
367+ if embedder .version :
368+ embedder_config ['version' ] = embedder .version # Handle version from ref
369+ elif isinstance (embedder , str ):
370+ embedder_name = embedder
371+ else :
372+ # Handle case where embedder is None
373+ raise ValueError ('Embedder must be specified as a string name or an EmbedderRef.' )
374+
375+ # Merge options passed to embed() with config from EmbedderRef
376+ final_options = {** (embedder_config or {}), ** (options or {})}
377+
378+ embed_action = self .registry .lookup_action (ActionKind .EMBEDDER , embedder_name )
379+
380+ return (await embed_action .arun (EmbedRequest (input = documents , options = final_options ))).response
0 commit comments