Skip to content

Commit afa1af6

Browse files
authored
[MISC] Improve zero-copy efficiency. (#2054)
1 parent a6e1cc9 commit afa1af6

File tree

4 files changed

+93
-109
lines changed

4 files changed

+93
-109
lines changed

‎genesis/engine/solvers/rigid/collider_decomp.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,17 @@ def reset(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool
302302
envs_idx = slice(None) if envs_idx is None else envs_idx
303303
if not cache_only:
304304
first_time = ti_to_torch(self._collider_state.first_time, copy=False)
305-
if isinstance(envs_idx, torch.Tensor):
306-
first_time.scatter_(0, envs_idx, True)
307-
else:
308-
first_time[envs_idx] = True
305+
first_time[envs_idx] = True
309306

310307
i_va_ws = ti_to_torch(self._collider_state.contact_cache.i_va_ws, copy=False)
311308
normal = ti_to_torch(self._collider_state.contact_cache.normal, copy=False)
312309
if isinstance(envs_idx, torch.Tensor):
313310
max_possible_pairs = normal.shape[0]
314311
i_va_ws.scatter_(2, envs_idx[None, None].expand((2, max_possible_pairs, -1)), -1)
315312
normal.scatter_(1, envs_idx[None, :, None].expand((max_possible_pairs, -1, 3)), 0.0)
313+
elif envs_idx is None:
314+
i_va_ws.fill_(-1)
315+
normal.zero_()
316316
else:
317317
i_va_ws[:, :, envs_idx] = -1
318318
normal[:, envs_idx] = 0.0

‎genesis/engine/solvers/rigid/rigid_solver_decomp.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False):
14401440
if gs.use_zerocopy:
14411441
mask = (0, *indices_to_mask(qs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, qs_idx)
14421442
data = ti_to_torch(self._rigid_global_info.qpos, transpose=True, copy=False)
1443-
assign_indexed_tensor(data, mask, qpos, gs.tc_float)
1443+
assign_indexed_tensor(data, mask, qpos)
14441444
if mask and isinstance(mask[0], torch.Tensor):
14451445
envs_idx = mask[0].reshape((-1,))
14461446
else:
@@ -1556,7 +1556,7 @@ def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None):
15561556
data = ti_to_torch(getattr(self.dofs_info, name), transpose=True, copy=False)
15571557
num_values = len(tensor_list)
15581558
for j, mask_j in enumerate(((*mask, ..., j) for j in range(num_values)) if num_values > 1 else (mask,)):
1559-
assign_indexed_tensor(data, mask_j, tensor_list[j], gs.tc_float)
1559+
assign_indexed_tensor(data, mask_j, tensor_list[j])
15601560
return
15611561

15621562
tensor_list = list(tensor_list)
@@ -1638,7 +1638,7 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw
16381638
if velocity is None:
16391639
vel[mask] = 0.0
16401640
else:
1641-
assign_indexed_tensor(vel, mask, velocity, gs.tc_float)
1641+
assign_indexed_tensor(vel, mask, velocity)
16421642
if mask and isinstance(mask[0], torch.Tensor):
16431643
envs_idx = mask[0].reshape((-1,))
16441644
elif not isinstance(envs_idx, torch.Tensor):
@@ -1709,7 +1709,7 @@ def control_dofs_force(self, force, dofs_idx=None, envs_idx=None):
17091709
ctrl_mode = ti_to_torch(self.dofs_state.ctrl_mode, transpose=True, copy=False)
17101710
ctrl_mode[mask] = gs.CTRL_MODE.FORCE
17111711
ctrl_force = ti_to_torch(self.dofs_state.ctrl_force, transpose=True, copy=False)
1712-
assign_indexed_tensor(ctrl_force, mask, force, gs.tc_float)
1712+
assign_indexed_tensor(ctrl_force, mask, force)
17131713
return
17141714

17151715
force, dofs_idx, envs_idx = self._sanitize_io_variables(
@@ -1728,7 +1728,7 @@ def control_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None):
17281728
ctrl_pos = ti_to_torch(self.dofs_state.ctrl_pos, transpose=True, copy=False)
17291729
ctrl_pos[mask] = 0.0
17301730
ctrl_vel = ti_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False)
1731-
assign_indexed_tensor(ctrl_vel, mask, velocity, gs.tc_float)
1731+
assign_indexed_tensor(ctrl_vel, mask, velocity)
17321732
return
17331733

