Skip to content

Commit 1fb4f33

Browse files
author
Orr Paradise
committed
Bugfix: use_wandb_tracking is not saved in SemanticTransformerTrainer
1 parent 8119f4f commit 1fb4f33

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

‎audiolm_pytorch/trainer.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,8 @@ def __init__(
747747
check_one_trainer()
748748

749749
init_process_kwargs = InitProcessGroupKwargs(timeout = timedelta(seconds = init_process_group_timeout_seconds))
750+
self.use_wandb_tracking = use_wandb_tracking
750751
if use_wandb_tracking:
751-
self.use_wandb_tracking = use_wandb_tracking
752752
accelerate_kwargs.update(log_with = 'wandb')
753753
self.accelerator = Accelerator(
754754
kwargs_handlers = [DEFAULT_DDP_KWARGS, init_process_kwargs],
@@ -840,7 +840,7 @@ def __init__(
840840
self.valid_dl_iter = cycle(self.valid_dl)
841841

842842
self.save_model_every = save_model_every
843-
self.save_results_every = save_results_every
843+
self.save_results_every = save_results_every
844844

845845
self.results_folder = Path(results_folder)
846846

@@ -849,7 +849,7 @@ def __init__(
849849

850850
self.accelerator.wait_for_everyone()
851851
self.results_folder.mkdir(parents = True, exist_ok = True)
852-
852+
853853
hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
854854
self.tracker_hps = hps
855855

0 commit comments

Comments
 (0)