Skip to content

Conversation

@mmathew23
Copy link
Collaborator

This feature adds the option to enable tiled mlp from the arctic long sequence training paper: https://arxiv.org/abs/2506.13996.

This feature adds the option to enable tiled mlp from the arctic long
sequence training paper: https://arxiv.org/abs/2506.13996.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization for handling long sequences in neural networks by implementing a "tiled MLP" mechanism. Inspired by recent research, this feature allows Multi-Layer Perceptrons to process inputs in smaller, memory-manageable chunks, thereby reducing peak GPU memory usage. This enables training with larger sequence lengths or batch sizes that would otherwise be constrained by memory limitations, ultimately improving the scalability and efficiency of model training.

Highlights

  • Tiled MLP Core Logic: Implemented "TiledMLP" as a "torch.autograd.Function" to manage the splitting and reassembly of MLP inputs and outputs, enabling memory-efficient processing for long sequences.
  • Dynamic Tiling Strategies: Introduced two modes for tiling: an "arctic" mode based on sequence length and hidden dimension, and a "target_gb" mode that dynamically adjusts shard sizes to fit within a specified GPU memory limit.
  • Automated MLP Patching: Added utility functions "patch_mlp" and "patch_tiled_mlp" to automatically identify and replace the "forward" method of MLP modules within a given model with the new tiled implementation.
  • RNG State Management: Incorporated logic to correctly preserve and restore random number generator states during the tiled forward and backward passes, ensuring consistent behavior, especially when dropout layers are present.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Tiled MLP implementation, a technique to reduce memory usage for long sequence training, based on the Arctic paper. The implementation involves a custom torch.autograd.Function for the tiled forward and backward passes, and monkey-patching utilities to apply this to existing models.

My review identified a critical bug in the logic that splits tensors into chunks, which could lead to crashes or incorrect behavior. I've also pointed out a couple of places where error handling can be made more specific by avoiding bare except clauses. Addressing these points will improve the robustness and correctness of this new feature.

Comment on lines +93 to +95
qlen_chunk_size, remainder = divmod(B*S, min(max(1, num_shards), B*S))
split_sizes = [qlen_chunk_size]*num_shards
if remainder != 0: split_sizes.append(remainder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for calculating split_sizes has several critical issues:

  1. It can cause a ZeroDivisionError if the input tensor x is empty (B*S == 0).
  2. If num_shards is greater than B*S, it creates a split_sizes list whose sum is not equal to B*S, which will cause torch.split to fail.
  3. In normal cases, it can create num_shards + 1 splits instead of the requested num_shards when there is a remainder.

The suggested change fixes these issues by correctly handling the edge cases and properly distributing the elements into the effective number of splits.

Suggested change
qlen_chunk_size, remainder = divmod(B*S, min(max(1, num_shards), B*S))
split_sizes = [qlen_chunk_size]*num_shards
if remainder != 0: split_sizes.append(remainder)
effective_num_shards = min(max(1, num_shards), B*S) if B*S > 0 else 0
q, r = divmod(B*S, effective_num_shards) if effective_num_shards > 0 else (0, 0)
split_sizes = [q + 1] * r + [q] * (effective_num_shards - r)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in subsequent commit. n_shards is a request for that many but if it's not evenly divisible it will result in num_shards+1 chunks.

flat_qlen = bsz*qlen
try:
intermediate_size = mlp_module.config.intermediate_size
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: can mask other errors. It's better to catch the specific exception you expect, which seems to be AttributeError in this case (if mlp_module doesn't have a config attribute or config doesn't have intermediate_size).

Suggested change
except:
except AttributeError:
if len(patch_options_strs) > 1:
try:
target_gb = float(patch_options_strs[-1])
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: can hide unexpected errors. It's good practice to catch specific exceptions. In this case, float() can raise a ValueError, so it's better to catch that explicitly.

Suggested change
except:
except ValueError:
@mmathew23 mmathew23 merged commit c9c7693 into unslothai:main Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant