File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change 25
25
26
26
import ray
27
27
from ray import tune
28
+ from ray .rllib .agents .callbacks import MemoryTrackingCallbacks
28
29
29
30
# import configparser
30
31
import pprint
@@ -304,6 +305,31 @@ def main(args):
304
305
res_dir = args .framework_dir + "/_ray_results_" + str (args .config_num )
305
306
print ("## Results dir: {}" .format (res_dir ))
306
307
308
+ callbacks = []
309
+
310
+ mem_callback = MemoryTrackingCallbacks ()
311
+ mdpp_on_episode_end = tune_config ["callbacks" ]["on_episode_end" ]
312
+
313
+ def combined_on_episode_end (info ):
314
+ """
315
+ Old Ray: callbacks were just a dict of functions getting an info
316
+ dict.
317
+ New Ray: callbacks are objects and their functions have a more
318
+ explicit signature.
319
+ When we pass this dict, Ray passes the old info-dict. For using the
320
+ MemoryTrackingCallbacks, we have to fake its signature, which works
321
+ fine as it only needs the "episode" value anyways.
322
+
323
+ .. _See: https://docs.ray.io/en/releases-1.6.0/_modules/ray/rllib\
324
+ /agents/callbacks.html
325
+ """
326
+ mem_callback .on_episode_end (worker = None , base_env = None ,
327
+ policies = None , episode = info ["episode" ],
328
+ env_index = None )
329
+ mdpp_on_episode_end (info )
330
+
331
+ tune_config ["callbacks" ]["on_episode_end" ] = combined_on_episode_end
332
+
307
333
if args .wandb is not None :
308
334
from ray .tune .integration .wandb import WandbLoggerCallback
309
335
if not os .path .isfile (WANDB_API_KEY_FILE ):
You can’t perform that action at this time.
0 commit comments