Skip to content

Commit 7e46e47

Browse files
committed
Added optimized conv2d grad
1 parent 1376b90 commit 7e46e47

File tree

3 files changed

+299
-26
lines changed

3 files changed

+299
-26
lines changed

‎model.py‎

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn import functional as F
99
from torch.autograd import Function
1010

11-
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
11+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
1212

1313

1414
class PixelNorm(nn.Module):
@@ -112,7 +112,7 @@ def __init__(
112112
self.bias = None
113113

114114
def forward(self, input):
115-
out = F.conv2d(
115+
out = conv2d_gradfix.conv2d(
116116
input,
117117
self.weight * self.scale,
118118
bias=self.bias,
@@ -177,6 +177,7 @@ def __init__(
177177
upsample=False,
178178
downsample=False,
179179
blur_kernel=[1, 3, 3, 1],
180+
fused=True,
180181
):
181182
super().__init__()
182183

@@ -214,6 +215,7 @@ def __init__(
214215
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
215216

216217
self.demodulate = demodulate
218+
self.fused = fused
217219

218220
def __repr__(self):
219221
return (
@@ -224,6 +226,35 @@ def __repr__(self):
224226
def forward(self, input, style):
225227
batch, in_channel, height, width = input.shape
226228

229+
if not self.fused:
230+
weight = self.scale * self.weight.squeeze(0)
231+
style = self.modulation(style)
232+
233+
if self.demodulate:
234+
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235+
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236+
237+
input = input * style.reshape(batch, in_channel, 1, 1)
238+
239+
if self.upsample:
240+
weight = weight.transpose(0, 1)
241+
out = conv2d_gradfix.conv_transpose2d(
242+
input, weight, padding=0, stride=2
243+
)
244+
out = self.blur(out)
245+
246+
elif self.downsample:
247+
input = self.blur(input)
248+
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249+
250+
else:
251+
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252+
253+
if self.demodulate:
254+
out = out * dcoefs.view(batch, -1, 1, 1)
255+
256+
return out
257+
227258
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
228259
weight = self.scale * self.weight * style
229260

@@ -243,7 +274,9 @@ def forward(self, input, style):
243274
weight = weight.transpose(1, 2).reshape(
244275
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
245276
)
246-
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
277+
out = conv2d_gradfix.conv_transpose2d(
278+
input, weight, padding=0, stride=2, groups=batch
279+
)
247280
_, _, height, width = out.shape
248281
out = out.view(batch, self.out_channel, height, width)
249282
out = self.blur(out)
@@ -252,13 +285,17 @@ def forward(self, input, style):
252285
input = self.blur(input)
253286
_, _, height, width = input.shape
254287
input = input.view(1, batch * in_channel, height, width)
255-
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
288+
out = conv2d_gradfix.conv2d(
289+
input, weight, padding=0, stride=2, groups=batch
290+
)
256291
_, _, height, width = out.shape
257292
out = out.view(batch, self.out_channel, height, width)
258293

259294
else:
260295
input = input.view(1, batch * in_channel, height, width)
261-
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
296+
out = conv2d_gradfix.conv2d(
297+
input, weight, padding=self.padding, groups=batch
298+
)
262299
_, _, height, width = out.shape
263300
out = out.view(batch, self.out_channel, height, width)
264301

‎op/conv2d_gradfix.py‎

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import contextlib
2+
import warnings
3+
4+
import torch
5+
from torch import autograd
6+
from torch.nn import functional as F
7+
8+
enabled = True
9+
weight_gradients_disabled = False
10+
11+
12+
@contextlib.contextmanager
13+
def no_weight_gradients():
14+
global weight_gradients_disabled
15+
16+
old = weight_gradients_disabled
17+
weight_gradients_disabled = True
18+
yield
19+
weight_gradients_disabled = old
20+
21+
22+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23+
if could_use_op(input):
24+
return conv2d_gradfix(
25+
transpose=False,
26+
weight_shape=weight.shape,
27+
stride=stride,
28+
padding=padding,
29+
output_padding=0,
30+
dilation=dilation,
31+
groups=groups,
32+
).apply(input, weight, bias)
33+
34+
return F.conv2d(
35+
input=input,
36+
weight=weight,
37+
bias=bias,
38+
stride=stride,
39+
padding=padding,
40+
dilation=dilation,
41+
groups=groups,
42+
)
43+
44+
45+
def conv_transpose2d(
46+
input,
47+
weight,
48+
bias=None,
49+
stride=1,
50+
padding=0,
51+
output_padding=0,
52+
groups=1,
53+
dilation=1,
54+
):
55+
if could_use_op(input):
56+
return conv2d_gradfix(
57+
transpose=True,
58+
weight_shape=weight.shape,
59+
stride=stride,
60+
padding=padding,
61+
output_padding=output_padding,
62+
groups=groups,
63+
dilation=dilation,
64+
).apply(input, weight, bias)
65+
66+
return F.conv_transpose2d(
67+
input=input,
68+
weight=weight,
69+
bias=bias,
70+
stride=stride,
71+
padding=padding,
72+
output_padding=output_padding,
73+
dilation=dilation,
74+
groups=groups,
75+
)
76+
77+
78+
def could_use_op(input):
79+
if (not enabled) or (not torch.backends.cudnn.enabled):
80+
return False
81+
82+
if input.device.type != "cuda":
83+
return False
84+
85+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86+
return True
87+
88+
warnings.warn(
89+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90+
)
91+
92+
return False
93+
94+
95+
def ensure_tuple(xs, ndim):
96+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97+
98+
return xs
99+
100+
101+
conv2d_gradfix_cache = dict()
102+
103+
104+
def conv2d_gradfix(
105+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
106+
):
107+
ndim = 2
108+
weight_shape = tuple(weight_shape)
109+
stride = ensure_tuple(stride, ndim)
110+
padding = ensure_tuple(padding, ndim)
111+
output_padding = ensure_tuple(output_padding, ndim)
112+
dilation = ensure_tuple(dilation, ndim)
113+
114+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115+
if key in conv2d_gradfix_cache:
116+
return conv2d_gradfix_cache[key]
117+
118+
common_kwargs = dict(
119+
stride=stride, padding=padding, dilation=dilation, groups=groups
120+
)
121+
122+
def calc_output_padding(input_shape, output_shape):
123+
if transpose:
124+
return [0, 0]
125+
126+
return [
127+
input_shape[i + 2]
128+
- (output_shape[i + 2] - 1) * stride[i]
129+
- (1 - 2 * padding[i])
130+
- dilation[i] * (weight_shape[i + 2] - 1)
131+
for i in range(ndim)
132+
]
133+
134+
class Conv2d(autograd.Function):
135+
@staticmethod
136+
def forward(ctx, input, weight, bias):
137+
if not transpose:
138+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139+
140+
else:
141+
out = F.conv_transpose2d(
142+
input=input,
143+
weight=weight,
144+
bias=bias,
145+
output_padding=output_padding,
146+
**common_kwargs,
147+
)
148+
149+
ctx.save_for_backward(input, weight)
150+
151+
return out
152+
153+
@staticmethod
154+
def backward(ctx, grad_output):
155+
input, weight = ctx.saved_tensors
156+
grad_input, grad_weight, grad_bias = None, None, None
157+
158+
if ctx.needs_input_grad[0]:
159+
p = calc_output_padding(
160+
input_shape=input.shape, output_shape=grad_output.shape
161+
)
162+
grad_input = conv2d_gradfix(
163+
transpose=(not transpose),
164+
weight_shape=weight_shape,
165+
output_padding=p,
166+
**common_kwargs,
167+
).apply(grad_output, weight, None)
168+
169+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
171+
172+
if ctx.needs_input_grad[2]:
173+
grad_bias = grad_output.sum((0, 2, 3))
174+
175+
return grad_input, grad_weight, grad_bias
176+
177+
class Conv2dGradWeight(autograd.Function):
178+
@staticmethod
179+
def forward(ctx, grad_output, input):
180+
op = torch._C._jit_get_operation(
181+
"aten::cudnn_convolution_backward_weight"
182+
if not transpose
183+
else "aten::cudnn_convolution_transpose_backward_weight"
184+
)
185+
flags = [
186+
torch.backends.cudnn.benchmark,
187+
torch.backends.cudnn.deterministic,
188+
torch.backends.cudnn.allow_tf32,
189+
]
190+
grad_weight = op(
191+
weight_shape,
192+
grad_output,
193+
input,
194+
padding,
195+
stride,
196+
dilation,
197+
groups,
198+
*flags,
199+
)
200+
ctx.save_for_backward(grad_output, input)
201+
202+
return grad_weight
203+
204+
@staticmethod
205+
def backward(ctx, grad_grad_weight):
206+
grad_output, input = ctx.saved_tensors
207+
grad_grad_output, grad_grad_input = None, None
208+
209+
if ctx.needs_input_grad[0]:
210+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211+
212+
if ctx.needs_input_grad[1]:
213+
p = calc_output_padding(
214+
input_shape=input.shape, output_shape=grad_output.shape
215+
)
216+
grad_grad_input = conv2d_gradfix(
217+
transpose=(not transpose),
218+
weight_shape=weight_shape,
219+
output_padding=p,
220+
**common_kwargs,
221+
).apply(grad_output, grad_grad_weight, None)
222+
223+
return grad_grad_output, grad_grad_input
224+
225+
conv2d_gradfix_cache[key] = Conv2d
226+
227+
return Conv2d

0 commit comments

Comments
 (0)