Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/packages/genkit/src/genkit/ai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def resolver(kind, name, plugin=plugin):
return plugin.resolve_action(self, kind, name)

self.registry.register_action_resolver(plugin.plugin_name(), resolver)
self.registry.register_list_actions_resolver(plugin.plugin_name(), plugin.list_actions)
else:
raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`')

Expand Down
12 changes: 12 additions & 0 deletions py/packages/genkit/src/genkit/ai/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,15 @@ def initialize(self, ai: GenkitRegistry) -> None:
None, initialization is done by side-effect on the registry.
"""
pass

def list_actions(self) -> list[dict[str, str]]:
"""Generate a list of available actions or models.

Returns:
list of actions dicts with the following shape:
{
'name': str,
'kind': ActionKind,
}
"""
return []
11 changes: 10 additions & 1 deletion py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def define_embedder(
self,
name: str,
fn: EmbedderFn,
config_schema: BaseModel | dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
description: str | None = None,
) -> Action:
Expand All @@ -464,15 +465,23 @@ def define_embedder(
Args:
name: Name of the model.
fn: Function implementing the embedder behavior.
config_schema: Optional schema for embedder configuration.
metadata: Optional metadata for the model.
description: Optional description for the embedder.
"""
embedder_meta: dict[str, Any] = metadata if metadata else {}
if 'embedder' not in embedder_meta:
embedder_meta['embedder'] = {}

if config_schema:
embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema)

embedder_description = get_func_description(fn, description)
return self.registry.register_action(
name=name,
kind=ActionKind.EMBEDDER,
fn=fn,
metadata=metadata,
metadata=embedder_meta,
description=embedder_description,
)

Expand Down
20 changes: 20 additions & 0 deletions py/packages/genkit/src/genkit/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,28 @@
"""Embedding actions."""

from collections.abc import Callable
from typing import Any

from genkit.ai import ActionKind
from genkit.core.action import ActionMetadata
from genkit.core.schema import to_json_schema
from genkit.core.typing import EmbedRequest, EmbedResponse

# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse]
EmbedderFn = Callable[[EmbedRequest], EmbedResponse]


def embedder_action_metadata(
name: str,
info: dict[str, Any] | None = None,
config_schema: Any | None = None,
) -> ActionMetadata:
"""Generates an ActionMetadata for embedders."""
info = info if info is not None else {}
return ActionMetadata(
kind=ActionKind.EMBEDDER,
name=name,
input_json_schema=to_json_schema(EmbedRequest),
output_json_schema=to_json_schema(EmbedResponse),
metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}},
)
20 changes: 19 additions & 1 deletion py/packages/genkit/src/genkit/blocks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def my_model(request: GenerateRequest) -> GenerateResponse:

from pydantic import BaseModel, Field

from genkit.core.action import ActionRunContext
from genkit.ai import ActionKind
from genkit.core.action import ActionMetadata, ActionRunContext
from genkit.core.extract import extract_json
from genkit.core.schema import to_json_schema
from genkit.core.typing import (
Candidate,
DocumentPart,
Expand Down Expand Up @@ -423,3 +425,19 @@ def get_part_counts(parts: list[Part]) -> PartCounts:
part_counts.audio += 1 if is_audio else 0

return part_counts


def model_action_metadata(
name: str,
info: dict[str, Any] | None = None,
config_schema: Any | None = None,
) -> ActionMetadata:
"""Generates an ActionMetadata for models."""
info = info if info is not None else {}
return ActionMetadata(
kind=ActionKind.MODEL,
name=name,
input_json_schema=to_json_schema(GenerateRequest),
output_json_schema=to_json_schema(GenerateResponse),
metadata={'model': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}},
)
2 changes: 2 additions & 0 deletions py/packages/genkit/src/genkit/core/action/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ._action import (
Action,
ActionMetadata,
ActionRunContext,
)
from ._key import (
Expand All @@ -28,6 +29,7 @@

__all__ = [
Action.__name__,
ActionMetadata.__name__,
ActionRunContext.__name__,
create_action_key.__name__,
parse_action_key.__name__,
Expand Down
16 changes: 15 additions & 1 deletion py/packages/genkit/src/genkit/core/action/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from functools import cached_property
from typing import Any

from pydantic import TypeAdapter
from pydantic import BaseModel, TypeAdapter

from genkit.aio import Channel, ensure_async
from genkit.core.error import GenkitError
Expand Down Expand Up @@ -457,6 +457,20 @@ def _initialize_io_schemas(
self._metadata[ActionMetadataKey.OUTPUT_KEY] = self._output_schema


class ActionMetadata(BaseModel):
"""Metadata for actions."""

kind: ActionKind
name: str
description: str | None = None
input_schema: Any | None = None
input_json_schema: dict[str, Any] | None = None
output_schema: Any | None = None
output_json_schema: dict[str, Any] | None = None
stream_schema: Any | None = None
metadata: dict[str, Any] | None = None


_SyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse]
_AsyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse]

Expand Down
1 change: 1 addition & 0 deletions py/packages/genkit/src/genkit/core/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def do_GET(self) -> None: # noqa: N802
self.send_header('content-type', 'application/json')
self.end_headers()
actions = registry.list_serializable_actions()
actions = registry.list_actions(actions)
self.wfile.write(bytes(json.dumps(actions), encoding))
else:
self.send_response(404)
Expand Down
66 changes: 65 additions & 1 deletion py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@
from collections.abc import Callable
from typing import Any

import structlog

from genkit.core.action import (
Action,
ActionMetadata,
create_action_key,
parse_action_key,
parse_plugin_name_from_action_name,
)
from genkit.core.action.types import ActionKind, ActionName, ActionResolver

logger = structlog.get_logger(__name__)

# An action store is a nested dictionary mapping ActionKind to a dictionary of
# action names and their corresponding Action instances.
#
Expand Down Expand Up @@ -75,14 +80,15 @@ class Registry:
def __init__(self):
"""Initialize an empty Registry instance."""
self._action_resolvers: dict[str, ActionResolver] = {}
self._list_actions_resolvers: dict[str, Callable] = {}
self._entries: ActionStore = {}
self._value_by_kind_and_name: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock()

# TODO: Figure out how to set this.
self.api_stability: str = 'stable'

def register_action_resolver(self, plugin_name: str, resolver: ActionResolver):
def register_action_resolver(self, plugin_name: str, resolver: ActionResolver) -> None:
"""Registers an ActionResolver function for a given plugin.

Args:
Expand All @@ -97,6 +103,21 @@ def register_action_resolver(self, plugin_name: str, resolver: ActionResolver):
raise ValueError(f'Plugin {plugin_name} already registered')
self._action_resolvers[plugin_name] = resolver

def register_list_actions_resolver(self, plugin_name: str, resolver: Callable) -> None:
"""Registers an Callable function to list available actions or models.

Args:
plugin_name: The name of the plugin.
resolver: The Callable function to list models.

Raises:
ValueError: If a resolver is already registered for the plugin.
"""
with self._lock:
if plugin_name in self._list_actions_resolvers:
raise ValueError(f'Plugin {plugin_name} already registered')
self._list_actions_resolvers[plugin_name] = resolver

def register_action(
self,
kind: ActionKind,
Expand Down Expand Up @@ -212,6 +233,49 @@ def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None
}
return actions

def list_actions(
self,
actions: dict[str, Action] | None = None,
allowed_kinds: set[ActionKind] | None = None,
) -> dict[str, Action] | None:
"""Add actions or models.

Args:
actions: dictionary of serializable actions.
allowed_kinds: The types of actions to list. If None, all actions
are listed.

Returns:
A dictionary of serializable Actions updated.
"""
if actions is None:
actions = {}

for plugin_name in self._list_actions_resolvers:
actions_lister = self._list_actions_resolvers[plugin_name]

# TODO: Set all the list_actions plugins' methods as cached_properties.
if isinstance(actions_lister, list):
actions_list = actions_lister
else:
actions_list = actions_lister()

for _action in actions_list:
kind = _action.kind
if allowed_kinds is not None and kind not in allowed_kinds:
continue
key = create_action_key(kind, _action.name)

if key not in actions:
actions[key] = {
'key': key,
'name': _action.name,
'inputSchema': _action.input_json_schema,
'outputSchema': _action.output_json_schema,
'metadata': _action.metadata,
}
return actions

def register_value(self, kind: str, name: str, value: Any):
"""Registers a value with a given kind and name.

Expand Down
34 changes: 34 additions & 0 deletions py/packages/genkit/tests/genkit/blocks/embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for the action module."""

from genkit.blocks.embedding import embedder_action_metadata
from genkit.core.action import ActionMetadata


def test_embedder_action_metadata():
"""Test for embedder_action_metadata."""
action_metadata = embedder_action_metadata(
name='test_model',
info={'label': 'test_label'},
config_schema=None,
)

assert isinstance(action_metadata, ActionMetadata)
assert action_metadata.input_json_schema is not None
assert action_metadata.output_json_schema is not None
assert action_metadata.metadata == {'embedder': {'customOptions': None, 'label': 'test_label'}}
16 changes: 16 additions & 0 deletions py/packages/genkit/tests/genkit/blocks/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
PartCounts,
get_basic_usage_stats,
get_part_counts,
model_action_metadata,
)
from genkit.core.action import ActionMetadata
from genkit.core.typing import (
Candidate,
GenerateRequest,
Expand Down Expand Up @@ -453,3 +455,17 @@ def test_response_wrapper_interrupts() -> None:
tool_request=ToolRequest(name='tool2', input={'bcd': 4}), metadata={'interrupt': {'banana': 'yes'}}
)
]


def test_model_action_metadata():
"""Test for model_action_metadata."""
action_metadata = model_action_metadata(
name='test_model',
info={'label': 'test_label'},
config_schema=None,
)

assert isinstance(action_metadata, ActionMetadata)
assert action_metadata.input_json_schema is not None
assert action_metadata.output_json_schema is not None
assert action_metadata.metadata == {'model': {'customOptions': None, 'label': 'test_label'}}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def test_health_check(asgi_client):
async def test_list_actions(asgi_client, mock_registry):
"""Test that the actions list endpoint returns registered actions."""
mock_registry.list_serializable_actions.return_value = {'action1': {'name': 'Action 1'}}
mock_registry.list_actions.return_value = {'action1': {'name': 'Action 1'}}
response = await asgi_client.get('/api/actions')
assert response.status_code == 200
assert response.json() == {'action1': {'name': 'Action 1'}}
Expand Down
Loading
Loading