Skip to content

Commit b379498

Browse files
committed
cleanup
1 parent 6c031a4 commit b379498

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

‎audiolm_pytorch/trainer.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,7 @@ def data_tuple_to_kwargs(self, data):
902902
assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
903903

904904
return dict(zip(self.ds_fields, data))
905+
905906
@contextmanager
906907
def wandb_tracker(self, project, run = None, hps = None):
907908
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SemanticTransformerTrainer'
@@ -919,6 +920,7 @@ def wandb_tracker(self, project, run = None, hps = None):
919920
yield
920921

921922
self.accelerator.end_training()
923+
922924
def train_step(self):
923925
device = self.device
924926

@@ -1197,6 +1199,7 @@ def wandb_tracker(self, project, run = None, hps = None):
11971199
yield
11981200

11991201
self.accelerator.end_training()
1202+
12001203
@property
12011204
def device(self):
12021205
return self.accelerator.device
@@ -1475,6 +1478,7 @@ def print(self, msg):
14751478

14761479
def generate(self, *args, **kwargs):
14771480
return self.train_wrapper.generate(*args, **kwargs)
1481+
14781482
@contextmanager
14791483
def wandb_tracker(self, project, run = None, hps = None):
14801484
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on FineTransformerTrainer'
@@ -1491,7 +1495,8 @@ def wandb_tracker(self, project, run = None, hps = None):
14911495

14921496
yield
14931497

1494-
self.accelerator.end_training()
1498+
self.accelerator.end_training()
1499+
14951500
@property
14961501
def device(self):
14971502
return self.accelerator.device

0 commit comments

Comments
 (0)