Skip to content

Conversation

@mmathew23
Copy link
Collaborator

The whisper model is failing due to double backward.

It's because after 4.51 there was a gradient checkpointing refactor in the transformers code base, and WhisperDecoder calls the decoder_layer with some args passed as keywords. This adds a patch to the compiler to account for this Module and allows us to maintain a list of modules with modern gradient checkpointing styles that need to be fixed.

Whisper Notebook that works with fix branch:
https://colab.research.google.com/drive/1GYlOwXo_4zjA5gT0jzvCNDl9NKEuzEfT?usp=sharing

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical 'double backward' error affecting the Whisper model when used with 'transformers' versions after 4.51. The issue stemmed from a refactor in 'transformers'' gradient checkpointing mechanism, where certain modules began passing arguments as keywords to their internal layers. The changes introduce a targeted compiler patch that intelligently strips these keyword arguments during the compilation process, ensuring the Whisper model can correctly utilize gradient checkpointing without encountering runtime failures.

Highlights

  • Fixes Whisper Model Double Backward: Addresses a 'double backward' failure in the Whisper model caused by a 'transformers' gradient checkpointing refactor post-4.51, which led to incorrect argument passing.
  • Introduces Keyword Argument Stripping: Adds a new utility function, 'strip_kw_from_module_calls', which uses regular expressions to identify and remove keyword arguments from module calls within source code.
  • New Gradient Checkpointing Patching Logic: Implements 'patch_gradient_checkpointing_layer_caller' to specifically handle modules that pass arguments as keywords to their decoder layers, ensuring compatibility with modern gradient checkpointing styles.
  • Targets Specific Modules: Introduces 'FIX_GC_LAYER_CALLER_MODULES' list, initially containing 'WhisperDecoder', to apply the new keyword argument stripping patch only where necessary.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for gradient checkpointing with Whisper models by stripping keyword arguments from layer calls. This is achieved by adding a new patching function, patch_gradient_checkpointing_layer_caller, and a helper, strip_kw_from_module_calls. The changes are well-contained, and the logic to select the new patching mechanism for specific modules like WhisperDecoder is correctly implemented. My review includes a couple of suggestions to improve the robustness of the new code.

call_pattern = re.compile(
rf"""
(^[ \t]+)
(\w+)\s*=\s*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The regex pattern to capture the output variable (\w+) is not robust enough to handle tuple unpacking (e.g., (out1, out2) = layer(...)) or chained assignments. This could cause the patching to fail for models that use such patterns.

Using ([^=]+?) instead will non-greedily match any character except = one or more times, making it capable of handling more complex assignment targets.

Suggested change
(\w+)\s*=\s*
([^=]+?)\s*=\s*
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good suggestion if we need it in the future, but for the current use case it's not needed.

Comment on lines +1619 to +1620
try: init = inspect.getsource(source.__init__)
except: return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: is generally discouraged as it can catch and silence unexpected errors, making debugging harder. It's better to specify the exceptions you expect to handle.

inspect.getsource() is documented to raise TypeError or OSError. Catching these specific exceptions makes the code more robust and clearer about its intent.

This also applies to the try...except block for getting the forward source on lines 1622-1623.

Suggested change
try: init = inspect.getsource(source.__init__)
except: return None
try: init = inspect.getsource(source.__init__)
except (TypeError, OSError): return None
@mmathew23
Copy link
Collaborator Author

@shimmyshimmer shimmyshimmer merged commit 87c0a40 into unslothai:main Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants