-
Notifications
You must be signed in to change notification settings - Fork 173
Tiled MLP Implementation #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This feature adds the option to enable tiled mlp from the arctic long sequence training paper: https://arxiv.org/abs/2506.13996.
Summary of ChangesHello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization for handling long sequences in neural networks by implementing a "tiled MLP" mechanism. Inspired by recent research, this feature allows Multi-Layer Perceptrons to process inputs in smaller, memory-manageable chunks, thereby reducing peak GPU memory usage. This enables training with larger sequence lengths or batch sizes that would otherwise be constrained by memory limitations, ultimately improving the scalability and efficiency of model training. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Tiled MLP implementation, a technique to reduce memory usage for long sequence training, based on the Arctic paper. The implementation involves a custom torch.autograd.Function for the tiled forward and backward passes, and monkey-patching utilities to apply this to existing models.
My review identified a critical bug in the logic that splits tensors into chunks, which could lead to crashes or incorrect behavior. I've also pointed out a couple of places where error handling can be made more specific by avoiding bare except clauses. Addressing these points will improve the robustness and correctness of this new feature.
| qlen_chunk_size, remainder = divmod(B*S, min(max(1, num_shards), B*S)) | ||
| split_sizes = [qlen_chunk_size]*num_shards | ||
| if remainder != 0: split_sizes.append(remainder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for calculating split_sizes has several critical issues:
- It can cause a
ZeroDivisionErrorif the input tensorxis empty (B*S == 0). - If
num_shardsis greater thanB*S, it creates asplit_sizeslist whose sum is not equal toB*S, which will causetorch.splitto fail. - In normal cases, it can create
num_shards + 1splits instead of the requestednum_shardswhen there is a remainder.
The suggested change fixes these issues by correctly handling the edge cases and properly distributing the elements into the effective number of splits.
| qlen_chunk_size, remainder = divmod(B*S, min(max(1, num_shards), B*S)) | |
| split_sizes = [qlen_chunk_size]*num_shards | |
| if remainder != 0: split_sizes.append(remainder) | |
| effective_num_shards = min(max(1, num_shards), B*S) if B*S > 0 else 0 | |
| q, r = divmod(B*S, effective_num_shards) if effective_num_shards > 0 else (0, 0) | |
| split_sizes = [q + 1] * r + [q] * (effective_num_shards - r) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in subsequent commit. n_shards is a request for that many but if it's not evenly divisible it will result in num_shards+1 chunks.
| flat_qlen = bsz*qlen | ||
| try: | ||
| intermediate_size = mlp_module.config.intermediate_size | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if len(patch_options_strs) > 1: | ||
| try: | ||
| target_gb = float(patch_options_strs[-1]) | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feature adds the option to enable tiled mlp from the arctic long sequence training paper: https://arxiv.org/abs/2506.13996.