Skip to content

Conversation

@mmathew23
Copy link
Collaborator

The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.

device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -1024? Is it maybe hd?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait actually it is 1024, ie the BLOCK_SIZE.

batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move (2**31) to a global var

e,
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later

@mmathew23 mmathew23 force-pushed the tiled/contextlen branch 2 times, most recently from c008eca to 262ada3 Compare November 19, 2025 17:24
@mmathew23 mmathew23 marked this pull request as ready for review November 19, 2025 22:16
@mmathew23
Copy link
Collaborator Author

Why -1024? Is it maybe hd?

So the idea is that offsets cannot be more than 2**31-1 which means n_elements<=2**31. I want to add a buffer before this point and since we are processing in BLOCK_SIZE blocks instead of hidden_dim blocks I figured it would be better. Plus we get the added benefit of the behavior remaining consistent across models.

I've updated the PR to reflect your comments and finalized it. Let me know if there's anything else to address.

@danielhanchen danielhanchen merged commit ac82560 into unslothai:main Nov 20, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants