Skip to content

Commit d8179b1

Browse files
committed
offload_to_disk_path
1 parent 0bf55a9 commit d8179b1

File tree

2 files changed

+24
-41
lines changed

2 files changed

+24
-41
lines changed

‎src/diffusers/hooks/group_offloading.py‎

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def __init__(
6161
record_stream: Optional[bool] = False,
6262
low_cpu_mem_usage: bool = False,
6363
onload_self: bool = True,
64-
offload_to_disk: bool = False,
65-
offload_path: Optional[str] = None,
64+
offload_to_disk_path: Optional[str] = None,
6665
) -> None:
6766
self.modules = modules
6867
self.offload_device = offload_device
@@ -77,14 +76,11 @@ def __init__(
7776
self.onload_self = onload_self
7877
self.low_cpu_mem_usage = low_cpu_mem_usage
7978

80-
self.offload_to_disk = offload_to_disk
81-
self.offload_path = offload_path
79+
self.offload_to_disk_path = offload_to_disk_path
8280
self._is_offloaded_to_disk = False
8381

84-
if self.offload_to_disk:
85-
if self.offload_path is None:
86-
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
87-
self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors")
82+
if self.offload_to_disk_path:
83+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
8884

8985
all_tensors = []
9086
for module in self.modules:
@@ -150,7 +146,7 @@ def onload_(self):
150146
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
151147
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
152148

153-
if self.offload_to_disk:
149+
if self.offload_to_disk_path:
154150
if self.stream is not None:
155151
# Wait for previous Host->Device transfer to complete
156152
self.stream.synchronize()
@@ -219,7 +215,7 @@ def onload_(self):
219215
@torch.compiler.disable()
220216
def offload_(self):
221217
r"""Offloads the group of modules to the offload_device."""
222-
if self.offload_to_disk:
218+
if self.offload_to_disk_path:
223219
if not self._is_offloaded_to_disk:
224220
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
225221
tensors_to_save = {
@@ -419,8 +415,7 @@ def apply_group_offloading(
419415
onload_device: torch.device,
420416
offload_device: torch.device = torch.device("cpu"),
421417
offload_type: str = "block_level",
422-
offload_to_disk: bool = False,
423-
offload_path: Optional[str] = None,
418+
offload_to_disk_path: Optional[str] = None,
424419
num_blocks_per_group: Optional[int] = None,
425420
non_blocking: bool = False,
426421
use_stream: bool = False,
@@ -464,11 +459,8 @@ def apply_group_offloading(
464459
offload_type (`str`, defaults to "block_level"):
465460
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
466461
"block_level".
467-
offload_to_disk (`bool`, defaults to `False`):
468-
If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited.
469-
Requires `offload_path` to be set.
470-
offload_path (`str`, *optional*):
471-
The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`.
462+
offload_to_disk_path (`str`, *optional*):
463+
The path to the directory where offloaded parameters will be stored.
472464
num_blocks_per_group (`int`, *optional*):
473465
The number of blocks per group when using offload_type="block_level". This is required when using
474466
offload_type="block_level".
@@ -486,6 +478,8 @@ def apply_group_offloading(
486478
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
487479
the CPU memory is a bottleneck but may counteract the benefits of using streams.
488480
481+
(TODO: include example with `offload_to_disk_path`)
482+
489483
Example:
490484
```python
491485
>>> from diffusers import CogVideoXTransformer3DModel
@@ -514,8 +508,6 @@ def apply_group_offloading(
514508
stream = torch.Stream()
515509
else:
516510
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
517-
if offload_to_disk and offload_path is None:
518-
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
519511

520512
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
521513

@@ -528,8 +520,7 @@ def apply_group_offloading(
528520
num_blocks_per_group=num_blocks_per_group,
529521
offload_device=offload_device,
530522
onload_device=onload_device,
531-
offload_to_disk=offload_to_disk,
532-
offload_path=offload_path,
523+
offload_to_disk_path=offload_to_disk_path,
533524
non_blocking=non_blocking,
534525
stream=stream,
535526
record_stream=record_stream,
@@ -540,8 +531,7 @@ def apply_group_offloading(
540531
module=module,
541532
offload_device=offload_device,
542533
onload_device=onload_device,
543-
offload_to_disk=offload_to_disk,
544-
offload_path=offload_path,
534+
offload_to_disk_path=offload_to_disk_path,
545535
non_blocking=non_blocking,
546536
stream=stream,
547537
record_stream=record_stream,
@@ -555,8 +545,7 @@ def _apply_group_offloading_block_level(
555545
module: torch.nn.Module,
556546
num_blocks_per_group: int,
557547
offload_device: torch.device,
558-
offload_to_disk: bool,
559-
offload_path: Optional[str],
548+
offload_to_disk_path: Optional[str],
560549
onload_device: torch.device,
561550
non_blocking: bool,
562551
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
@@ -572,6 +561,7 @@ def _apply_group_offloading_block_level(
572561
The module to which group offloading is applied.
573562
offload_device (`torch.device`):
574563
The device to which the group of modules are offloaded. This should typically be the CPU.
564+
offload_to_disk_path: TODO
575565
onload_device (`torch.device`):
576566
The device to which the group of modules are onloaded.
577567
non_blocking (`bool`):
@@ -611,8 +601,7 @@ def _apply_group_offloading_block_level(
611601
modules=current_modules,
612602
offload_device=offload_device,
613603
onload_device=onload_device,
614-
offload_to_disk=offload_to_disk,
615-
offload_path=offload_path,
604+
offload_to_disk_path=offload_to_disk_path,
616605
offload_leader=current_modules[-1],
617606
onload_leader=current_modules[0],
618607
non_blocking=non_blocking,
@@ -645,8 +634,7 @@ def _apply_group_offloading_block_level(
645634
modules=unmatched_modules,
646635
offload_device=offload_device,
647636
onload_device=onload_device,
648-
offload_to_disk=offload_to_disk,
649-
offload_path=offload_path,
637+
offload_to_disk_path=offload_to_disk_path,
650638
offload_leader=module,
651639
onload_leader=module,
652640
parameters=parameters,
@@ -666,8 +654,7 @@ def _apply_group_offloading_leaf_level(
666654
module: torch.nn.Module,
667655
offload_device: torch.device,
668656
onload_device: torch.device,
669-
offload_to_disk: bool,
670-
offload_path: Optional[str],
657+
offload_to_disk_path: Optional[str],
671658
non_blocking: bool,
672659
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
673660
record_stream: Optional[bool] = False,
@@ -686,6 +673,7 @@ def _apply_group_offloading_leaf_level(
686673
The device to which the group of modules are offloaded. This should typically be the CPU.
687674
onload_device (`torch.device`):
688675
The device to which the group of modules are onloaded.
676+
offload_to_disk_path: TODO
689677
non_blocking (`bool`):
690678
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
691679
and data transfer.
@@ -711,8 +699,7 @@ def _apply_group_offloading_leaf_level(
711699
modules=[submodule],
712700
offload_device=offload_device,
713701
onload_device=onload_device,
714-
offload_to_disk=offload_to_disk,
715-
offload_path=offload_path,
702+
offload_to_disk_path=offload_to_disk_path,
716703
offload_leader=submodule,
717704
onload_leader=submodule,
718705
non_blocking=non_blocking,
@@ -759,8 +746,7 @@ def _apply_group_offloading_leaf_level(
759746
onload_device=onload_device,
760747
offload_leader=parent_module,
761748
onload_leader=parent_module,
762-
offload_to_disk=offload_to_disk,
763-
offload_path=offload_path,
749+
offload_to_disk_path=offload_to_disk_path,
764750
parameters=parameters,
765751
buffers=buffers,
766752
non_blocking=non_blocking,
@@ -779,8 +765,7 @@ def _apply_group_offloading_leaf_level(
779765
modules=[],
780766
offload_device=offload_device,
781767
onload_device=onload_device,
782-
offload_to_disk=offload_to_disk,
783-
offload_path=offload_path,
768+
offload_to_disk_path=offload_to_disk_path,
784769
offload_leader=module,
785770
onload_leader=module,
786771
parameters=None,

‎src/diffusers/models/modeling_utils.py‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,7 @@ def enable_group_offload(
543543
onload_device: torch.device,
544544
offload_device: torch.device = torch.device("cpu"),
545545
offload_type: str = "block_level",
546-
offload_to_disk: bool = False,
547-
offload_path: Optional[str] = None,
546+
offload_to_disk_path: Optional[str] = None,
548547
num_blocks_per_group: Optional[int] = None,
549548
non_blocking: bool = False,
550549
use_stream: bool = False,
@@ -599,8 +598,7 @@ def enable_group_offload(
599598
use_stream=use_stream,
600599
record_stream=record_stream,
601600
low_cpu_mem_usage=low_cpu_mem_usage,
602-
offload_to_disk=offload_to_disk,
603-
offload_path=offload_path,
601+
offload_to_disk_path=offload_to_disk_path,
604602
)
605603

606604
def save_pretrained(

0 commit comments

Comments
 (0)