Skip to content

Commit 9f06243

Browse files
authored
[BUG FIX] Revert removal of 'zero_velocity' optional argument to 'set_pos', 'set_quat'. (#2047)
* Revert removal of 'zero_velocity' optional argument to 'set_pos', 'set_quat'. * Rename 'n_equalities_candidate' in 'n_candidate_equalities'. * Add support of default values in data_oriented. Fix typing.
1 parent 4fb6fdf commit 9f06243

File tree

5 files changed

+44
-48
lines changed

5 files changed

+44
-48
lines changed

‎genesis/engine/entities/rigid_entity/rigid_entity.py‎

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,7 +1998,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal
19981998
return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe)
19991999

20002000
@gs.assert_built
2001-
def set_pos(self, pos, envs_idx=None, *, relative=False, unsafe=False):
2001+
def set_pos(self, pos, envs_idx=None, *, zero_velocity=True, relative=False, unsafe=False):
20022002
"""
20032003
Set position of the entity's base link.
20042004
@@ -2020,13 +2020,14 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, unsafe=False):
20202020
if _pos is not pos:
20212021
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
20222022
pos = _pos
2023-
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
2023+
if zero_velocity:
2024+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
20242025
self._solver.set_base_links_pos(
20252026
pos.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
20262027
)
20272028

20282029
@gs.assert_built
2029-
def set_quat(self, quat, envs_idx=None, *, relative=False, unsafe=False):
2030+
def set_quat(self, quat, envs_idx=None, *, zero_velocity=True, relative=False, unsafe=False):
20302031
"""
20312032
Set quaternion of the entity's base link.
20322033
@@ -2048,7 +2049,8 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, unsafe=False):
20482049
if _quat is not quat:
20492050
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
20502051
quat = _quat
2051-
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
2052+
if zero_velocity:
2053+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
20522054
self._solver.set_base_links_quat(
20532055
quat.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
20542056
)

‎genesis/engine/solvers/rigid/constraint_solver_decomp.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, rigid_solver: "RigidSolver"):
4141
4 * rigid_solver.collider._collider_info.max_contact_pairs[None]
4242
+ sum(joint.type in (gs.JOINT_TYPE.REVOLUTE, gs.JOINT_TYPE.PRISMATIC) for joint in self._solver.joints)
4343
+ self._solver.n_dofs
44-
+ self._solver.n_equalities_candidate * 6
44+
+ self._solver.n_candidate_equalities_ * 6
4545
)
4646
self.len_constraints_ = max(1, self.len_constraints)
4747

@@ -351,7 +351,7 @@ def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=Fal
351351
if overflow:
352352
gs.logger.warning(
353353
"Ignoring dynamically registered weld constraint to avoid exceeding max number of equality constraints"
354-
f"({self.rigid_global_info.n_equalities_candidate.to_numpy()}). Please increase the value of "
354+
f"({self.rigid_global_info.n_candidate_equalities.to_numpy()}). Please increase the value of "
355355
"RigidSolver's option 'max_dynamic_constraints'."
356356
)
357357

