Skip to content

Commit 7408367

Browse files
committed
Fixed translation augmentation bug
1 parent f8dc9ad commit 7408367

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

‎non_leaking.py‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def sample_affine(p, size, height, width, device="cpu"):
207207
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208208

209209
# integer translate
210-
param = uniform_sample(size, -0.125, 0.125)
211-
param_height = torch.round(param * height) / height
212-
param_width = torch.round(param * width) / width
210+
param = uniform_sample((2, size), -0.125, 0.125)
211+
param_height = torch.round(param[0] * height)
212+
param_width = torch.round(param[1] * width)
213213
Gc = translate_mat(param_width, param_height, device=device)
214214
G = random_mat_apply(p, Gc, G, eye, device=device)
215215
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
@@ -241,8 +241,8 @@ def sample_affine(p, size, height, width, device="cpu"):
241241
# print('post-rotate', G, rotate_mat(-param), sep='\n')
242242

243243
# fractional translate
244-
param = normal_sample(size, std=0.125)
245-
Gc = translate_mat(param, param, device=device)
244+
param = normal_sample((2, size), std=0.125)
245+
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
246246
G = random_mat_apply(p, Gc, G, eye, device=device)
247247
# print('fractional translate', G, translate_mat(param, param), sep='\n')
248248

@@ -365,7 +365,7 @@ def forward(ctx, grad_output, input, grid):
365365

366366
@staticmethod
367367
def backward(ctx, grad_grad_input, grad_grad_grid):
368-
grid, = ctx.saved_tensors
368+
(grid,) = ctx.saved_tensors
369369
grad_grad_output = None
370370

371371
if ctx.needs_input_grad[0]:

0 commit comments

Comments
 (0)