-
Notifications
You must be signed in to change notification settings - Fork 174
MoE: Cast routing_weights dtype correctly #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -716,7 +716,16 @@ def create_standalone_class( | |
| disable = False, | ||
| add_loss_kwargs = False, | ||
| new_init = None, | ||
| new_methods = None, | ||
| ) -> str: | ||
| """ | ||
| new_methods: dict[str, str] = { | ||
| "method_name": "method_source", | ||
| } | ||
| method_name needs to be a valid attribute of the module class and | ||
| method_source is the source code of the method it will be an exact string | ||
| replacement so indentation and whitespace should be handled ahead of time! | ||
| """ | ||
| # All Unsloth Zoo code licensed under LGPLv3 | ||
| # Create optimized standalone forward function | ||
| f = eval(f"{model_location}.{module}") | ||
|
|
@@ -783,6 +792,16 @@ def create_standalone_class( | |
| if new_init is not None: | ||
| full_class = full_class.replace(old_init, new_init) | ||
|
|
||
| # New methods as well | ||
| if new_methods is not None and isinstance(new_methods, dict): | ||
| for method_name, method_source in new_methods.items(): | ||
| try: | ||
| old_method_source = inspect.getsource(getattr(f, method_name)) | ||
| full_class = full_class.replace(old_method_source, method_source) | ||
| except Exception as e: | ||
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | ||
| print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}") | ||
|
|
||
| # Combine all into file | ||
| source = source + full_class | ||
|
|
||
|
|
@@ -1591,6 +1610,27 @@ def patch_finfo_attention_mask_dtype_mismatch(module, source): | |
| return source | ||
| pass | ||
|
|
||
| MOE_ROUTING_WEIGHTS_CAST_PATTERN = r"(\brouting_weights\s*=\s*routing_weights\.to\(\s*)hidden_states(\.dtype\s*\))" | ||
| MOE_ROUTING_WEIGHTS_CAST_REPLACE = r"\1router_logits\2" | ||
|
|
||
| def patch_moe_routing_weights_cast(module_cls: Any, source: str) -> Tuple[str, Dict[str, str]]: | ||
| new_route_sources = {} | ||
| for method_name, obj in module_cls.__dict__.items(): | ||
| if isinstance(obj, (staticmethod, classmethod)): | ||
| func = obj.__func__ | ||
| elif isinstance(obj, types.FunctionType): | ||
| func = obj | ||
| else: | ||
| continue | ||
|
|
||
| new_route_source = inspect.getsource(func) | ||
| new_route_source, replaced_count = re.subn(MOE_ROUTING_WEIGHTS_CAST_PATTERN, MOE_ROUTING_WEIGHTS_CAST_REPLACE, new_route_source) | ||
| if replaced_count > 0: | ||
| new_route_sources[method_name] = new_route_source | ||
|
|
||
| return re.sub(MOE_ROUTING_WEIGHTS_CAST_PATTERN, MOE_ROUTING_WEIGHTS_CAST_REPLACE, source), new_route_sources | ||
| pass | ||
|
|
||
| # Torch.compiling makes things slower - rather just leave it as addmm | ||
| COMPILED_LORA_FORWARD = """ | ||
| torch_addmm = torch.addmm | ||
|
|
@@ -2181,6 +2221,8 @@ def replaced_tqdm(*args, **kwargs): | |
| gradient_checkpointed_modules = [] | ||
| scaled_dot_product_attention_modules = [] | ||
| full_attention_modules = [] | ||
| router_logit_cast_modules = [] | ||
|
|
||
| for module in torch_modules: | ||
| source = eval(f"modeling_file.{module}") | ||
| try: source = inspect.getsource(source) | ||
|
|
@@ -2197,6 +2239,8 @@ def replaced_tqdm(*args, **kwargs): | |
| pass | ||
| else: | ||
| full_attention_modules.append(module) | ||
| elif "routing_weights.to" in source: | ||
| router_logit_cast_modules.append(module) | ||
| pass | ||
| removal = set( | ||
| scaled_dot_product_attention_modules + \ | ||
|
|
@@ -2582,7 +2626,6 @@ def replaced_tqdm(*args, **kwargs): | |
| pass | ||
| pass | ||
|
|
||
| # torch.finfo fix for transformers > 4.52.4 affect qwen2vl, qwen25vl, and glm4vl | ||
| for module in other_classes: | ||
| if module in all_standalone_classes: | ||
| source = all_standalone_classes[module] | ||
|
|
@@ -2592,7 +2635,10 @@ def replaced_tqdm(*args, **kwargs): | |
| source = inspect.getsource(module_cls.forward) | ||
| else: | ||
| continue | ||
| # torch.finfo fix for transformers > 4.52.4 affect qwen2vl, qwen25vl, and glm4vl | ||
| # Note: check if this is still valid for todays transformers | ||
| new_source = patch_finfo_attention_mask_dtype_mismatch(module, source) | ||
|
|
||
| if new_source != source: | ||
| try: | ||
| new_module = create_standalone_class( | ||
|
|
@@ -2611,6 +2657,35 @@ def replaced_tqdm(*args, **kwargs): | |
| pass | ||
| pass | ||
|
|
||
| if len(router_logit_cast_modules) > 0: | ||
| for module in router_logit_cast_modules: | ||
| module_cls = eval(f"{model_location}.{module}") | ||
| if hasattr(module_cls, "forward"): | ||
| source = inspect.getsource(module_cls.forward) | ||
| else: | ||
| continue | ||
|
|
||
| # MOE routing weights cast fix takes effect in v5 | ||
| new_source, new_methods = patch_moe_routing_weights_cast(module_cls, source) | ||
| if new_source != source or len(new_methods) > 0: | ||
| try: | ||
| new_module = create_standalone_class( | ||
| module, | ||
| model_location, | ||
| functions, | ||
| fullgraph = False, | ||
| disable = True, | ||
| forward_source = new_source, | ||
| new_methods = new_methods, | ||
| ) | ||
| all_standalone_classes[module] = new_module | ||
| print(f"Unsloth: Patched {module} by casting routing_weights to router_logits dtype") | ||
| except Exception as e: | ||
| print(f"Unsloth: Failed casting routing_weights to router_logits dtype in {module} with error = {str(e)}") | ||
|
Comment on lines
+2683
to
+2684
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to other parts of the code, this |
||
| pass | ||
| pass | ||
| pass | ||
|
|
||
| # Manually replace hand written parts | ||
| if manual_replacements: | ||
| for module in compiler_replacements: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad
except Exceptioncan hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to occur here. For instance,getattrcan raiseAttributeError, andinspect.getsourcecan raiseTypeErrororOSError. Catching these specific exceptions would provide more precise error handling.