Skip to content

Commit c353bb7

Browse files
committed
Cleaner and error free alloc_conf
1 parent f690a5a commit c353bb7

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

‎unsloth_zoo/__init__.py‎

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,32 +82,37 @@
8282
ALLOW_PREQUANTIZED_MODELS,
8383
)
8484

85-
# Reduce VRAM usage by reducing fragmentation
86-
# And optimize pinning of memory
87-
# TODO(billishyahao): need to add hip related optimization...
88-
if (DEVICE_TYPE in ("cuda", "hip")) and (os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0"):
89-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
90-
"expandable_segments:True,"\
91-
"roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
92-
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True"
93-
elif (DEVICE_TYPE in ("cuda", "hip")) and (os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="1") and \
94-
("expandable_segments:True" in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")):
95-
warnings.warn(
96-
"Unsloth: `UNSLOTH_VLLM_STANDBY` is on, but requires `expandable_segments` to be off.\n"\
97-
"We will remove `expandable_segments`.",
98-
stacklevel = 2,
99-
)
100-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = re.sub(
101-
r"expandable\_segments\:True\,?",
102-
"",
103-
os.environ["PYTORCH_CUDA_ALLOC_CONF"],
104-
)
105-
os.environ["PYTORCH_HIP_ALLOC_CONF"] = re.sub(
106-
r"expandable\_segments\:True\,?",
107-
"",
108-
os.environ["PYTORCH_HIP_ALLOC_CONF"],
109-
)
110-
pass
85+
# Optimize VRAM usage by reducing fragmentation and improving memory pinning.
86+
# TODO(billishyahao): Add HIP-specific optimizations if needed.
87+
def _set_memory_optimizations():
88+
standby = os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "1"
89+
cuda_enabled = DEVICE_TYPE == "cuda"
90+
hip_enabled = DEVICE_TYPE == "hip"
91+
92+
if (cuda_enabled or hip_enabled) and not standby:
93+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
94+
"expandable_segments:True,"
95+
"roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
96+
)
97+
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True"
98+
elif (cuda_enabled or hip_enabled) and standby:
99+
# Remove expandable_segments if UNSLOTH_VLLM_STANDBY is enabled
100+
def _remove_expandable_segments(key):
101+
conf = os.environ.get(key, "")
102+
if "expandable_segments:True" in conf:
103+
os.environ[key] = re.sub(
104+
r"expandable_segments:True,?", "", conf
105+
)
106+
warnings.warn(
107+
"Unsloth: `UNSLOTH_VLLM_STANDBY` is on, but requires `expandable_segments` to be off.\n"\
108+
"We will remove `expandable_segments`.",
109+
stacklevel = 2,
110+
)
111+
_remove_expandable_segments("PYTORCH_CUDA_ALLOC_CONF")
112+
_remove_expandable_segments("PYTORCH_HIP_ALLOC_CONF")
113+
_set_memory_optimizations()
114+
del _set_memory_optimizations
115+
111116
# We support Pytorch 2
112117
# Fixes https://github.com/unslothai/unsloth/issues/38
113118
torch_version = str(re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)).split(".")

0 commit comments

Comments
 (0)