@@ -902,6 +902,7 @@ def data_tuple_to_kwargs(self, data):
902902 assert not has_duplicates (self .ds_fields ), 'dataset fields must not have duplicate field names'
903903
904904 return dict (zip (self .ds_fields , data ))
905+
905906 @contextmanager
906907 def wandb_tracker (self , project , run = None , hps = None ):
907908 assert self .use_wandb_tracking , '`use_wandb_tracking` must be set to True on SemanticTransformerTrainer'
@@ -919,6 +920,7 @@ def wandb_tracker(self, project, run = None, hps = None):
919920 yield
920921
921922 self .accelerator .end_training ()
923+
922924 def train_step (self ):
923925 device = self .device
924926
@@ -1197,6 +1199,7 @@ def wandb_tracker(self, project, run = None, hps = None):
11971199 yield
11981200
11991201 self .accelerator .end_training ()
1202+
12001203 @property
12011204 def device (self ):
12021205 return self .accelerator .device
@@ -1475,6 +1478,7 @@ def print(self, msg):
14751478
14761479 def generate (self , * args , ** kwargs ):
14771480 return self .train_wrapper .generate (* args , ** kwargs )
1481+
14781482 @contextmanager
14791483 def wandb_tracker (self , project , run = None , hps = None ):
14801484 assert self .use_wandb_tracking , '`use_wandb_tracking` must be set to True on FineTransformerTrainer'
@@ -1491,7 +1495,8 @@ def wandb_tracker(self, project, run = None, hps = None):
14911495
14921496 yield
14931497
1494- self .accelerator .end_training ()
1498+ self .accelerator .end_training ()
1499+
14951500 @property
14961501 def device (self ):
14971502 return self .accelerator .device
0 commit comments