11import argparse
2+ import dataclasses
23import inspect
34import os
45from dataclasses import dataclass
1213load_dotenv ()
1314
1415
16+ def parameter (* , desc : str , default = dataclasses .MISSING , init : bool = True , repr : bool = True , hash = None ,
17+ compare : bool = True , metadata : Dict = None , kw_only : bool = dataclasses .MISSING ) -> dataclasses .Field :
18+ if metadata is None :
19+ metadata = dict ()
20+ metadata ["desc" ] = desc
21+
22+ return dataclasses .field (default = default , default_factory = dataclasses .MISSING , init = init , repr = repr , hash = hash ,
23+ compare = compare , metadata = metadata , kw_only = kw_only )
24+
25+
1526def get_default (key , default ):
1627 return os .getenv (key , os .getenv (key .upper (), os .getenv (key .replace ("." , "_" ), os .getenv (key .replace ("." , "_" ).upper (), default ))))
1728
@@ -24,12 +35,14 @@ class ParameterDefinition:
2435 name : str
2536 type : Type
2637 default : Any
38+ description : str
2739
2840 def parser (self , basename : str , parser : argparse .ArgumentParser ):
2941 name = f"{ basename } { self .name } "
3042 default = get_default (name , self .default )
3143
32- parser .add_argument (f"--{ name } " , type = self .type , default = default , required = default is None )
44+ parser .add_argument (f"--{ name } " , type = self .type , default = default , required = default is None ,
45+ help = self .description )
3346
3447 def get (self , basename : str , args : argparse .Namespace ):
3548 return getattr (args , f"{ basename } { self .name } " )
@@ -62,7 +75,18 @@ def get(self, basename: str, args: argparse.Namespace):
6275 return parameter
6376
6477
65- def get_parameters (fun , basename : str ) -> ParameterDefinitions :
78+ def get_class_parameters (cls , name : str = None , fields : Dict [str , dataclasses .Field ] = None ) -> ParameterDefinitions :
79+ if name is None :
80+ name = cls .__name__
81+ if fields is None and hasattr (cls , "__dataclass_fields__" ):
82+ fields = cls .__dataclass_fields__
83+ return get_parameters (cls .__init__ , name , fields )
84+
85+
86+ def get_parameters (fun , basename : str , fields : Dict [str , dataclasses .Field ] = None ) -> ParameterDefinitions :
87+ if fields is None :
88+ fields = dict ()
89+
6690 sig = inspect .signature (fun )
6791 params : ParameterDefinitions = {}
6892 for name , param in sig .parameters .items ():
@@ -73,13 +97,27 @@ def get_parameters(fun, basename: str) -> ParameterDefinitions:
7397 raise ValueError (f"Parameter { name } of { basename } .{ fun .__name__ } must have a type annotation" )
7498
7599 default = param .default if param .default != inspect .Parameter .empty else None
76-
77- if hasattr (param .annotation , "__parameters__" ):
78- params [name ] = ComplexParameterDefinition (name , param .annotation , default , get_parameters (param .annotation , f"{ basename } .{ fun .__name__ } " ))
79- elif param .annotation in (str , int , bool ):
80- params [name ] = ParameterDefinition (name , param .annotation , default )
100+ description = None
101+ type = param .annotation
102+
103+ field = None
104+ if isinstance (default , dataclasses .Field ):
105+ field = default
106+ default = field .default
107+ elif name in fields :
108+ field = fields [name ]
109+
110+ if field is not None :
111+ description = field .metadata .get ("desc" , None )
112+ if field .type is not None :
113+ type = field .type
114+
115+ if hasattr (type , "__parameters__" ):
116+ params [name ] = ComplexParameterDefinition (name , type , default , description , get_class_parameters (type , f"{ basename } .{ fun .__name__ } " ))
117+ elif type in (str , int , bool ):
118+ params [name ] = ParameterDefinition (name , type , default , description )
81119 else :
82- raise ValueError (f"Parameter { name } of { basename } .{ fun .__name__ } must have str, int, bool, or a __parameters__ class as type, not { param . annotation } " )
120+ raise ValueError (f"Parameter { name } of { basename } .{ fun .__name__ } must have str, int, bool, or a __parameters__ class as type, not { type } " )
83121
84122 return params
85123
@@ -106,7 +144,7 @@ def inner(cls) -> Configurable:
106144 cls .name = service_name
107145 cls .description = service_desc
108146 cls .__service__ = True
109- cls .__parameters__ = get_parameters (cls . __init__ , cls . __name__ )
147+ cls .__parameters__ = get_class_parameters (cls )
110148
111149 return cls
112150
0 commit comments