@@ -1038,8 +1038,8 @@ def __init__(
10381038 ):
10391039 super ().__init__ ()
10401040 check_one_trainer ()
1041+ self .use_wandb_tracking = use_wandb_tracking
10411042 if use_wandb_tracking :
1042- self .use_wandb_tracking = use_wandb_tracking
10431043 accelerate_kwargs .update (log_with = 'wandb' )
10441044 init_process_kwargs = InitProcessGroupKwargs (timeout = timedelta (seconds = init_process_group_timeout_seconds ))
10451045
@@ -1140,7 +1140,7 @@ def __init__(
11401140 self .valid_dl_iter = cycle (self .valid_dl )
11411141
11421142 self .save_model_every = save_model_every
1143- self .save_results_every = save_results_every
1143+ self .save_results_every = save_results_every
11441144
11451145 self .results_folder = Path (results_folder )
11461146
@@ -1152,7 +1152,7 @@ def __init__(
11521152 hps = {"num_train_steps" : num_train_steps , "data_max_length" : data_max_length , "learning_rate" : lr }
11531153 self .tracker_hps = hps
11541154
1155- self .accelerator .init_trackers ("coarse" , config = hps )
1155+ self .accelerator .init_trackers ("coarse" , config = hps )
11561156
11571157 self .train_wrapper .to (self .device )
11581158 self .average_valid_loss_over_grad_accum_every = average_valid_loss_over_grad_accum_every
@@ -1338,8 +1338,8 @@ def __init__(
13381338 ):
13391339 super ().__init__ ()
13401340 check_one_trainer ()
1341+ self .use_wandb_tracking = use_wandb_tracking
13411342 if use_wandb_tracking :
1342- self .use_wandb_tracking = use_wandb_tracking
13431343 accelerate_kwargs .update (log_with = 'wandb' )
13441344 init_process_kwargs = InitProcessGroupKwargs (timeout = timedelta (seconds = init_process_group_timeout_seconds ))
13451345
@@ -1435,7 +1435,7 @@ def __init__(
14351435 self .valid_dl_iter = cycle (self .valid_dl )
14361436
14371437 self .save_model_every = save_model_every
1438- self .save_results_every = save_results_every
1438+ self .save_results_every = save_results_every
14391439
14401440 self .results_folder = Path (results_folder )
14411441
@@ -1448,7 +1448,7 @@ def __init__(
14481448 hps = {"num_train_steps" : num_train_steps , "data_max_length" : data_max_length , "learning_rate" : lr }
14491449 self .tracker_hps = hps
14501450
1451- self .accelerator .init_trackers ("fine" , config = hps )
1451+ self .accelerator .init_trackers ("fine" , config = hps )
14521452
14531453 self .train_wrapper .to (self .device )
14541454 self .average_valid_loss_over_grad_accum_every = average_valid_loss_over_grad_accum_every
0 commit comments