@@ -411,23 +411,15 @@ def has_display() -> bool:
411411# -------------------------------------- TAICHI SPECIALIZATION --------------------------------------
412412
413413TI_PROG_WEAKREF : weakref .ReferenceType | None = None
414- TI_DATA_CACHE : OrderedDict [int , "FieldMetadata" ] = OrderedDict ()
414+ TI_MAPPING_KEY_CACHE : OrderedDict [int , Any ] = OrderedDict ()
415415MAX_CACHE_SIZE = 1000
416416
417417
418- @dataclass
419- class FieldMetadata :
420- ndim : int
421- shape : tuple [int , ...]
422- dtype : ti ._lib .core .DataTypeCxx
423- mapping_key : Any
424-
425-
426418def _ensure_compiled (self , * args ):
427419 # Note that the field is enough to determine the key because all the other arguments depends on it.
428420 # This may not be the case anymore if the output is no longer dynamically allocated at some point.
429- ti_data_meta = TI_DATA_CACHE [ id (args [0 ])]
430- key = ti_data_meta . mapping_key
421+ cache_id = id (args [0 ])
422+ key = TI_MAPPING_KEY_CACHE . get ( cache_id )
431423 if key is None :
432424 extracted = []
433425 for arg , kernel_arg in zip (args , self .mapper .arguments ):
@@ -439,7 +431,7 @@ def _ensure_compiled(self, *args):
439431 subkey = (arg .dtype , len (arg .shape ), needs_grad , anno .boundary )
440432 extracted .append (subkey )
441433 key = tuple (extracted )
442- ti_data_meta . mapping_key = key
434+ TI_MAPPING_KEY_CACHE [ cache_id ] = key
443435 instance_id = self .mapper .mapping .get (key )
444436 if instance_id is None :
445437 key = ti .lang .kernel_impl .Kernel .ensure_compiled (self , * args )
@@ -518,7 +510,7 @@ def _launch_kernel(self, t_kernel, compiled_kernel_data, *args):
518510
519511def _destroy_callback (ref : weakref .ReferenceType ):
520512 global TI_PROG_WEAKREF
521- TI_DATA_CACHE .clear ()
513+ TI_MAPPING_KEY_CACHE .clear ()
522514 for kernel in TO_EXT_ARR_FAST_MAP .values ():
523515 kernel ._primal .mapper .mapping .clear ()
524516 TI_PROG_WEAKREF = None
@@ -539,31 +531,6 @@ def _destroy_callback(ref: weakref.ReferenceType):
539531 TO_EXT_ARR_FAST_MAP [data_type ] = func
540532
541533
542- def _get_ti_metadata (value : ti .Field | ti .Ndarray ) -> FieldMetadata :
543- global TI_PROG_WEAKREF
544-
545- # Keep track of taichi runtime to automatically clear cache if destroyed
546- if TI_PROG_WEAKREF is None :
547- TI_PROG_WEAKREF = weakref .ref (impl .get_runtime ().prog , _destroy_callback )
548-
549- # Get metadata
550- ti_data_id = id (value )
551- ti_data_meta = TI_DATA_CACHE .get (ti_data_id )
552- if ti_data_meta is None :
553- if isinstance (value , ti .MatrixField ):
554- ndim = value .ndim
555- elif isinstance (value , ti .Ndarray ):
556- ndim = len (value .element_shape )
557- else :
558- ndim = 0
559- ti_data_meta = FieldMetadata (ndim , value .shape , value .dtype , None )
560- if len (TI_DATA_CACHE ) == MAX_CACHE_SIZE :
561- TI_DATA_CACHE .popitem (last = False )
562- TI_DATA_CACHE [ti_data_id ] = ti_data_meta
563-
564- return ti_data_meta
565-
566-
567534def ti_to_python (
568535 value : ti .Field | ti .Ndarray ,
569536 transpose : bool = False ,
@@ -600,67 +567,73 @@ def ti_to_python(
600567 elif copy is None :
601568 copy = False
602569
603- # Extract metadata if necessary
604- if transpose or not use_zerocopy :
605- ti_data_meta = _get_ti_metadata (value )
606-
607570 # Leverage zero-copy if enabled
571+ batch_shape = value .shape
608572 if use_zerocopy :
609573 try :
610- out = value ._tc if to_torch or gs .backend != gs .cpu else value ._np
574+ if to_torch or gs .backend != gs .cpu :
575+ out = value ._T_tc if transpose else value ._tc
576+ else :
577+ out = value ._T_np if transpose else value ._np
611578 except AttributeError :
612- out = value ._tc = torch .utils .dlpack .from_dlpack (value .to_dlpack ())
579+ value ._tc = torch .utils .dlpack .from_dlpack (value .to_dlpack ())
580+ value ._T_tc = value ._tc .movedim (batch_ndim - 1 , 0 ) if (batch_ndim := len (batch_shape )) > 1 else value ._tc
581+ if to_torch :
582+ out = value ._T_tc if transpose else value ._tc
613583 if gs .backend == gs .cpu :
614584 value ._np = value ._tc .numpy ()
585+ value ._T_np = value ._T_tc .numpy ()
615586 if not to_torch :
616- out = value ._np
587+ out = value ._T_np if transpose else value . _np
617588 if copy :
618589 if to_torch :
619590 out = out .clone ()
620591 else :
621592 out = tensor_to_array (out )
622- else :
623- # Extract value as a whole.
624- # Note that this is usually much faster than using a custom kernel to extract a slice.
625- # The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
626- is_metal = gs .device .type == "mps"
627- out_dtype = _to_torch_type_fast (ti_data_meta .dtype ) if to_torch else _to_numpy_type_fast (ti_data_meta .dtype )
628- if issubclass (data_type , (ti .ScalarField , ti .ScalarNdarray )):
629- if to_torch :
630- out = torch .zeros (ti_data_meta .shape , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
631- else :
632- out = np .zeros (ti_data_meta .shape , dtype = out_dtype )
633- TO_EXT_ARR_FAST_MAP [data_type ](value , out )
634- elif issubclass (data_type , ti .MatrixField ):
635- as_vector = value .m == 1
636- shape_ext = (value .n ,) if as_vector else (value .n , value .m )
637- if to_torch :
638- out = torch .empty (
639- ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device
640- )
641- else :
642- out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
643- TO_EXT_ARR_FAST_MAP [data_type ](value , out , as_vector )
644- elif issubclass (data_type , (ti .VectorNdarray , ti .MatrixNdarray )):
645- layout_is_aos = 1
646- as_vector = issubclass (data_type , ti .VectorNdarray )
647- shape_ext = (value .n ,) if as_vector else (value .n , value .m )
648- if to_torch :
649- out = torch .empty (
650- ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device
651- )
652- else :
653- out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
654- TO_EXT_ARR_FAST_MAP [ti .MatrixNdarray ](value , out , layout_is_aos , as_vector )
593+ return out
594+
595+ # Keep track of taichi runtime to automatically clear cache if destroyed
596+ global TI_PROG_WEAKREF
597+ if TI_PROG_WEAKREF is None :
598+ TI_PROG_WEAKREF = weakref .ref (impl .get_runtime ().prog , _destroy_callback )
599+
600+ # Extract value as a whole.
601+ # Note that this is usually much faster than using a custom kernel to extract a slice.
602+ # The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
603+ is_metal = gs .device .type == "mps"
604+ out_dtype = _to_torch_type_fast (value .dtype ) if to_torch else _to_numpy_type_fast (value .dtype )
605+ if issubclass (data_type , (ti .ScalarField , ti .ScalarNdarray )):
606+ if to_torch :
607+ out = torch .zeros (batch_shape , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
608+ else :
609+ out = np .zeros (batch_shape , dtype = out_dtype )
610+ TO_EXT_ARR_FAST_MAP [data_type ](value , out )
611+ elif issubclass (data_type , ti .MatrixField ):
612+ as_vector = value .m == 1
613+ shape_ext = (value .n ,) if as_vector else (value .n , value .m )
614+ if to_torch :
615+ out = torch .empty (batch_shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
616+ else :
617+ out = np .zeros (batch_shape + shape_ext , dtype = out_dtype )
618+ TO_EXT_ARR_FAST_MAP [data_type ](value , out , as_vector )
619+ elif issubclass (data_type , (ti .VectorNdarray , ti .MatrixNdarray )):
620+ layout_is_aos = 1
621+ as_vector = issubclass (data_type , ti .VectorNdarray )
622+ shape_ext = (value .n ,) if as_vector else (value .n , value .m )
623+ if to_torch :
624+ out = torch .empty (batch_shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
655625 else :
656- gs .raise_exception (f"Unsupported type '{ type (value )} '." )
657- if to_torch and is_metal :
658- out = out .to (gs .device )
626+ out = np .zeros (batch_shape + shape_ext , dtype = out_dtype )
627+ TO_EXT_ARR_FAST_MAP [ti .MatrixNdarray ](value , out , layout_is_aos , as_vector )
628+ else :
629+ gs .raise_exception (f"Unsupported type '{ type (value )} '." )
630+ if to_torch and is_metal :
631+ out = out .to (gs .device )
659632
660633 # Transpose if necessary and requested.
661634 # Note that it is worth transposing here before slicing, as it preserve row-major memory alignment in case of
662635 # advanced masking, which would spare computation later on if expected from the user.
663- if transpose and (batch_ndim := len (ti_data_meta . shape )) > 1 :
636+ if transpose and (batch_ndim := len (batch_shape )) > 1 :
664637 if to_torch :
665638 out = out .movedim (batch_ndim - 1 , 0 )
666639 else :
@@ -766,14 +739,23 @@ def ti_to_torch(
766739 copy (bool, optional): Wether to enforce returning a copy no matter what. None to avoid copy if possible
767740 without raising an exception if not.
768741 """
769- # FIXME: Ideally one should detect if slicing would require a copy to avoid enforcing copy here
770- tensor = ti_to_python (value , transpose , copy = copy , to_torch = True )
742+ # Try efficient shortcut first and only fallback to standard branching if necessary.
743+ # FIXME: Ideally one should detect if slicing would require a copy to avoid enforcing copy here.
744+ if gs .use_zerocopy :
745+ try :
746+ tensor = value ._T_tc if transpose else value ._tc
747+ if copy :
748+ tensor = tensor .clone ()
749+ except AttributeError :
750+ tensor = ti_to_python (value , transpose , copy = copy , to_torch = True )
751+ else :
752+ tensor = ti_to_python (value , transpose , copy = copy , to_torch = True )
753+
771754 if row_mask is None and col_mask is None :
772755 return tensor
773756
774- ti_data_meta = _get_ti_metadata (value )
775757 raise_if_fancy = copy is False
776- if len (ti_data_meta .shape ) < 2 :
758+ if len (value .shape ) < 2 :
777759 if row_mask is not None and col_mask is not None :
778760 gs .raise_exception ("Cannot specify both row and column masks for tensor with 1D batch." )
779761 mask = indices_to_mask (
@@ -808,9 +790,8 @@ def ti_to_numpy(
808790 if row_mask is None and col_mask is None :
809791 return tensor
810792
811- ti_data_meta = _get_ti_metadata (value )
812793 raise_if_fancy = copy is False
813- if len (ti_data_meta .shape ) < 2 :
794+ if len (value .shape ) < 2 :
814795 if row_mask is not None and col_mask is not None :
815796 gs .raise_exception ("Cannot specify both row and column masks for tensor with 1D batch." )
816797 mask = indices_to_mask (
@@ -902,9 +883,9 @@ def broadcast_tensor(
902883 expected_ndim = len (expected_shape )
903884
904885 # Expand current tensor shape with extra dims of size 1 if necessary before expanding to expected shape
905- if tensor_ndim < 2 :
906- tensor_ = torch . atleast_1d ( tensor_ )
907- elif tensor_ndim < expected_ndim :
886+ if tensor_ndim == 0 :
887+ tensor_ = tensor_ [ None ]
888+ elif 2 <= tensor_ndim < expected_ndim :
908889 # Try expanding first dimensions if priority
909890 for dims_valid in tuple (combinations (range (expected_ndim ), tensor_ndim ))[::- 1 ]:
910891 curr_idx = 0
@@ -1005,11 +986,14 @@ def get_indexed_shape(tensor_shape, indices):
1005986
1006987
1007988def assign_indexed_tensor (
1008- out : torch .Tensor ,
989+ tensor : torch .Tensor ,
1009990 indices : tuple [int | slice | torch .Tensor , ...],
1010- in_ : np .typing .ArrayLike ,
1011- dtype : torch .dtype ,
991+ value : np .typing .ArrayLike ,
1012992 dim_names : tuple [str , ...] | list [str ] | None = None ,
1013993) -> None :
1014- indexed_shape = get_indexed_shape (out .shape , indices ) if indices else out .shape
1015- out [indices ] = broadcast_tensor (in_ , dtype , indexed_shape , dim_names )
994+ try :
995+ tensor [indices ] = value
996+ except (TypeError , RuntimeError ):
997+ # Try extended broadcasting as a fallback to avoid slowing down the hot path
998+ indexed_shape = get_indexed_shape (tensor .shape , indices ) if indices else tensor .shape
999+ tensor [indices ] = broadcast_tensor (value , tensor .dtype , indexed_shape , dim_names )
0 commit comments