@@ -1562,12 +1562,76 @@ def patch_gradient_checkpointing(module, source):
15621562 .replace ("ARGS" , args ).replace ("$" , spaces )
15631563 forward = forward .replace (forward [span [0 ] : span [1 ]], replacer )
15641564
1565+ # Confirm no equal signs seen - might be "attention_mask=causal_mask_mapping" vs "attention_mask=attention_mask"
1566+ if '=' in args :
1567+ return None
15651568 # Also fix init
15661569 spaces = init .find ("def" )
15671570 init = init + "\n " + (spaces + 4 ) * " " + "self.gradient_checkpointing = False\n \n "
15681571
1569- # Confirm no equal signs seen - might be "attention_mask=causal_mask_mapping" vs "attention_mask=attention_mask"
1570- if "=" in init : return None
1572+ return init , forward
1573+ pass
1574+
1575+ def strip_kw_from_module_calls (src : str , modulelist_item : str ) -> str :
1576+ for_pattern = re .compile (
1577+ rf"for (?:[^\s,]+,\s*)?(?P<layer>\w+)\s+in\s+"
1578+ rf"(?:enumerate\({ re .escape (modulelist_item )} \)|{ re .escape (modulelist_item )} )\s*:" ,
1579+ re .MULTILINE ,
1580+ )
1581+ layer_vars = {m .group ("layer" ) for m in for_pattern .finditer (src )}
1582+ if not layer_vars :
1583+ return src
1584+
1585+ kw_at_start_pattern = re .compile (
1586+ r'(^|,)(\s*)([A-Za-z_]\w*)\s*=\s*' ,
1587+ re .MULTILINE ,
1588+ )
1589+
1590+ def strip_kw_names (args : str ) -> str :
1591+ return kw_at_start_pattern .sub (r'\1\2' , args )
1592+
1593+ for layer in layer_vars :
1594+ call_pattern = re .compile (
1595+ rf"""
1596+ (^[ \t]+)
1597+ (\w+)\s*=\s*
1598+ { re .escape (layer )}
1599+ \(
1600+ (
1601+ [^)]*?
1602+ )
1603+ \)
1604+ """ ,
1605+ re .MULTILINE | re .DOTALL | re .VERBOSE ,
1606+ )
1607+
1608+ def replace_call (m : re .Match ) -> str :
1609+ indent , outvar , args = m .group (1 ), m .group (2 ), m .group (3 )
1610+ new_args = strip_kw_names (args )
1611+ return f"{ indent } { outvar } = { layer } ({ new_args } )"
1612+
1613+ src = call_pattern .sub (replace_call , src )
1614+
1615+ return src
1616+
1617+ def patch_gradient_checkpointing_layer_caller (module , source ):
1618+ # All Unsloth Zoo code licensed under LGPLv3
1619+ try : init = inspect .getsource (source .__init__ )
1620+ except : return None
1621+ if "nn.ModuleList" not in init : return None
1622+ try : forward = inspect .getsource (source .forward )
1623+ except : return None
1624+ if "_gradient_checkpointing_func" in forward : return None
1625+
1626+ modulelist_items = re .findall (r"(self\.[^\s]{1,}) = .*?nn\.ModuleList\(" , init )
1627+ if len (modulelist_items ) != 1 : return None
1628+ modulelist_item = modulelist_items [0 ]
1629+
1630+ forward = strip_kw_from_module_calls (forward , modulelist_item )
1631+ spaces = init .find ("def" )
1632+ if 'self.gradient_checkpointing =' not in init :
1633+ init = init + "\n " + (spaces + 4 ) * " " + "self.gradient_checkpointing = False\n \n "
1634+
15711635 return init , forward
15721636pass
15731637
@@ -2036,6 +2100,10 @@ def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False):
20362100 "Gemma3nTextModel" ,
20372101]
20382102
2103+ FIX_GC_LAYER_CALLER_MODULES = [
2104+ "WhisperDecoder" ,
2105+ ]
2106+
20392107
20402108def unsloth_compile_transformers (
20412109 model_type : str = "llama" ,
@@ -2602,9 +2670,13 @@ def replaced_tqdm(*args, **kwargs):
26022670 for module in other_classes :
26032671 source = eval (f"{ model_location } .{ module } " )
26042672 if "(GradientCheckpointingLayer)" in full_source :
2605- # Uses GC layers which is in new transformers - no need to patch
2606- continue
2607- output = patch_gradient_checkpointing (module , source )
2673+ if module in FIX_GC_LAYER_CALLER_MODULES :
2674+ output = patch_gradient_checkpointing_layer_caller (module , source )
2675+ else :
2676+ # Uses GC layers which is in new transformers - no need to patch
2677+ continue
2678+ else :
2679+ output = patch_gradient_checkpointing (module , source )
26082680 if output is None : continue
26092681
26102682 init , forward = output
0 commit comments