-
Notifications
You must be signed in to change notification settings - Fork 173
Fix gradient checkpointing layer caller kwargs #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical 'double backward' error affecting the Whisper model when used with 'transformers' versions after 4.51. The issue stemmed from a refactor in 'transformers'' gradient checkpointing mechanism, where certain modules began passing arguments as keywords to their internal layers. The changes introduce a targeted compiler patch that intelligently strips these keyword arguments during the compilation process, ensuring the Whisper model can correctly utilize gradient checkpointing without encountering runtime failures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix for gradient checkpointing with Whisper models by stripping keyword arguments from layer calls. This is achieved by adding a new patching function, patch_gradient_checkpointing_layer_caller, and a helper, strip_kw_from_module_calls. The changes are well-contained, and the logic to select the new patching mechanism for specific modules like WhisperDecoder is correctly implemented. My review includes a couple of suggestions to improve the robustness of the new code.
| call_pattern = re.compile( | ||
| rf""" | ||
| (^[ \t]+) | ||
| (\w+)\s*=\s* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The regex pattern to capture the output variable (\w+) is not robust enough to handle tuple unpacking (e.g., (out1, out2) = layer(...)) or chained assignments. This could cause the patching to fail for models that use such patterns.
Using ([^=]+?) instead will non-greedily match any character except = one or more times, making it capable of handling more complex assignment targets.
| (\w+)\s*=\s* | |
| ([^=]+?)\s*=\s* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good suggestion if we need it in the future, but for the current use case it's not needed.
| try: init = inspect.getsource(source.__init__) | ||
| except: return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except: is generally discouraged as it can catch and silence unexpected errors, making debugging harder. It's better to specify the exceptions you expect to handle.
inspect.getsource() is documented to raise TypeError or OSError. Catching these specific exceptions makes the code more robust and clearer about its intent.
This also applies to the try...except block for getting the forward source on lines 1622-1623.
| try: init = inspect.getsource(source.__init__) | |
| except: return None | |
| try: init = inspect.getsource(source.__init__) | |
| except (TypeError, OSError): return None |
The whisper model is failing due to double backward.
It's because after 4.51 there was a gradient checkpointing refactor in the transformers code base, and WhisperDecoder calls the decoder_layer with some args passed as keywords. This adds a patch to the compiler to account for this Module and allows us to maintain a list of modules with modern gradient checkpointing styles that need to be fixed.
Whisper Notebook that works with fix branch:
https://colab.research.google.com/drive/1GYlOwXo_4zjA5gT0jzvCNDl9NKEuzEfT?usp=sharing