Skip to content

[Performance] Should we optimize the masked_scatter operator with a C++/CUDA fast path? #78052

@DrRyanHuang

Description

@DrRyanHuang

当前 masked_scatter 的实现过于粗放,这种组合算子的写法作为早期实现没有问题,但在生产环境或处理较大 shape 的 Tensor 时,会导致非常严重的性能瓶颈:

  • 显存开销极大:执行过程中创建了大量无用的中间 Tensor(如 zeros_like_x, mask_prefix,以及 value.flatten()[mask_prefix] 产生的高级索引中间结果),极易导致显存峰值溢出(OOM)。

  • 执行效率低下:涉及了 add, cast, cumsum, clip, flatten, where 等一连串琐碎的算子调用,造成了大量的 Kernel Launch overhead 和访存瓶颈(Memory Bandwidth Bound),无法发挥 GPU 的并行计算优势。

assert x.dtype == value.dtype, (
f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
)
assert mask.dtype == paddle.bool
zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
if in_dynamic_mode() and mask_prefix.numel() != 0:
assert mask_prefix[-1] <= value.numel(), (
f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}'
)
value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask.astype(bool))
return paddle.where(mask, x, value)

前向的时候引入了这些琐碎算子(注意 index_elementwise_get 这个 kernel):

Image

而反向的时候,反向算子 index_elementwise_get_grad 甚至引入了一个 Memcpy H2D + 大空泡

Image

Metadata

Metadata

Assignees

Labels

PFCCPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions