@@ -207,9 +207,9 @@ def sample_affine(p, size, height, width, device="cpu"):
207207 # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208208
209209 # integer translate
210- param = uniform_sample (size , - 0.125 , 0.125 )
211- param_height = torch .round (param * height ) / height
212- param_width = torch .round (param * width ) / width
210+ param = uniform_sample (( 2 , size ) , - 0.125 , 0.125 )
211+ param_height = torch .round (param [ 0 ] * height )
212+ param_width = torch .round (param [ 1 ] * width )
213213 Gc = translate_mat (param_width , param_height , device = device )
214214 G = random_mat_apply (p , Gc , G , eye , device = device )
215215 # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
@@ -241,8 +241,8 @@ def sample_affine(p, size, height, width, device="cpu"):
241241 # print('post-rotate', G, rotate_mat(-param), sep='\n')
242242
243243 # fractional translate
244- param = normal_sample (size , std = 0.125 )
245- Gc = translate_mat (param , param , device = device )
244+ param = normal_sample (( 2 , size ) , std = 0.125 )
245+ Gc = translate_mat (param [ 1 ] * width , param [ 0 ] * height , device = device )
246246 G = random_mat_apply (p , Gc , G , eye , device = device )
247247 # print('fractional translate', G, translate_mat(param, param), sep='\n')
248248
@@ -365,7 +365,7 @@ def forward(ctx, grad_output, input, grid):
365365
366366 @staticmethod
367367 def backward (ctx , grad_grad_input , grad_grad_grid ):
368- grid , = ctx .saved_tensors
368+ ( grid ,) = ctx .saved_tensors
369369 grad_grad_output = None
370370
371371 if ctx .needs_input_grad [0 ]:
0 commit comments