-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Add an int64 path for mlp kernels #3614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
unsloth/kernels/geglu.py
Outdated
| device = gate.device | ||
| out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why -1024? Is it maybe hd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait actually it is 1024, ie the BLOCK_SIZE.
unsloth/kernels/geglu.py
Outdated
| batch_seq_len, hd = e.shape | ||
| n_elements = e.numel() | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move (2**31) to a global var
unsloth/kernels/swiglu.py
Outdated
| e, | ||
| g, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later
c008eca to
262ada3
Compare
262ada3 to
833d91f
Compare
So the idea is that offsets cannot be more than I've updated the PR to reflect your comments and finalized it. Let me know if there's anything else to address. |
The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.