@@ -323,8 +323,8 @@ def forward(ctx, X, weight, weight_scale):
323323 assert block_size is not None , "block_size is not set"
324324 if triton .cdiv (m , block_size [0 ]) != p or triton .cdiv (n , block_size [1 ]) != q :
325325 if triton .cdiv (m , block_size [0 ]) == q and triton .cdiv (n , block_size [1 ]) == p :
326- # weights are tranposed during backward pass for training :)
327- # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
326+ # weights are transposed during backward pass for training :)
327+ # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
328328 weight_scale = weight_scale .T
329329 else :
330330 raise ValueError (f"Weight shape { weight .shape } and scales shape { weight_scale .shape } is not compatible with block size { block_size } " )
@@ -437,8 +437,8 @@ def forward(ctx, X, weight, weight_scale, bias=None):
437437
438438 if triton .cdiv (m , bs_n ) != p or triton .cdiv (n , bs_k ) != q :
439439 if triton .cdiv (m , bs_n ) == q and triton .cdiv (n , bs_k ) == p :
440- # weights are tranposed during backward pass for training :)
441- # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
440+ # weights are transposed during backward pass for training :)
441+ # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
442442 weight_scale = weight_scale .T
443443 else :
444444 raise ValueError (f"Weight shape { weight .shape } and scales shape { weight_scale .shape } is not compatible with block size { block_size } " )
0 commit comments