@@ -2190,7 +2190,7 @@ def kernel_add_weld_constraint(
21902190
for i_b_ in ti.ndrange(envs_idx.shape[0]):
21912191
i_b = envs_idx[i_b_]
21922192
i_e = constraint_state.ti_n_equalities[i_b]
2193-
if i_e == rigid_global_info.n_equalities_candidate[None]:
2193+
if i_e == rigid_global_info.n_candidate_equalities[None]:
21942194
overflow = True
21952195
else:
21962196
shared_pos = links_state.pos[link1_idx, i_b]

‎genesis/engine/solvers/rigid/rigid_solver_decomp.py‎

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ def build(self):
224224
self.n_entities_ = max(1, self.n_entities)
225225
self.n_free_verts_ = max(1, self.n_free_verts)
226226
self.n_fixed_verts_ = max(1, self.n_fixed_verts)
227-
228-
self.n_equalities_candidate = max(1, self.n_equalities + self._options.max_dynamic_constraints)
227+
self.n_candidate_equalities_ = max(1, self.n_equalities + self._options.max_dynamic_constraints)
229228

230229
# FIXME: AvatarSolver should not inherit from RigidSolver, not to mention that it is completely broken...
231230
is_rigid_solver = type(self) is RigidSolver
@@ -250,16 +249,7 @@ def build(self):
250249
self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig(
251250
para_level=self.sim._para_level,
252251
requires_grad=self.sim.options.requires_grad,
253-
use_hibernation=False,
254-
batch_links_info=False,
255-
batch_dofs_info=False,
256-
batch_joints_info=False,
257-
enable_mujoco_compatibility=False,
258-
enable_multi_contact=True,
259252
enable_collision=self._enable_collision,
260-
enable_joint_limit=False,
261-
box_box_detection=True,
262-
sparse_solve=False,
263253
integrator=gs.integrator.approximate_implicitfast,
264254
solver_type=gs.constraint_solver.CG,
265255
)

‎genesis/utils/array_class.py‎

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import dataclasses
33
from functools import partial
4+
from typing_extensions import dataclass_transform # Made it into standard lib from Python 3.12
45

56
import gstaichi as ti
67
import numpy as np
@@ -24,35 +25,38 @@ def maybe_shape(shape, is_on):
2425
return shape if is_on else ()
2526

2627

28+
@dataclass_transform(eq_default=True, order_default=True, kw_only_default=False, frozen_default=True)
2729
class AutoInitMeta(type):
2830
def __new__(cls, name, bases, namespace):
29-
field_names = namespace["__annotations__"].keys()
31+
names = tuple(namespace["__annotations__"].keys())
32+
defaults = {k: namespace[k] for k in names if k in namespace}
3033

3134
def __init__(self, *args, **kwargs):
32-
assigned = {}
35+
# Initialize assigned arguments from defaults
36+
assigned = defaults.copy()
3337

3438
# Assign positional arguments
35-
if len(args) > len(field_names):
36-
raise TypeError(f"{name}() takes {len(field_names)} positional arguments but {len(args)} were given")
37-
for field, value in zip(field_names, args):
38-
assigned[field] = value
39+
if len(args) > len(names):
40+
raise TypeError(f"{name}() takes {len(names)} positional arguments but {len(args)} were given")
41+
for key, value in zip(names, args):
42+
assigned[key] = value
3943

4044
# Assign keyword arguments
4145
for key, value in kwargs.items():
42-
if key not in field_names:
46+
if key not in names:
4347
raise TypeError(f"{name}() got unexpected keyword argument '{key}'")
44-
if key in assigned:
48+
if key in names[: len(args)]:
4549
raise TypeError(f"{name}() got multiple values for argument '{key}'")
4650
assigned[key] = value
4751

4852
# Check for missing arguments
49-
for field in field_names:
50-
if field not in assigned:
51-
raise TypeError(f"{name}() missing required argument: '{field}'")
53+
for key in names:
54+
if key not in assigned:
55+
raise TypeError(f"{name}() missing required argument: '{key}'")
5256

5357
# Set attributes
54-
for field, value in assigned.items():
55-
setattr(self, field, value)
58+
for key, value in assigned.items():
59+
setattr(self, key, value)
5660

5761
namespace["__init__"] = __init__
5862

@@ -100,7 +104,7 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS):
100104
noslip_iterations: V_ANNOTATION
101105
noslip_tolerance: V_ANNOTATION
102106
n_equalities: V_ANNOTATION
103-
n_equalities_candidate: V_ANNOTATION
107+
n_candidate_equalities: V_ANNOTATION
104108
hibernation_thresh_acc: V_ANNOTATION
105109
hibernation_thresh_vel: V_ANNOTATION
106110
EPS: V_ANNOTATION
@@ -142,7 +146,7 @@ def get_rigid_global_info(solver):
142146
noslip_iterations=V_SCALAR_FROM(dtype=gs.ti_int, value=solver._options.noslip_iterations),
143147
noslip_tolerance=V_SCALAR_FROM(dtype=gs.ti_float, value=solver._options.noslip_tolerance),
144148
n_equalities=V_SCALAR_FROM(dtype=gs.ti_int, value=solver._n_equalities),
145-
n_equalities_candidate=V_SCALAR_FROM(dtype=gs.ti_int, value=solver.n_equalities_candidate),
149+
n_candidate_equalities=V_SCALAR_FROM(dtype=gs.ti_int, value=solver.n_candidate_equalities_),
146150
hibernation_thresh_acc=V_SCALAR_FROM(dtype=gs.ti_float, value=solver._hibernation_thresh_acc),
147151
hibernation_thresh_vel=V_SCALAR_FROM(dtype=gs.ti_float, value=solver._hibernation_thresh_vel),
148152
EPS=V_SCALAR_FROM(dtype=gs.ti_float, value=gs.EPS),
@@ -1658,7 +1662,7 @@ class StructEqualitiesInfo(metaclass=BASE_METACLASS):
16581662