17341734
velocity, dofs_idx, envs_idx = self._sanitize_io_variables(
@@ -1745,7 +1745,7 @@ def control_dofs_position(self, position, dofs_idx=None, envs_idx=None):
17451745
ctrl_mode = ti_to_torch(self.dofs_state.ctrl_mode, transpose=True, copy=False)
17461746
ctrl_mode[mask] = gs.CTRL_MODE.POSITION
17471747
ctrl_pos = ti_to_torch(self.dofs_state.ctrl_pos, transpose=True, copy=False)
1748-
assign_indexed_tensor(ctrl_pos, mask, position, gs.tc_float)
1748+
assign_indexed_tensor(ctrl_pos, mask, position)
17491749
ctrl_vel = ti_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False)
17501750
ctrl_vel[mask] = 0.0
17511751
return
@@ -1764,9 +1764,9 @@ def control_dofs_position_velocity(self, position, velocity, dofs_idx=None, envs
17641764
ctrl_mode = ti_to_torch(self.dofs_state.ctrl_mode, transpose=True, copy=False)
17651765
ctrl_mode[mask] = gs.CTRL_MODE.POSITION
17661766
ctrl_pos = ti_to_torch(self.dofs_state.ctrl_pos, transpose=True, copy=False)
1767-
assign_indexed_tensor(ctrl_pos, mask, position, gs.tc_float)
1767+
assign_indexed_tensor(ctrl_pos, mask, position)
17681768
ctrl_vel = ti_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False)
1769-
assign_indexed_tensor(ctrl_vel, mask, velocity, gs.tc_float)
1769+
assign_indexed_tensor(ctrl_vel, mask, velocity)
17701770
return
17711771

