Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,16 @@ def create_standalone_class(
disable = False,
add_loss_kwargs = False,
new_init = None,
new_methods = None,
) -> str:
"""
new_methods: dict[str, str] = {
"method_name": "method_source",
}
method_name needs to be a valid attribute of the module class and
method_source is the source code of the method it will be an exact string
replacement so indentation and whitespace should be handled ahead of time!
"""
# All Unsloth Zoo code licensed under LGPLv3
# Create optimized standalone forward function
f = eval(f"{model_location}.{module}")
Expand Down Expand Up @@ -783,6 +792,16 @@ def create_standalone_class(
if new_init is not None:
full_class = full_class.replace(old_init, new_init)

# New methods as well
if new_methods is not None and isinstance(new_methods, dict):
for method_name, method_source in new_methods.items():
try:
old_method_source = inspect.getsource(getattr(f, method_name))
full_class = full_class.replace(old_method_source, method_source)
except Exception as e:
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}")
Comment on lines +801 to +803
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to occur here. For instance, getattr can raise AttributeError, and inspect.getsource can raise TypeError or OSError. Catching these specific exceptions would provide more precise error handling.

Suggested change
except Exception as e:
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}")
except (AttributeError, TypeError, OSError) as e:
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}")

# Combine all into file
source = source + full_class

Expand Down Expand Up @@ -1591,6 +1610,27 @@ def patch_finfo_attention_mask_dtype_mismatch(module, source):
return source
pass

MOE_ROUTING_WEIGHTS_CAST_PATTERN = r"(\brouting_weights\s*=\s*routing_weights\.to\(\s*)hidden_states(\.dtype\s*\))"
MOE_ROUTING_WEIGHTS_CAST_REPLACE = r"\1router_logits\2"

def patch_moe_routing_weights_cast(module_cls: Any, source: str) -> Tuple[str, Dict[str, str]]:
new_route_sources = {}
for method_name, obj in module_cls.__dict__.items():
if isinstance(obj, (staticmethod, classmethod)):
func = obj.__func__
elif isinstance(obj, types.FunctionType):
func = obj
else:
continue

new_route_source = inspect.getsource(func)
new_route_source, replaced_count = re.subn(MOE_ROUTING_WEIGHTS_CAST_PATTERN, MOE_ROUTING_WEIGHTS_CAST_REPLACE, new_route_source)
if replaced_count > 0:
new_route_sources[method_name] = new_route_source

return re.sub(MOE_ROUTING_WEIGHTS_CAST_PATTERN, MOE_ROUTING_WEIGHTS_CAST_REPLACE, source), new_route_sources
pass

# Torch.compiling makes things slower - rather just leave it as addmm
COMPILED_LORA_FORWARD = """
torch_addmm = torch.addmm
Expand Down Expand Up @@ -2181,6 +2221,8 @@ def replaced_tqdm(*args, **kwargs):
gradient_checkpointed_modules = []
scaled_dot_product_attention_modules = []
full_attention_modules = []
router_logit_cast_modules = []

for module in torch_modules:
source = eval(f"modeling_file.{module}")
try: source = inspect.getsource(source)
Expand All @@ -2197,6 +2239,8 @@ def replaced_tqdm(*args, **kwargs):
pass
else:
full_attention_modules.append(module)
elif "routing_weights.to" in source:
router_logit_cast_modules.append(module)
pass
removal = set(
scaled_dot_product_attention_modules + \
Expand Down Expand Up @@ -2582,7 +2626,6 @@ def replaced_tqdm(*args, **kwargs):
pass
pass

# torch.finfo fix for transformers > 4.52.4 affect qwen2vl, qwen25vl, and glm4vl
for module in other_classes:
if module in all_standalone_classes:
source = all_standalone_classes[module]
Expand All @@ -2592,7 +2635,10 @@ def replaced_tqdm(*args, **kwargs):
source = inspect.getsource(module_cls.forward)
else:
continue
# torch.finfo fix for transformers > 4.52.4 affect qwen2vl, qwen25vl, and glm4vl
# Note: check if this is still valid for todays transformers
new_source = patch_finfo_attention_mask_dtype_mismatch(module, source)

if new_source != source:
try:
new_module = create_standalone_class(
Expand All @@ -2611,6 +2657,35 @@ def replaced_tqdm(*args, **kwargs):
pass
pass

if len(router_logit_cast_modules) > 0:
for module in router_logit_cast_modules:
module_cls = eval(f"{model_location}.{module}")
if hasattr(module_cls, "forward"):
source = inspect.getsource(module_cls.forward)
else:
continue

# MOE routing weights cast fix takes effect in v5
new_source, new_methods = patch_moe_routing_weights_cast(module_cls, source)
if new_source != source or len(new_methods) > 0:
try:
new_module = create_standalone_class(
module,
model_location,
functions,
fullgraph = False,
disable = True,
forward_source = new_source,
new_methods = new_methods,
)
all_standalone_classes[module] = new_module
print(f"Unsloth: Patched {module} by casting routing_weights to router_logits dtype")
except Exception as e:
print(f"Unsloth: Failed casting routing_weights to router_logits dtype in {module} with error = {str(e)}")
Comment on lines +2683 to +2684
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to other parts of the code, this try...except block uses a broad except Exception. This can make debugging difficult by catching and silencing unexpected errors. Consider catching more specific exceptions that might be raised by create_standalone_class. If the goal is to catch any failure, it would be beneficial to log the full traceback when logging is enabled to get more context on the failure.

pass
pass
pass

# Manually replace hand written parts
if manual_replacements:
for module in compiler_replacements:
Expand Down