-
Notifications
You must be signed in to change notification settings - Fork 6k
Open
Labels
PFCCPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfccPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc
Description
当前 masked_scatter 的实现过于粗放,这种组合算子的写法作为早期实现没有问题,但在生产环境或处理较大 shape 的 Tensor 时,会导致非常严重的性能瓶颈:
-
显存开销极大:执行过程中创建了大量无用的中间 Tensor(如 zeros_like_x, mask_prefix,以及 value.flatten()[mask_prefix] 产生的高级索引中间结果),极易导致显存峰值溢出(OOM)。
-
执行效率低下:涉及了 add, cast, cumsum, clip, flatten, where 等一连串琐碎的算子调用,造成了大量的 Kernel Launch overhead 和访存瓶颈(Memory Bandwidth Bound),无法发挥 GPU 的并行计算优势。
Paddle/python/paddle/tensor/manipulation.py
Lines 5590 to 5605 in 74cd273
| assert x.dtype == value.dtype, ( | |
| f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' | |
| ) | |
| assert mask.dtype == paddle.bool | |
| zeros_like_x = paddle.zeros_like(x, dtype=int) | |
| mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) | |
| mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) | |
| if in_dynamic_mode() and mask_prefix.numel() != 0: | |
| assert mask_prefix[-1] <= value.numel(), ( | |
| f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' | |
| ) | |
| value = value.flatten()[mask_prefix].reshape(mask.shape) | |
| mask = paddle.logical_not(mask.astype(bool)) | |
| return paddle.where(mask, x, value) |
前向的时候引入了这些琐碎算子(注意 index_elementwise_get 这个 kernel):
而反向的时候,反向算子 index_elementwise_get_grad 甚至引入了一个 Memcpy H2D + 大空泡

Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
PFCCPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfccPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc