Skip to content

Commit 2154a74

Browse files
committed
needs to be zero centered GP
1 parent 6740714 commit 2154a74

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

‎audiolm_pytorch/soundstream.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def hinge_gen_loss(fake):
6767
def leaky_relu(p = 0.1):
6868
return nn.LeakyReLU(p)
6969

70-
def gradient_penalty(wave, output, weight = 10):
70+
def gradient_penalty(wave, output, weight = 10, center = 0.):
7171
batch_size, device = wave.shape[0], wave.device
7272

7373
gradients = torch_grad(
@@ -80,7 +80,7 @@ def gradient_penalty(wave, output, weight = 10):
8080
)[0]
8181

8282
gradients = rearrange(gradients, 'b ... -> b (...)')
83-
return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()
83+
return weight * ((vector_norm(gradients, dim = 1) - center) ** 2).mean()
8484

8585
# better sequential
8686

‎audiolm_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.3.0'
1+
__version__ = '2.3.1'

0 commit comments

Comments
 (0)