Skip to content

Commit bef283a

Browse files
committed
Updated SWAGAN
1 parent a2f3891 commit bef283a

File tree

1 file changed

+71
-19
lines changed

1 file changed

+71
-19
lines changed

‎swagan.py‎

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99
from torch.autograd import Function
1010

1111
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12-
from model import ModulatedConv2d, StyledConv, ConstantInput, PixelNorm, Upsample, Downsample, Blur, EqualLinear, ConvLayer
12+
from model import (
13+
ModulatedConv2d,
14+
StyledConv,
15+
ConstantInput,
16+
PixelNorm,
17+
Upsample,
18+
Downsample,
19+
Blur,
20+
EqualLinear,
21+
ConvLayer,
22+
)
23+
1324

1425
def get_haar_wavelet(in_channels):
1526
haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
@@ -20,47 +31,88 @@ def get_haar_wavelet(in_channels):
2031
haar_wav_lh = haar_wav_h.T * haar_wav_l
2132
haar_wav_hl = haar_wav_l.T * haar_wav_h
2233
haar_wav_hh = haar_wav_h.T * haar_wav_h
23-
34+
2435
return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
2536

2637

38+
def dwt_init(x):
39+
x01 = x[:, :, 0::2, :] / 2
40+
x02 = x[:, :, 1::2, :] / 2
41+
x1 = x01[:, :, :, 0::2]
42+
x2 = x02[:, :, :, 0::2]
43+
x3 = x01[:, :, :, 1::2]
44+
x4 = x02[:, :, :, 1::2]
45+
x_LL = x1 + x2 + x3 + x4
46+
x_HL = -x1 - x2 + x3 + x4
47+
x_LH = -x1 + x2 - x3 + x4
48+
x_HH = x1 - x2 - x3 + x4
49+
50+
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
51+
52+
53+
def iwt_init(x):
54+
r = 2
55+
in_batch, in_channel, in_height, in_width = x.size()
56+
# print([in_batch, in_channel, in_height, in_width])
57+
out_batch, out_channel, out_height, out_width = (
58+
in_batch,
59+
int(in_channel / (r ** 2)),
60+
r * in_height,
61+
r * in_width,
62+
)
63+
x1 = x[:, 0:out_channel, :, :] / 2
64+
x2 = x[:, out_channel : out_channel * 2, :, :] / 2
65+
x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
66+
x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
67+
68+
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
69+
70+
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
71+
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
72+
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
73+
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
74+
75+
return h
76+
77+
2778
class HaarTransform(nn.Module):
2879
def __init__(self, in_channels):
2980
super().__init__()
30-
81+
3182
ll, lh, hl, hh = get_haar_wavelet(in_channels)
32-
33-
self.register_buffer('ll', ll)
34-
self.register_buffer('lh', lh)
35-
self.register_buffer('hl', hl)
36-
self.register_buffer('hh', hh)
37-
83+
84+
self.register_buffer("ll", ll)
85+
self.register_buffer("lh", lh)
86+
self.register_buffer("hl", hl)
87+
self.register_buffer("hh", hh)
88+
3889
def forward(self, input):
3990
ll = upfirdn2d(input, self.ll, down=2)
4091
lh = upfirdn2d(input, self.lh, down=2)
4192
hl = upfirdn2d(input, self.hl, down=2)
4293
hh = upfirdn2d(input, self.hh, down=2)
43-
94+
4495
return torch.cat((ll, lh, hl, hh), 1)
45-
96+
97+
4698
class InverseHaarTransform(nn.Module):
4799
def __init__(self, in_channels):
48100
super().__init__()
49-
101+
50102
ll, lh, hl, hh = get_haar_wavelet(in_channels)
51103

52-
self.register_buffer('ll', ll)
53-
self.register_buffer('lh', -lh)
54-
self.register_buffer('hl', -hl)
55-
self.register_buffer('hh', hh)
56-
104+
self.register_buffer("ll", ll)
105+
self.register_buffer("lh", -lh)
106+
self.register_buffer("hl", -hl)
107+
self.register_buffer("hh", hh)
108+
57109
def forward(self, input):
58110
ll, lh, hl, hh = input.chunk(4, 1)
59111
ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
60112
lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
61113
hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
62114
hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
63-
115+
64116
return ll + lh + hl + hh
65117

66118

@@ -299,7 +351,7 @@ def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
299351
self.downsample = Downsample(blur_kernel)
300352
self.dwt = HaarTransform(3)
301353

302-
self.conv = ConvLayer(3 * 4, out_channel, 1)
354+
self.conv = ConvLayer(3 * 4, out_channel, 3)
303355

304356
def forward(self, input, skip=None):
305357
if self.downsample:

0 commit comments

Comments
 (0)