17721772
position, dofs_idx, _ = self._sanitize_io_variables(

‎genesis/ext/pyrender/interaction/viewer_interaction.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import TYPE_CHECKING, cast
2-
from typing_extensions import override
2+
from typing_extensions import override # Made it into standard lib from Python 3.12
33
from threading import Lock as threading_Lock
44

55
import numpy as np

‎genesis/utils/misc.py‎

Lines changed: 80 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -411,23 +411,15 @@ def has_display() -> bool:
411411
# -------------------------------------- TAICHI SPECIALIZATION --------------------------------------
412412

413413
TI_PROG_WEAKREF: weakref.ReferenceType | None = None
414-
TI_DATA_CACHE: OrderedDict[int, "FieldMetadata"] = OrderedDict()
414+
TI_MAPPING_KEY_CACHE: OrderedDict[int, Any] = OrderedDict()
415415
MAX_CACHE_SIZE = 1000
416416

417417

418-
@dataclass
419-
class FieldMetadata:
420-
ndim: int
421-
shape: tuple[int, ...]
422-
dtype: ti._lib.core.DataTypeCxx
423-
mapping_key: Any
424-
425-
426418
def _ensure_compiled(self, *args):
427419
# Note that the field is enough to determine the key because all the other arguments depends on it.
428420
# This may not be the case anymore if the output is no longer dynamically allocated at some point.
429-
ti_data_meta = TI_DATA_CACHE[id(args[0])]
430-
key = ti_data_meta.mapping_key
421+
cache_id = id(args[0])
422+
key = TI_MAPPING_KEY_CACHE.get(cache_id)
431423
if key is None:
432424
extracted = []
433425
for arg, kernel_arg in zip(args, self.mapper.arguments):
@@ -439,7 +431,7 @@ def _ensure_compiled(self, *args):
439431
subkey = (arg.dtype, len(arg.shape), needs_grad, anno.boundary)
440432
extracted.append(subkey)
441433
key = tuple(extracted)
442-
ti_data_meta.mapping_key = key
434+
TI_MAPPING_KEY_CACHE[cache_id] = key
443435
instance_id = self.mapper.mapping.get(key)
444436
if instance_id is None:
445437
key = ti.lang.kernel_impl.Kernel.ensure_compiled(self, *args)
@@ -518,7 +510,7 @@ def _launch_kernel(self, t_kernel, compiled_kernel_data, *args):
518510

519511
def _destroy_callback(ref: weakref.ReferenceType):
520512
global TI_PROG_WEAKREF
521-
TI_DATA_CACHE.clear()
513+
TI_MAPPING_KEY_CACHE.clear()
522514
for kernel in TO_EXT_ARR_FAST_MAP.values():
523515
kernel._primal.mapper.mapping.clear()
524516
TI_PROG_WEAKREF = None
@@ -539,31 +531,6 @@ def _destroy_callback(ref: weakref.ReferenceType):
539531
TO_EXT_ARR_FAST_MAP[data_type] = func
540532

541533

542-
def _get_ti_metadata(value: ti.Field | ti.Ndarray) -> FieldMetadata:
543-
global TI_PROG_WEAKREF
544-
545-
# Keep track of taichi runtime to automatically clear cache if destroyed
546-
if TI_PROG_WEAKREF is None:
547-
TI_PROG_WEAKREF = weakref.ref(impl.get_runtime().prog, _destroy_callback)
548-
549-
# Get metadata
550-
ti_data_id = id(value)
551-
ti_data_meta = TI_DATA_CACHE.get(ti_data_id)
552-
if ti_data_meta is None:
553-
if isinstance(value, ti.MatrixField):
554-
ndim = value.ndim
555-
elif isinstance(value, ti.Ndarray):
556-
ndim = len(value.element_shape)
557-
else:
558-
ndim = 0
559-
ti_data_meta = FieldMetadata(ndim, value.shape, value.dtype, None)
560-
if len(TI_DATA_CACHE) == MAX_CACHE_SIZE:
561-
TI_DATA_CACHE.popitem(last=False)
562-
TI_DATA_CACHE[ti_data_id] = ti_data_meta
563-
564-
return ti_data_meta
565-
566-
567534
def ti_to_python(
568535
value: ti.Field | ti.Ndarray,
569536
transpose: bool = False,
@@ -600,67 +567,73 @@ def ti_to_python(
600567
elif copy is None:
601568
copy = False
602569

603-
# Extract metadata if necessary
604-
if transpose or not use_zerocopy:
605-
ti_data_meta = _get_ti_metadata(value)
606-
607570
# Leverage zero-copy if enabled
571+
batch_shape = value.shape
608572
if use_zerocopy:
609573
try:
610-
out = value._tc if to_torch or gs.backend != gs.cpu else value._np
574+
if to_torch or gs.backend != gs.cpu:
575+
out = value._T_tc if transpose else value._tc
576+
else:
577+
out = value._T_np if transpose else value._np
611578
except AttributeError:
612-
out = value._tc = torch.utils.dlpack.from_dlpack(value.to_dlpack())
579+
value._tc = torch.utils.dlpack.from_dlpack(value.to_dlpack())
580+
value._T_tc = value._tc.movedim(batch_ndim - 1, 0) if (batch_ndim := len(batch_shape)) > 1 else value._tc
581+
if to_torch:
582+
out = value._T_tc if transpose else value._tc
613583
if gs.backend == gs.cpu:
614584
value._np = value._tc.numpy()
585+
value._T_np = value._T_tc.numpy()
615586
if not to_torch:
616-
out = value._np
587+
out = value._T_np if transpose else value._np
617588
if copy:
618589
if to_torch:
619590
out = out.clone()
620591
else:
621592
out = tensor_to_array(out)
622-
else:
623-
# Extract value as a whole.
624-
# Note that this is usually much faster than using a custom kernel to extract a slice.
625-
# The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
626-
is_metal = gs.device.type == "mps"
627-
out_dtype = _to_torch_type_fast(ti_data_meta.dtype) if to_torch else _to_numpy_type_fast(ti_data_meta.dtype)
628-
if issubclass(data_type, (ti.ScalarField, ti.ScalarNdarray)):
629-
if to_torch:
630-
out = torch.zeros(ti_data_meta.shape, dtype=out_dtype, device="cpu" if is_metal else gs.device)
631-
else:
632-
out = np.zeros(ti_data_meta.shape, dtype=out_dtype)
633-
TO_EXT_ARR_FAST_MAP[data_type](value, out)
634-
elif issubclass(data_type, ti.MatrixField):
635-
as_vector = value.m == 1
636-
shape_ext = (value.n,) if as_vector else (value.n, value.m)
637-
if to_torch:
638-
out = torch.empty(
639-
ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device
640-
)
641-
else:
642-
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
643-
TO_EXT_ARR_FAST_MAP[data_type](value, out, as_vector)
644-
elif issubclass(data_type, (ti.VectorNdarray, ti.MatrixNdarray)):
645-
layout_is_aos = 1
646-
as_vector = issubclass(data_type, ti.VectorNdarray)
647-
shape_ext = (value.n,) if as_vector else (value.n, value.m)
648-
if to_torch:
649-
out = torch.empty(
650-
ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device
651-
)
652-
else:
653-
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
654-
TO_EXT_ARR_FAST_MAP[ti.MatrixNdarray](value, out, layout_is_aos, as_vector)
593+
return out
594+
595+
# Keep track of taichi runtime to automatically clear cache if destroyed
596+
global TI_PROG_WEAKREF
597+
if TI_PROG_WEAKREF is None:
598+
TI_PROG_WEAKREF = weakref.ref(impl.get_runtime().prog, _destroy_callback)
599+
600+
# Extract value as a whole.
601+
# Note that this is usually much faster than using a custom kernel to extract a slice.
602+
# The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
603+
is_metal = gs.device.type == "mps"
604+
out_dtype = _to_torch_type_fast(value.dtype) if to_torch else _to_numpy_type_fast(value.dtype)
605+
if issubclass(data_type, (ti.ScalarField, ti.ScalarNdarray)):
606+
if to_torch:
607+
out = torch.zeros(batch_shape, dtype=out_dtype, device="cpu" if is_metal else gs.device)
608+
else:
609+
out = np.zeros(batch_shape, dtype=out_dtype)
610+
TO_EXT_ARR_FAST_MAP[data_type](value, out)
611+
elif issubclass(data_type, ti.MatrixField):
612+
as_vector = value.m == 1
613+
shape_ext = (value.n,) if as_vector else (value.n, value.m)
614+
if to_torch:
615+
out = torch.empty(batch_shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device)
616+
else:
617+
out = np.zeros(batch_shape + shape_ext, dtype=out_dtype)
618+
TO_EXT_ARR_FAST_MAP[data_type](value, out, as_vector)
619+
elif issubclass(data_type, (ti.VectorNdarray, ti.MatrixNdarray)):
620+
layout_is_aos = 1
621+
as_vector = issubclass(data_type, ti.VectorNdarray)
622+
shape_ext = (value.n,) if as_vector else (value.n, value.m)
623+
if to_torch:
624+
out = torch.empty(batch_shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device)
655625
else:
656-
gs.raise_exception(f"Unsupported type '{type(value)}'.")
657-
if to_torch and is_metal:
658-
out = out.to(gs.device)
626+
out = np.zeros(batch_shape + shape_ext, dtype=out_dtype)
627+
TO_EXT_ARR_FAST_MAP[ti.MatrixNdarray](value, out, layout_is_aos, as_vector)
628+
else:
629+
gs.raise_exception(f"Unsupported type '{type(value)}'.")
630+
if to_torch and is_metal:
631+
out = out.to(gs.device)
659632

660633
# Transpose if necessary and requested.
661634
# Note that it is worth transposing here before slicing, as it preserve row-major memory alignment in case of
662635
# advanced masking, which would spare computation later on if expected from the user.
663-
if transpose and (batch_ndim := len(ti_data_meta.shape)) > 1:
636+
if transpose and (batch_ndim := len(batch_shape)) > 1:
664637
if to_torch:
665638
out = out.movedim(batch_ndim - 1, 0)
666639
else:
@@ -766,14 +739,23 @@ def ti_to_torch(
766739
copy (bool, optional): Wether to enforce returning a copy no matter what. None to avoid copy if possible
767740
without raising an exception if not.
768741
"""
769-
# FIXME: Ideally one should detect if slicing would require a copy to avoid enforcing copy here
770-
tensor = ti_to_python(value, transpose, copy=copy, to_torch=True)
742+
# Try efficient shortcut first and only fallback to standard branching if necessary.
743+
# FIXME: Ideally one should detect if slicing would require a copy to avoid enforcing copy here.
744+
if gs.use_zerocopy:
745+
try:
746+
tensor = value._T_tc if transpose else value._tc
747+
if copy:
748+
tensor = tensor.clone()
749+
except AttributeError:
750+
tensor = ti_to_python(value, transpose, copy=copy, to_torch=True)
751+
else:
752+
tensor = ti_to_python(value, transpose, copy=copy, to_torch=True)
753+
771754
if row_mask is None and col_mask is None:
772755
return tensor
773756

774-
ti_data_meta = _get_ti_metadata(value)
775757
raise_if_fancy = copy is False
776-
if len(ti_data_meta.shape) < 2:
758+
if len(value.shape) < 2:
777759
if row_mask is not None and col_mask is not None:
778760
gs.raise_exception("Cannot specify both row and column masks for tensor with 1D batch.")
779761
mask = indices_to_mask(
@@ -808,9 +790,8 @@ def ti_to_numpy(
808790
if row_mask is None and col_mask is None:
809791
return tensor
810792

811-
ti_data_meta = _get_ti_metadata(value)
812793
raise_if_fancy = copy is False
813-
if len(ti_data_meta.shape) < 2:
794+
if len(value.shape) < 2:
814795
if row_mask is not None and col_mask is not None:
815796
gs.raise_exception("Cannot specify both row and column masks for tensor with 1D batch.")
816797
mask = indices_to_mask(
@@ -902,9 +883,9 @@ def broadcast_tensor(
902883
expected_ndim = len(expected_shape)
903884

904885
# Expand current tensor shape with extra dims of size 1 if necessary before expanding to expected shape
905-
if tensor_ndim < 2:
906-
tensor_ = torch.atleast_1d(tensor_)
907-
elif tensor_ndim < expected_ndim:
886+
if tensor_ndim == 0:
887+
tensor_ = tensor_[None]
888+
elif 2 <= tensor_ndim < expected_ndim:
908889
# Try expanding first dimensions if priority
909890
for dims_valid in tuple(combinations(range(expected_ndim), tensor_ndim))[::-1]:
910891
curr_idx = 0
@@ -1005,11 +986,14 @@ def get_indexed_shape(tensor_shape, indices):
1005986

1006987

1007988
def assign_indexed_tensor(
1008-
out: torch.Tensor,
989+
tensor: torch.Tensor,
1009990
indices: tuple[int | slice | torch.Tensor, ...],
1010-
in_: np.typing.ArrayLike,
1011-
dtype: torch.dtype,
991+
value: np.typing.ArrayLike,
1012992
dim_names: tuple[str, ...] | list[str] | None = None,
1013993
) -> None:
1014-
indexed_shape = get_indexed_shape(out.shape, indices) if indices else out.shape
1015-
out[indices] = broadcast_tensor(in_, dtype, indexed_shape, dim_names)
994+
try:
995+
tensor[indices] = value
996+
except (TypeError, RuntimeError):
997+
# Try extended broadcasting as a fallback to avoid slowing down the hot path
998+
indexed_shape = get_indexed_shape(tensor.shape, indices) if indices else tensor.shape
999+
tensor[indices] = broadcast_tensor(value, tensor.dtype, indexed_shape, dim_names)

0 commit comments

Comments
 (0)