11import math
22import dataclasses
33from functools import partial
4+ from typing_extensions import dataclass_transform # Made it into standard lib from Python 3.12
45
56import gstaichi as ti
67import numpy as np
@@ -24,35 +25,38 @@ def maybe_shape(shape, is_on):
2425 return shape if is_on else ()
2526
2627
28+ @dataclass_transform (eq_default = True , order_default = True , kw_only_default = False , frozen_default = True )
2729class AutoInitMeta (type ):
2830 def __new__ (cls , name , bases , namespace ):
29- field_names = namespace ["__annotations__" ].keys ()
31+ names = tuple (namespace ["__annotations__" ].keys ())
32+ defaults = {k : namespace [k ] for k in names if k in namespace }
3033
3134 def __init__ (self , * args , ** kwargs ):
32- assigned = {}
35+ # Initialize assigned arguments from defaults
36+ assigned = defaults .copy ()
3337
3438 # Assign positional arguments
35- if len (args ) > len (field_names ):
36- raise TypeError (f"{ name } () takes { len (field_names )} positional arguments but { len (args )} were given" )
37- for field , value in zip (field_names , args ):
38- assigned [field ] = value
39+ if len (args ) > len (names ):
40+ raise TypeError (f"{ name } () takes { len (names )} positional arguments but { len (args )} were given" )
41+ for key , value in zip (names , args ):
42+ assigned [key ] = value
3943
4044 # Assign keyword arguments
4145 for key , value in kwargs .items ():
42- if key not in field_names :
46+ if key not in names :
4347 raise TypeError (f"{ name } () got unexpected keyword argument '{ key } '" )
44- if key in assigned :
48+ if key in names [: len ( args )] :
4549 raise TypeError (f"{ name } () got multiple values for argument '{ key } '" )
4650 assigned [key ] = value
4751
4852 # Check for missing arguments
49- for field in field_names :
50- if field not in assigned :
51- raise TypeError (f"{ name } () missing required argument: '{ field } '" )
53+ for key in names :
54+ if key not in assigned :
55+ raise TypeError (f"{ name } () missing required argument: '{ key } '" )
5256
5357 # Set attributes
54- for field , value in assigned .items ():
55- setattr (self , field , value )
58+ for key , value in assigned .items ():
59+ setattr (self , key , value )
5660
5761 namespace ["__init__" ] = __init__
5862
@@ -100,7 +104,7 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS):
100104 noslip_iterations : V_ANNOTATION
101105 noslip_tolerance : V_ANNOTATION
102106 n_equalities : V_ANNOTATION
103- n_equalities_candidate : V_ANNOTATION
107+ n_candidate_equalities : V_ANNOTATION
104108 hibernation_thresh_acc : V_ANNOTATION
105109 hibernation_thresh_vel : V_ANNOTATION
106110 EPS : V_ANNOTATION
@@ -142,7 +146,7 @@ def get_rigid_global_info(solver):
142146 noslip_iterations = V_SCALAR_FROM (dtype = gs .ti_int , value = solver ._options .noslip_iterations ),
143147 noslip_tolerance = V_SCALAR_FROM (dtype = gs .ti_float , value = solver ._options .noslip_tolerance ),
144148 n_equalities = V_SCALAR_FROM (dtype = gs .ti_int , value = solver ._n_equalities ),
145- n_equalities_candidate = V_SCALAR_FROM (dtype = gs .ti_int , value = solver .n_equalities_candidate ),
149+ n_candidate_equalities = V_SCALAR_FROM (dtype = gs .ti_int , value = solver .n_candidate_equalities_ ),
146150 hibernation_thresh_acc = V_SCALAR_FROM (dtype = gs .ti_float , value = solver ._hibernation_thresh_acc ),
147151 hibernation_thresh_vel = V_SCALAR_FROM (dtype = gs .ti_float , value = solver ._hibernation_thresh_vel ),
148152 EPS = V_SCALAR_FROM (dtype = gs .ti_float , value = gs .EPS ),
@@ -1658,7 +1662,7 @@ class StructEqualitiesInfo(metaclass=BASE_METACLASS):
16581662
16591663
16601664def get_equalities_info (solver ):
1661- shape = (solver .n_equalities_candidate , solver ._B )
1665+ shape = (solver .n_candidate_equalities_ , solver ._B )
16621666
16631667 return StructEqualitiesInfo (
16641668 eq_obj1id = V (dtype = gs .ti_int , shape = shape ),
@@ -1725,19 +1729,19 @@ def get_entities_state(solver):
17251729@ti .data_oriented
17261730class StructRigidSimStaticConfig (metaclass = AutoInitMeta ):
17271731 para_level : int
1728- requires_grad : bool
1729- use_hibernation : bool
1730- batch_links_info : bool
1731- batch_dofs_info : bool
1732- batch_joints_info : bool
1733- enable_mujoco_compatibility : bool
1734- enable_multi_contact : bool
17351732 enable_collision : bool
1736- enable_joint_limit : bool
1737- box_box_detection : bool
1738- sparse_solve : bool
1739- integrator : int
1740- solver_type : int
1733+ use_hibernation : bool = False
1734+ batch_links_info : bool = False
1735+ batch_dofs_info : bool = False
1736+ batch_joints_info : bool = False
1737+ enable_mujoco_compatibility : bool = False
1738+ enable_multi_contact : bool = False
1739+ enable_joint_limit : bool = False
1740+ box_box_detection : bool = True
1741+ sparse_solve : bool = False
1742+ integrator : int = gs .integrator .approximate_implicitfast
1743+ solver_type : int = gs .constraint_solver .CG
1744+ requires_grad : bool = False
17411745
17421746
17431747# =========================================== DataManager ===========================================
0 commit comments