@@ -61,8 +61,7 @@ def __init__(
6161 record_stream : Optional [bool ] = False ,
6262 low_cpu_mem_usage : bool = False ,
6363 onload_self : bool = True ,
64- offload_to_disk : bool = False ,
65- offload_path : Optional [str ] = None ,
64+ offload_to_disk_path : Optional [str ] = None ,
6665 ) -> None :
6766 self .modules = modules
6867 self .offload_device = offload_device
@@ -77,14 +76,11 @@ def __init__(
7776 self .onload_self = onload_self
7877 self .low_cpu_mem_usage = low_cpu_mem_usage
7978
80- self .offload_to_disk = offload_to_disk
81- self .offload_path = offload_path
79+ self .offload_to_disk_path = offload_to_disk_path
8280 self ._is_offloaded_to_disk = False
8381
84- if self .offload_to_disk :
85- if self .offload_path is None :
86- raise ValueError ("`offload_path` must be set when `offload_to_disk=True`." )
87- self .safetensors_file_path = os .path .join (self .offload_path , f"group_{ id (self )} .safetensors" )
82+ if self .offload_to_disk_path :
83+ self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ id (self )} .safetensors" )
8884
8985 all_tensors = []
9086 for module in self .modules :
@@ -150,7 +146,7 @@ def onload_(self):
150146 context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
151147 current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
152148
153- if self .offload_to_disk :
149+ if self .offload_to_disk_path :
154150 if self .stream is not None :
155151 # Wait for previous Host->Device transfer to complete
156152 self .stream .synchronize ()
@@ -219,7 +215,7 @@ def onload_(self):
219215 @torch .compiler .disable ()
220216 def offload_ (self ):
221217 r"""Offloads the group of modules to the offload_device."""
222- if self .offload_to_disk :
218+ if self .offload_to_disk_path :
223219 if not self ._is_offloaded_to_disk :
224220 os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
225221 tensors_to_save = {
@@ -419,8 +415,7 @@ def apply_group_offloading(
419415 onload_device : torch .device ,
420416 offload_device : torch .device = torch .device ("cpu" ),
421417 offload_type : str = "block_level" ,
422- offload_to_disk : bool = False ,
423- offload_path : Optional [str ] = None ,
418+ offload_to_disk_path : Optional [str ] = None ,
424419 num_blocks_per_group : Optional [int ] = None ,
425420 non_blocking : bool = False ,
426421 use_stream : bool = False ,
@@ -464,11 +459,8 @@ def apply_group_offloading(
464459 offload_type (`str`, defaults to "block_level"):
465460 The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
466461 "block_level".
467- offload_to_disk (`bool`, defaults to `False`):
468- If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited.
469- Requires `offload_path` to be set.
470- offload_path (`str`, *optional*):
471- The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`.
462+ offload_to_disk_path (`str`, *optional*):
463+ The path to the directory where offloaded parameters will be stored.
472464 num_blocks_per_group (`int`, *optional*):
473465 The number of blocks per group when using offload_type="block_level". This is required when using
474466 offload_type="block_level".
@@ -486,6 +478,8 @@ def apply_group_offloading(
486478 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
487479 the CPU memory is a bottleneck but may counteract the benefits of using streams.
488480
481+ (TODO: include example with `offload_to_disk_path`)
482+
489483 Example:
490484 ```python
491485 >>> from diffusers import CogVideoXTransformer3DModel
@@ -514,8 +508,6 @@ def apply_group_offloading(
514508 stream = torch .Stream ()
515509 else :
516510 raise ValueError ("Using streams for data transfer requires a CUDA device, or an Intel XPU device." )
517- if offload_to_disk and offload_path is None :
518- raise ValueError ("`offload_path` must be set when `offload_to_disk=True`." )
519511
520512 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
521513
@@ -528,8 +520,7 @@ def apply_group_offloading(
528520 num_blocks_per_group = num_blocks_per_group ,
529521 offload_device = offload_device ,
530522 onload_device = onload_device ,
531- offload_to_disk = offload_to_disk ,
532- offload_path = offload_path ,
523+ offload_to_disk_path = offload_to_disk_path ,
533524 non_blocking = non_blocking ,
534525 stream = stream ,
535526 record_stream = record_stream ,
@@ -540,8 +531,7 @@ def apply_group_offloading(
540531 module = module ,
541532 offload_device = offload_device ,
542533 onload_device = onload_device ,
543- offload_to_disk = offload_to_disk ,
544- offload_path = offload_path ,
534+ offload_to_disk_path = offload_to_disk_path ,
545535 non_blocking = non_blocking ,
546536 stream = stream ,
547537 record_stream = record_stream ,
@@ -555,8 +545,7 @@ def _apply_group_offloading_block_level(
555545 module : torch .nn .Module ,
556546 num_blocks_per_group : int ,
557547 offload_device : torch .device ,
558- offload_to_disk : bool ,
559- offload_path : Optional [str ],
548+ offload_to_disk_path : Optional [str ],
560549 onload_device : torch .device ,
561550 non_blocking : bool ,
562551 stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
@@ -572,6 +561,7 @@ def _apply_group_offloading_block_level(
572561 The module to which group offloading is applied.
573562 offload_device (`torch.device`):
574563 The device to which the group of modules are offloaded. This should typically be the CPU.
564+ offload_to_disk_path: TODO
575565 onload_device (`torch.device`):
576566 The device to which the group of modules are onloaded.
577567 non_blocking (`bool`):
@@ -611,8 +601,7 @@ def _apply_group_offloading_block_level(
611601 modules = current_modules ,
612602 offload_device = offload_device ,
613603 onload_device = onload_device ,
614- offload_to_disk = offload_to_disk ,
615- offload_path = offload_path ,
604+ offload_to_disk_path = offload_to_disk_path ,
616605 offload_leader = current_modules [- 1 ],
617606 onload_leader = current_modules [0 ],
618607 non_blocking = non_blocking ,
@@ -645,8 +634,7 @@ def _apply_group_offloading_block_level(
645634 modules = unmatched_modules ,
646635 offload_device = offload_device ,
647636 onload_device = onload_device ,
648- offload_to_disk = offload_to_disk ,
649- offload_path = offload_path ,
637+ offload_to_disk_path = offload_to_disk_path ,
650638 offload_leader = module ,
651639 onload_leader = module ,
652640 parameters = parameters ,
@@ -666,8 +654,7 @@ def _apply_group_offloading_leaf_level(
666654 module : torch .nn .Module ,
667655 offload_device : torch .device ,
668656 onload_device : torch .device ,
669- offload_to_disk : bool ,
670- offload_path : Optional [str ],
657+ offload_to_disk_path : Optional [str ],
671658 non_blocking : bool ,
672659 stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
673660 record_stream : Optional [bool ] = False ,
@@ -686,6 +673,7 @@ def _apply_group_offloading_leaf_level(
686673 The device to which the group of modules are offloaded. This should typically be the CPU.
687674 onload_device (`torch.device`):
688675 The device to which the group of modules are onloaded.
676+ offload_to_disk_path: TODO
689677 non_blocking (`bool`):
690678 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
691679 and data transfer.
@@ -711,8 +699,7 @@ def _apply_group_offloading_leaf_level(
711699 modules = [submodule ],
712700 offload_device = offload_device ,
713701 onload_device = onload_device ,
714- offload_to_disk = offload_to_disk ,
715- offload_path = offload_path ,
702+ offload_to_disk_path = offload_to_disk_path ,
716703 offload_leader = submodule ,
717704 onload_leader = submodule ,
718705 non_blocking = non_blocking ,
@@ -759,8 +746,7 @@ def _apply_group_offloading_leaf_level(
759746 onload_device = onload_device ,
760747 offload_leader = parent_module ,
761748 onload_leader = parent_module ,
762- offload_to_disk = offload_to_disk ,
763- offload_path = offload_path ,
749+ offload_to_disk_path = offload_to_disk_path ,
764750 parameters = parameters ,
765751 buffers = buffers ,
766752 non_blocking = non_blocking ,
@@ -779,8 +765,7 @@ def _apply_group_offloading_leaf_level(
779765 modules = [],
780766 offload_device = offload_device ,
781767 onload_device = onload_device ,
782- offload_to_disk = offload_to_disk ,
783- offload_path = offload_path ,
768+ offload_to_disk_path = offload_to_disk_path ,
784769 offload_leader = module ,
785770 onload_leader = module ,
786771 parameters = None ,
0 commit comments