Skip to content

Commit be05887

Browse files
committed
Added native implementation for FusedLeakyReLU
1 parent 488d960 commit be05887

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

‎op/fused_act.py‎

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
module_path = os.path.dirname(__file__)
1010
fused = load(
11-
'fused',
11+
"fused",
1212
sources=[
13-
os.path.join(module_path, 'fused_bias_act.cpp'),
14-
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
13+
os.path.join(module_path, "fused_bias_act.cpp"),
14+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
1515
],
1616
)
1717

@@ -83,4 +83,14 @@ def forward(self, input):
8383

8484

8585
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86-
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
86+
if input.device.type == "cpu":
87+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
88+
return (
89+
F.leaky_relu(
90+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
91+
)
92+
* scale
93+
)
94+
95+
else:
96+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)

0 commit comments

Comments
 (0)