16591663

16601664
def get_equalities_info(solver):
1661-
shape = (solver.n_equalities_candidate, solver._B)
1665+
shape = (solver.n_candidate_equalities_, solver._B)
16621666

16631667
return StructEqualitiesInfo(
16641668
eq_obj1id=V(dtype=gs.ti_int, shape=shape),
@@ -1725,19 +1729,19 @@ def get_entities_state(solver):
17251729
@ti.data_oriented
17261730
class StructRigidSimStaticConfig(metaclass=AutoInitMeta):
17271731
para_level: int
1728-
requires_grad: bool
1729-
use_hibernation: bool
1730-
batch_links_info: bool
1731-
batch_dofs_info: bool
1732-
batch_joints_info: bool
1733-
enable_mujoco_compatibility: bool
1734-
enable_multi_contact: bool
17351732
enable_collision: bool
1736-
enable_joint_limit: bool
1737-
box_box_detection: bool
1738-
sparse_solve: bool
1739-
integrator: int
1740-
solver_type: int
1733+
use_hibernation: bool = False
1734+
batch_links_info: bool = False
1735+
batch_dofs_info: bool = False
1736+
batch_joints_info: bool = False
1737+
enable_mujoco_compatibility: bool = False
1738+
enable_multi_contact: bool = False
1739+
enable_joint_limit: bool = False
1740+
box_box_detection: bool = True
1741+
sparse_solve: bool = False
1742+
integrator: int = gs.integrator.approximate_implicitfast
1743+
solver_type: int = gs.constraint_solver.CG
1744+
requires_grad: bool = False
17411745

17421746

17431747
# =========================================== DataManager ===========================================

‎tests/test_rigid_physics.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,8 +2508,8 @@ def test_urdf_capsule(tmp_path, show_viewer, tol):
25082508
for _ in range(40):
25092509
scene.step()
25102510
geom_verts = tensor_to_array(geom.get_verts())
2511-
assert np.linalg.norm(geom_verts - np.zeros(3), axis=-1, ord=np.inf).min() < 1e-3
2512-
assert np.linalg.norm(geom_verts - np.array((0.0, 0.0, 0.14)), axis=-1, ord=np.inf).min() < 1e-3
2511+
assert np.linalg.norm(geom_verts - (0.0, 0.0, 0.0), axis=-1, ord=np.inf).min() < 1e-3
2512+
assert np.linalg.norm(geom_verts - (0.0, 0.0, 0.14), axis=-1, ord=np.inf).min() < 1e-3
25132513

25142514

25152515
@pytest.mark.required

0 commit comments

Comments
 (0)