File tree Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Original file line number Diff line number Diff line change 88
99module_path = os .path .dirname (__file__ )
1010fused = load (
11- ' fused' ,
11+ " fused" ,
1212 sources = [
13- os .path .join (module_path , ' fused_bias_act.cpp' ),
14- os .path .join (module_path , ' fused_bias_act_kernel.cu' ),
13+ os .path .join (module_path , " fused_bias_act.cpp" ),
14+ os .path .join (module_path , " fused_bias_act_kernel.cu" ),
1515 ],
1616)
1717
@@ -83,4 +83,14 @@ def forward(self, input):
8383
8484
8585def fused_leaky_relu (input , bias , negative_slope = 0.2 , scale = 2 ** 0.5 ):
86- return FusedLeakyReLUFunction .apply (input , bias , negative_slope , scale )
86+ if input .device .type == "cpu" :
87+ rest_dim = [1 ] * (input .ndim - bias .ndim - 1 )
88+ return (
89+ F .leaky_relu (
90+ input + bias .view (1 , bias .shape [0 ], * rest_dim ), negative_slope = 0.2
91+ )
92+ * scale
93+ )
94+
95+ else :
96+ return FusedLeakyReLUFunction .apply (input , bias , negative_slope , scale )
You can’t perform that action at this time.
0 commit comments