Skip to content

Conversation

@nil0x9
Copy link
Contributor

@nil0x9 nil0x9 commented Dec 15, 2025

Currently when one passes loss_cfg.loss_reduction other than "token" on a ascend/npu device, a Runtime Error (device mismatch) is expected in this line:

loss = (loss * loss_weight).sum()

The root cause of this error is that, in ascend npu device, cu_seq_lens tensors are required to be on cpu. In func build_batches_loss_kwargs, the devuce ofloss_weight is inherited from num_grad_tokens -> boundaries -> cu_seq_lens -- and hence the problem.

@nil0x9 nil0x9 marked this pull request as ready for review December 15, 2025 17:38
@nil0x9 nil0x9 force-pushed the linty/fix-npu-loss-weight-device-mismatch branch from 8613e2d to f1294d3 Compare December 16, 2025 15:40
@nil0x9 nil0x9 force-pushed the linty/fix-npu-loss-weight-device-mismatch branch from f1294d3 to dacb7ef Compare December 24, 2025 08:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants