Skip to content

Commit 87c0a40

Browse files
Merge pull request #353 from mmathew23/fix/whisper
Fix gradient checkpointing layer caller kwargs
2 parents c9c7693 + 6ab47cf commit 87c0a40

File tree

1 file changed

+77
-5
lines changed

1 file changed

+77
-5
lines changed

‎unsloth_zoo/compiler.py‎

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,12 +1562,76 @@ def patch_gradient_checkpointing(module, source):
15621562
.replace("ARGS", args).replace("$", spaces)
15631563
forward = forward.replace(forward[span[0] : span[1]], replacer)
15641564

1565+
# Confirm no equal signs seen - might be "attention_mask=causal_mask_mapping" vs "attention_mask=attention_mask"
1566+
if '=' in args:
1567+
return None
15651568
# Also fix init
15661569
spaces = init.find("def")
15671570
init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n"
15681571

1569-
# Confirm no equal signs seen - might be "attention_mask=causal_mask_mapping" vs "attention_mask=attention_mask"
1570-
if "=" in init: return None
1572+
return init, forward
1573+
pass
1574+
1575+
def strip_kw_from_module_calls(src: str, modulelist_item: str) -> str:
1576+
for_pattern = re.compile(
1577+
rf"for (?:[^\s,]+,\s*)?(?P<layer>\w+)\s+in\s+"
1578+
rf"(?:enumerate\({re.escape(modulelist_item)}\)|{re.escape(modulelist_item)})\s*:",
1579+
re.MULTILINE,
1580+
)
1581+
layer_vars = {m.group("layer") for m in for_pattern.finditer(src)}
1582+
if not layer_vars:
1583+
return src
1584+
1585+
kw_at_start_pattern = re.compile(
1586+
r'(^|,)(\s*)([A-Za-z_]\w*)\s*=\s*',
1587+
re.MULTILINE,
1588+
)
1589+
1590+
def strip_kw_names(args: str) -> str:
1591+
return kw_at_start_pattern.sub(r'\1\2', args)
1592+
1593+
for layer in layer_vars:
1594+
call_pattern = re.compile(
1595+
rf"""
1596+
(^[ \t]+)
1597+
(\w+)\s*=\s*
1598+
{re.escape(layer)}
1599+
\(
1600+
(
1601+
[^)]*?
1602+
)
1603+
\)
1604+
""",
1605+
re.MULTILINE | re.DOTALL | re.VERBOSE,
1606+
)
1607+
1608+
def replace_call(m: re.Match) -> str:
1609+
indent, outvar, args = m.group(1), m.group(2), m.group(3)
1610+
new_args = strip_kw_names(args)
1611+
return f"{indent}{outvar} = {layer}({new_args})"
1612+
1613+
src = call_pattern.sub(replace_call, src)
1614+
1615+
return src
1616+
1617+
def patch_gradient_checkpointing_layer_caller(module, source):
1618+
# All Unsloth Zoo code licensed under LGPLv3
1619+
try: init = inspect.getsource(source.__init__)
1620+
except: return None
1621+
if "nn.ModuleList" not in init: return None
1622+
try: forward = inspect.getsource(source.forward)
1623+
except: return None
1624+
if "_gradient_checkpointing_func" in forward: return None
1625+
1626+
modulelist_items = re.findall(r"(self\.[^\s]{1,}) = .*?nn\.ModuleList\(", init)
1627+
if len(modulelist_items) != 1: return None
1628+
modulelist_item = modulelist_items[0]
1629+
1630+
forward = strip_kw_from_module_calls(forward, modulelist_item)
1631+
spaces = init.find("def")
1632+
if 'self.gradient_checkpointing =' not in init:
1633+
init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n"
1634+
15711635
return init, forward
15721636
pass
15731637

@@ -2036,6 +2100,10 @@ def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False):
20362100
"Gemma3nTextModel",
20372101
]
20382102

2103+
FIX_GC_LAYER_CALLER_MODULES = [
2104+
"WhisperDecoder",
2105+
]
2106+
20392107

20402108
def unsloth_compile_transformers(
20412109
model_type : str = "llama",
@@ -2602,9 +2670,13 @@ def replaced_tqdm(*args, **kwargs):
26022670
for module in other_classes:
26032671
source = eval(f"{model_location}.{module}")
26042672
if "(GradientCheckpointingLayer)" in full_source:
2605-
# Uses GC layers which is in new transformers - no need to patch
2606-
continue
2607-
output = patch_gradient_checkpointing(module, source)
2673+
if module in FIX_GC_LAYER_CALLER_MODULES:
2674+
output = patch_gradient_checkpointing_layer_caller(module, source)
2675+
else:
2676+
# Uses GC layers which is in new transformers - no need to patch
2677+
continue
2678+
else:
2679+
output = patch_gradient_checkpointing(module, source)
26082680
if output is None: continue
26092681

26102682
init, forward = output

0 commit comments

Comments
 (0)