-
Notifications
You must be signed in to change notification settings - Fork 173
MoE: Cast routing_weights dtype correctly #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MoE: Cast routing_weights dtype correctly #349
Conversation
Summary of ChangesHello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a crucial fix for Mixture of Experts (MoE) models by correcting a dtype mismatch during the compilation process. Previously, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a mechanism to patch methods in compiled classes, which is then used to fix an issue with routing_weights dtype casting in MoE models. The changes are logical and address the described problem. I've identified a few areas for improvement, mainly concerning error handling and type hint correctness. Specifically, using more specific exception handling instead of broad except Exception would improve robustness, and correcting a return type hint will enhance code clarity and maintainability.
| except Exception as e: | ||
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | ||
| print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad except Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to occur here. For instance, getattr can raise AttributeError, and inspect.getsource can raise TypeError or OSError. Catching these specific exceptions would provide more precise error handling.
| except Exception as e: | |
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | |
| print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}") | |
| except (AttributeError, TypeError, OSError) as e: | |
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | |
| print(f"Unsloth: Failed to replace method {method_name} in {module} with error = {str(e)}") |
| except Exception as e: | ||
| print(f"Unsloth: Failed casting routing_weights to router_logits dtype in {module} with error = {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to other parts of the code, this try...except block uses a broad except Exception. This can make debugging difficult by catching and silencing unexpected errors. Consider catching more specific exceptions that might be raised by create_standalone_class. If the goal is to catch any failure, it would be beneficial to log the full traceback when logging is enabled to get more context on the failure.
add correct type hint Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Update compiled to cast routing_weight tensors to router_logits.dtype instead of hidden_states.dtype.
This will patch the forward method, and any other methods that cast routing_weights to hidden_states.dtype.
Qwen3 Moe Test:
https://colab.research.google.com/drive/1z9KopzSJylyM3g8KJ3bnvvpo97LISHTm?usp=sharing