Skip to content

Commit 2e9899d

Browse files
authored
[MISC] Avoid atomic_max causing significant slowdown. (#2022)
1 parent 855c494 commit 2e9899d

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

‎genesis/engine/solvers/rigid/collider_decomp.py‎

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,9 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
413413

414414
n_envs = self._solver.n_envs
415415
if gs.use_zerocopy:
416+
n_contacts = ti_to_torch(self._collider_state.n_contacts, copy=False)
416417
if as_tensor or n_envs == 0:
417-
n_contacts_max = ti_to_torch(self._collider_state.n_contacts_max, copy=False).item()
418-
else:
419-
n_contacts = ti_to_torch(self._collider_state.n_contacts, copy=False)
418+
n_contacts_max = (n_contacts if n_envs == 0 else n_contacts.max()).item()
420419

421420
for key, data in self._contacts_info.items():
422421
if n_envs == 0:
@@ -625,8 +624,6 @@ def kernel_collider_clear(
625624
else:
626625
collider_state.n_contacts[i_b] = 0
627626

628-
collider_state.n_contacts_max[None] = 0
629-
630627

631628
@ti.kernel(fastcache=gs.use_fastcache)
632629
def collider_kernel_get_contacts(
@@ -637,7 +634,14 @@ def collider_kernel_get_contacts(
637634
collider_state: array_class.ColliderState,
638635
):
639636
_B = collider_state.active_buffer.shape[1]
640-
n_contacts_max = collider_state.n_contacts_max[None]
637+
638+
# TODO: Better implementation from gstaichi for this kind of reduction.
639+
n_contacts_max = gs.ti_int(0)
640+
ti.loop_config(serialize=True)
641+
for i_b in range(_B):
642+
n_contacts = collider_state.n_contacts[i_b]
643+
if n_contacts > n_contacts_max:
644+
n_contacts_max = n_contacts
641645

642646
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
643647
for i_b in range(_B):
@@ -1265,8 +1269,6 @@ def func_collision_clear(
12651269
else:
12661270
collider_state.n_contacts[i_b] = 0
12671271

1268-
collider_state.n_contacts_max[None] = 0
1269-
12701272

12711273
@ti.kernel(fastcache=gs.use_fastcache)
12721274
def func_broad_phase(
@@ -2138,7 +2140,6 @@ def func_add_contact(
21382140
collider_state.contact_data.link_b[i_c, i_b] = geoms_info.link_idx[i_gb]
21392141

21402142
collider_state.n_contacts[i_b] = i_c + 1
2141-
ti.atomic_max(collider_state.n_contacts_max[None], i_c + 1)
21422143
else:
21432144
errno[None] = 2
21442145

‎genesis/utils/array_class.py‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,6 @@ class StructColliderState(metaclass=BASE_METACLASS):
494494
xyz_max_min: V_ANNOTATION
495495
prism: V_ANNOTATION
496496
n_contacts: V_ANNOTATION
497-
n_contacts_max: V_ANNOTATION
498497
n_contacts_hibernated: V_ANNOTATION
499498
first_time: V_ANNOTATION
500499
contact_cache: StructContactCache
@@ -547,7 +546,6 @@ def get_collider_state(
547546
xyz_max_min=V(dtype=gs.ti_float, shape=(6, _B)),
548547
prism=V_VEC(3, dtype=gs.ti_float, shape=(6, _B)),
549548
n_contacts=V(dtype=gs.ti_int, shape=(_B,)),
550-
n_contacts_max=V(dtype=gs.ti_int, shape=()),
551549
n_contacts_hibernated=V(dtype=gs.ti_int, shape=(_B,)),
552550
first_time=V(dtype=gs.ti_int, shape=(_B,)),
553551
contact_cache=get_contact_cache(solver),

0 commit comments

Comments
 (0)