Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions usecases/usecase/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Dict, Type

from utils.configurable import get_parameters, ParameterDefinitions, build_parser, get_arguments
from utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters


class UseCase(abc.ABC):
Expand Down Expand Up @@ -66,7 +66,7 @@ def use_case(name: str, desc: str):
def inner(cls: Type[UseCase]):
if name in use_cases:
raise IndexError(f"Use case with name {name} already exists")
use_cases[name] = _WrappedUseCase(name, desc, cls, get_parameters(cls.__init__, name))
use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name))

return cls

Expand Down
56 changes: 47 additions & 9 deletions utils/configurable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import dataclasses
import inspect
import os
from dataclasses import dataclass
Expand All @@ -12,6 +13,16 @@
load_dotenv()


def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None,
compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING) -> dataclasses.Field:
if metadata is None:
metadata = dict()
metadata["desc"] = desc

return dataclasses.field(default=default, default_factory=dataclasses.MISSING, init=init, repr=repr, hash=hash,
compare=compare, metadata=metadata, kw_only=kw_only)


def get_default(key, default):
return os.getenv(key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default))))

Expand All @@ -24,12 +35,14 @@ class ParameterDefinition:
name: str
type: Type
default: Any
description: str

def parser(self, basename: str, parser: argparse.ArgumentParser):
name = f"{basename}{self.name}"
default = get_default(name, self.default)

parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None)
parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None,
help=self.description)

def get(self, basename: str, args: argparse.Namespace):
return getattr(args, f"{basename}{self.name}")
Expand Down Expand Up @@ -62,7 +75,18 @@ def get(self, basename: str, args: argparse.Namespace):
return parameter


def get_parameters(fun, basename: str) -> ParameterDefinitions:
def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
if name is None:
name = cls.__name__
if fields is None and hasattr(cls, "__dataclass_fields__"):
fields = cls.__dataclass_fields__
return get_parameters(cls.__init__, name, fields)


def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
if fields is None:
fields = dict()

sig = inspect.signature(fun)
params: ParameterDefinitions = {}
for name, param in sig.parameters.items():
Expand All @@ -73,13 +97,27 @@ def get_parameters(fun, basename: str) -> ParameterDefinitions:
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have a type annotation")

default = param.default if param.default != inspect.Parameter.empty else None

if hasattr(param.annotation, "__parameters__"):
params[name] = ComplexParameterDefinition(name, param.annotation, default, get_parameters(param.annotation, f"{basename}.{fun.__name__}"))
elif param.annotation in (str, int, bool):
params[name] = ParameterDefinition(name, param.annotation, default)
description = None
type = param.annotation

field = None
if isinstance(default, dataclasses.Field):
field = default
default = field.default
elif name in fields:
field = fields[name]

if field is not None:
description = field.metadata.get("desc", None)
if field.type is not None:
type = field.type

if hasattr(type, "__parameters__"):
params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}"))
elif type in (str, int, bool):
params[name] = ParameterDefinition(name, type, default, description)
else:
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {param.annotation}")
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")

return params

Expand All @@ -106,7 +144,7 @@ def inner(cls) -> Configurable:
cls.name = service_name
cls.description = service_desc
cls.__service__ = True
cls.__parameters__ = get_parameters(cls.__init__, cls.__name__)
cls.__parameters__ = get_class_parameters(cls)

return cls

Expand Down
4 changes: 2 additions & 2 deletions utils/db_storage/db_storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sqlite3

from utils.configurable import configurable
from utils.configurable import configurable, parameter


@configurable("db_storage", "Stores the results of the experiments in a SQLite database")
class DbStorage:
def __init__(self, connection_string: str = ":memory:"):
def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")):
self.connection_string = connection_string

def init(self):
Expand Down
16 changes: 8 additions & 8 deletions utils/openai/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import tiktoken

from utils.configurable import configurable
from utils.configurable import configurable, parameter
from utils.llm_util import LLMResult, LLM


Expand All @@ -20,13 +20,13 @@ class OpenAIConnection(LLM):
If you really must use it, you can import it directly from the utils.openai.openai_llm module, which will later on
show you, that you did not specialize yet.
"""
api_key: str
model: str
context_size: int
api_url: str = "https://api.openai.com"
api_timeout: int = 240
api_backoff: int = 60
api_retries: int = 3
api_key: str = parameter(desc="OpenAI API Key")
model: str = parameter(desc="OpenAI model name")
context_size: int = parameter(desc="Maximum context size for the model, only used internally for things like trimming to the context size")
api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com")
api_timeout: int = parameter(desc="Timeout for the API request", default=240)
api_backoff: int = parameter(desc="Backoff time in seconds when running into rate-limits", default=60)
api_retries: int = parameter(desc="Number of retries when running into rate-limits", default=3)

def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult:
if retry >= self.api_retries:
Expand Down
2 changes: 1 addition & 1 deletion wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
for name, use_case in use_cases.items():
use_case.build_parser(subparser.add_parser(
name=use_case.name,
description=use_case.description
help=use_case.description
))

parsed = parser.parse_args(sys.argv[1:])
Expand Down