Skip to content

Commit ce6690e

Browse files
author
Orr Paradise
committed
Bugfix: use_wandb_tracking is not saved in Fine/CoarseTransformerTrainer
1 parent 1fb4f33 commit ce6690e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

‎audiolm_pytorch/trainer.py‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,8 @@ def __init__(
10381038
):
10391039
super().__init__()
10401040
check_one_trainer()
1041+
self.use_wandb_tracking = use_wandb_tracking
10411042
if use_wandb_tracking:
1042-
self.use_wandb_tracking = use_wandb_tracking
10431043
accelerate_kwargs.update(log_with = 'wandb')
10441044
init_process_kwargs = InitProcessGroupKwargs(timeout = timedelta(seconds = init_process_group_timeout_seconds))
10451045

@@ -1140,7 +1140,7 @@ def __init__(
11401140
self.valid_dl_iter = cycle(self.valid_dl)
11411141

11421142
self.save_model_every = save_model_every
1143-
self.save_results_every = save_results_every
1143+
self.save_results_every = save_results_every
11441144

11451145
self.results_folder = Path(results_folder)
11461146

@@ -1152,7 +1152,7 @@ def __init__(
11521152
hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
11531153
self.tracker_hps = hps
11541154

1155-
self.accelerator.init_trackers("coarse", config=hps)
1155+
self.accelerator.init_trackers("coarse", config=hps)
11561156

11571157
self.train_wrapper.to(self.device)
11581158
self.average_valid_loss_over_grad_accum_every = average_valid_loss_over_grad_accum_every
@@ -1338,8 +1338,8 @@ def __init__(
13381338
):
13391339
super().__init__()
13401340
check_one_trainer()
1341+
self.use_wandb_tracking = use_wandb_tracking
13411342
if use_wandb_tracking:
1342-
self.use_wandb_tracking = use_wandb_tracking
13431343
accelerate_kwargs.update(log_with = 'wandb')
13441344
init_process_kwargs = InitProcessGroupKwargs(timeout = timedelta(seconds = init_process_group_timeout_seconds))
13451345

@@ -1435,7 +1435,7 @@ def __init__(
14351435
self.valid_dl_iter = cycle(self.valid_dl)
14361436

14371437
self.save_model_every = save_model_every
1438-
self.save_results_every = save_results_every
1438+
self.save_results_every = save_results_every
14391439

14401440
self.results_folder = Path(results_folder)
14411441

@@ -1448,7 +1448,7 @@ def __init__(
14481448
hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
14491449
self.tracker_hps = hps
14501450

1451-
self.accelerator.init_trackers("fine", config=hps)
1451+
self.accelerator.init_trackers("fine", config=hps)
14521452

14531453
self.train_wrapper.to(self.device)
14541454
self.average_valid_loss_over_grad_accum_every = average_valid_loss_over_grad_accum_every

0 commit comments

Comments
 (0)