Skip to content

Commit 0820133

Browse files
authored
Merge pull request #16 from Neverbolt/parameter-descriptions
Adds the possibility to define help text for parameters
2 parents 7436a5d + 544f3b0 commit 0820133

File tree

5 files changed

+60
-22
lines changed

5 files changed

+60
-22
lines changed

‎usecases/usecase/usecase.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from typing import Dict, Type
55

6-
from utils.configurable import get_parameters, ParameterDefinitions, build_parser, get_arguments
6+
from utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters
77

88

99
class UseCase(abc.ABC):
@@ -66,7 +66,7 @@ def use_case(name: str, desc: str):
6666
def inner(cls: Type[UseCase]):
6767
if name in use_cases:
6868
raise IndexError(f"Use case with name {name} already exists")
69-
use_cases[name] = _WrappedUseCase(name, desc, cls, get_parameters(cls.__init__, name))
69+
use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name))
7070

7171
return cls
7272

‎utils/configurable.py‎

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import dataclasses
23
import inspect
34
import os
45
from dataclasses import dataclass
@@ -12,6 +13,16 @@
1213
load_dotenv()
1314

1415

16+
def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None,
17+
compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING) -> dataclasses.Field:
18+
if metadata is None:
19+
metadata = dict()
20+
metadata["desc"] = desc
21+
22+
return dataclasses.field(default=default, default_factory=dataclasses.MISSING, init=init, repr=repr, hash=hash,
23+
compare=compare, metadata=metadata, kw_only=kw_only)
24+
25+
1526
def get_default(key, default):
1627
return os.getenv(key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default))))
1728

@@ -24,12 +35,14 @@ class ParameterDefinition:
2435
name: str
2536
type: Type
2637
default: Any
38+
description: str
2739

2840
def parser(self, basename: str, parser: argparse.ArgumentParser):
2941
name = f"{basename}{self.name}"
3042
default = get_default(name, self.default)
3143

32-
parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None)
44+
parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None,
45+
help=self.description)
3346

3447
def get(self, basename: str, args: argparse.Namespace):
3548
return getattr(args, f"{basename}{self.name}")
@@ -62,7 +75,18 @@ def get(self, basename: str, args: argparse.Namespace):
6275
return parameter
6376

6477

65-
def get_parameters(fun, basename: str) -> ParameterDefinitions:
78+
def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
79+
if name is None:
80+
name = cls.__name__
81+
if fields is None and hasattr(cls, "__dataclass_fields__"):
82+
fields = cls.__dataclass_fields__
83+
return get_parameters(cls.__init__, name, fields)
84+
85+
86+
def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
87+
if fields is None:
88+
fields = dict()
89+
6690
sig = inspect.signature(fun)
6791
params: ParameterDefinitions = {}
6892
for name, param in sig.parameters.items():
@@ -73,13 +97,27 @@ def get_parameters(fun, basename: str) -> ParameterDefinitions:
7397
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have a type annotation")
7498

7599
default = param.default if param.default != inspect.Parameter.empty else None
76-
77-
if hasattr(param.annotation, "__parameters__"):
78-
params[name] = ComplexParameterDefinition(name, param.annotation, default, get_parameters(param.annotation, f"{basename}.{fun.__name__}"))
79-
elif param.annotation in (str, int, bool):
80-
params[name] = ParameterDefinition(name, param.annotation, default)
100+
description = None
101+
type = param.annotation
102+
103+
field = None
104+
if isinstance(default, dataclasses.Field):
105+
field = default
106+
default = field.default
107+
elif name in fields:
108+
field = fields[name]
109+
110+
if field is not None:
111+
description = field.metadata.get("desc", None)
112+
if field.type is not None:
113+
type = field.type
114+
115+
if hasattr(type, "__parameters__"):
116+
params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}"))
117+
elif type in (str, int, bool):
118+
params[name] = ParameterDefinition(name, type, default, description)
81119
else:
82-
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {param.annotation}")
120+
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")
83121

84122
return params
85123

@@ -106,7 +144,7 @@ def inner(cls) -> Configurable:
106144
cls.name = service_name
107145
cls.description = service_desc
108146
cls.__service__ = True
109-
cls.__parameters__ = get_parameters(cls.__init__, cls.__name__)
147+
cls.__parameters__ = get_class_parameters(cls)
110148

111149
return cls
112150

‎utils/db_storage/db_storage.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import sqlite3
22

3-
from utils.configurable import configurable
3+
from utils.configurable import configurable, parameter
44

55

66
@configurable("db_storage", "Stores the results of the experiments in a SQLite database")
77
class DbStorage:
8-
def __init__(self, connection_string: str = ":memory:"):
8+
def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")):
99
self.connection_string = connection_string
1010

1111
def init(self):

‎utils/openai/openai_llm.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import tiktoken
88

9-
from utils.configurable import configurable
9+
from utils.configurable import configurable, parameter
1010
from utils.llm_util import LLMResult, LLM
1111

1212

@@ -20,13 +20,13 @@ class OpenAIConnection(LLM):
2020
If you really must use it, you can import it directly from the utils.openai.openai_llm module, which will later on
2121
show you, that you did not specialize yet.
2222
"""
23-
api_key: str
24-
model: str
25-
context_size: int
26-
api_url: str = "https://api.openai.com"
27-
api_timeout: int = 240
28-
api_backoff: int = 60
29-
api_retries: int = 3
23+
api_key: str = parameter(desc="OpenAI API Key")
24+
model: str = parameter(desc="OpenAI model name")
25+
context_size: int = parameter(desc="Maximum context size for the model, only used internally for things like trimming to the context size")
26+
api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com")
27+
api_timeout: int = parameter(desc="Timeout for the API request", default=240)
28+
api_backoff: int = parameter(desc="Backoff time in seconds when running into rate-limits", default=60)
29+
api_retries: int = parameter(desc="Number of retries when running into rate-limits", default=3)
3030

3131
def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult:
3232
if retry >= self.api_retries:

‎wintermute.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main():
1010
for name, use_case in use_cases.items():
1111
use_case.build_parser(subparser.add_parser(
1212
name=use_case.name,
13-
description=use_case.description
13+
help=use_case.description
1414
))
1515

1616
parsed = parser.parse_args(sys.argv[1:])

0 commit comments

Comments
 (0)