99from torch .autograd import Function
1010
1111from op import FusedLeakyReLU , fused_leaky_relu , upfirdn2d , conv2d_gradfix
12- from model import ModulatedConv2d , StyledConv , ConstantInput , PixelNorm , Upsample , Downsample , Blur , EqualLinear , ConvLayer
12+ from model import (
13+ ModulatedConv2d ,
14+ StyledConv ,
15+ ConstantInput ,
16+ PixelNorm ,
17+ Upsample ,
18+ Downsample ,
19+ Blur ,
20+ EqualLinear ,
21+ ConvLayer ,
22+ )
23+
1324
1425def get_haar_wavelet (in_channels ):
1526 haar_wav_l = 1 / (2 ** 0.5 ) * torch .ones (1 , 2 )
@@ -20,47 +31,88 @@ def get_haar_wavelet(in_channels):
2031 haar_wav_lh = haar_wav_h .T * haar_wav_l
2132 haar_wav_hl = haar_wav_l .T * haar_wav_h
2233 haar_wav_hh = haar_wav_h .T * haar_wav_h
23-
34+
2435 return haar_wav_ll , haar_wav_lh , haar_wav_hl , haar_wav_hh
2536
2637
38+ def dwt_init (x ):
39+ x01 = x [:, :, 0 ::2 , :] / 2
40+ x02 = x [:, :, 1 ::2 , :] / 2
41+ x1 = x01 [:, :, :, 0 ::2 ]
42+ x2 = x02 [:, :, :, 0 ::2 ]
43+ x3 = x01 [:, :, :, 1 ::2 ]
44+ x4 = x02 [:, :, :, 1 ::2 ]
45+ x_LL = x1 + x2 + x3 + x4
46+ x_HL = - x1 - x2 + x3 + x4
47+ x_LH = - x1 + x2 - x3 + x4
48+ x_HH = x1 - x2 - x3 + x4
49+
50+ return torch .cat ((x_LL , x_HL , x_LH , x_HH ), 1 )
51+
52+
53+ def iwt_init (x ):
54+ r = 2
55+ in_batch , in_channel , in_height , in_width = x .size ()
56+ # print([in_batch, in_channel, in_height, in_width])
57+ out_batch , out_channel , out_height , out_width = (
58+ in_batch ,
59+ int (in_channel / (r ** 2 )),
60+ r * in_height ,
61+ r * in_width ,
62+ )
63+ x1 = x [:, 0 :out_channel , :, :] / 2
64+ x2 = x [:, out_channel : out_channel * 2 , :, :] / 2
65+ x3 = x [:, out_channel * 2 : out_channel * 3 , :, :] / 2
66+ x4 = x [:, out_channel * 3 : out_channel * 4 , :, :] / 2
67+
68+ h = torch .zeros ([out_batch , out_channel , out_height , out_width ]).float ().cuda ()
69+
70+ h [:, :, 0 ::2 , 0 ::2 ] = x1 - x2 - x3 + x4
71+ h [:, :, 1 ::2 , 0 ::2 ] = x1 - x2 + x3 - x4
72+ h [:, :, 0 ::2 , 1 ::2 ] = x1 + x2 - x3 - x4
73+ h [:, :, 1 ::2 , 1 ::2 ] = x1 + x2 + x3 + x4
74+
75+ return h
76+
77+
2778class HaarTransform (nn .Module ):
2879 def __init__ (self , in_channels ):
2980 super ().__init__ ()
30-
81+
3182 ll , lh , hl , hh = get_haar_wavelet (in_channels )
32-
33- self .register_buffer ('ll' , ll )
34- self .register_buffer ('lh' , lh )
35- self .register_buffer ('hl' , hl )
36- self .register_buffer ('hh' , hh )
37-
83+
84+ self .register_buffer ("ll" , ll )
85+ self .register_buffer ("lh" , lh )
86+ self .register_buffer ("hl" , hl )
87+ self .register_buffer ("hh" , hh )
88+
3889 def forward (self , input ):
3990 ll = upfirdn2d (input , self .ll , down = 2 )
4091 lh = upfirdn2d (input , self .lh , down = 2 )
4192 hl = upfirdn2d (input , self .hl , down = 2 )
4293 hh = upfirdn2d (input , self .hh , down = 2 )
43-
94+
4495 return torch .cat ((ll , lh , hl , hh ), 1 )
45-
96+
97+
4698class InverseHaarTransform (nn .Module ):
4799 def __init__ (self , in_channels ):
48100 super ().__init__ ()
49-
101+
50102 ll , lh , hl , hh = get_haar_wavelet (in_channels )
51103
52- self .register_buffer ('ll' , ll )
53- self .register_buffer ('lh' , - lh )
54- self .register_buffer ('hl' , - hl )
55- self .register_buffer ('hh' , hh )
56-
104+ self .register_buffer ("ll" , ll )
105+ self .register_buffer ("lh" , - lh )
106+ self .register_buffer ("hl" , - hl )
107+ self .register_buffer ("hh" , hh )
108+
57109 def forward (self , input ):
58110 ll , lh , hl , hh = input .chunk (4 , 1 )
59111 ll = upfirdn2d (ll , self .ll , up = 2 , pad = (1 , 0 , 1 , 0 ))
60112 lh = upfirdn2d (lh , self .lh , up = 2 , pad = (1 , 0 , 1 , 0 ))
61113 hl = upfirdn2d (hl , self .hl , up = 2 , pad = (1 , 0 , 1 , 0 ))
62114 hh = upfirdn2d (hh , self .hh , up = 2 , pad = (1 , 0 , 1 , 0 ))
63-
115+
64116 return ll + lh + hl + hh
65117
66118
@@ -299,7 +351,7 @@ def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
299351 self .downsample = Downsample (blur_kernel )
300352 self .dwt = HaarTransform (3 )
301353
302- self .conv = ConvLayer (3 * 4 , out_channel , 1 )
354+ self .conv = ConvLayer (3 * 4 , out_channel , 3 )
303355
304356 def forward (self , input , skip = None ):
305357 if self .downsample :
0 commit comments