Skip to content

Commit ed986b6

Browse files
authored
Merge branch 'dev/memtrack' into dev/jan
2 parents 8ad2e48 + e8685fa commit ed986b6

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

‎mdp_playground/scripts/run_experiments.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import ray
2727
from ray import tune
28+
from ray.rllib.agents.callbacks import MemoryTrackingCallbacks
2829

2930
# import configparser
3031
import pprint
@@ -304,6 +305,31 @@ def main(args):
304305
res_dir = args.framework_dir + "/_ray_results_" + str(args.config_num)
305306
print("## Results dir: {}".format(res_dir))
306307

308+
callbacks = []
309+
310+
mem_callback = MemoryTrackingCallbacks()
311+
mdpp_on_episode_end = tune_config["callbacks"]["on_episode_end"]
312+
313+
def combined_on_episode_end(info):
314+
"""
315+
Old Ray: callbacks were just a dict of functions getting an info
316+
dict.
317+
New Ray: callbacks are objects and their functions have a more
318+
explicit signature.
319+
When we pass this dict, Ray passes the old info-dict. For using the
320+
MemoryTrackingCallbacks, we have to fake its signature, which works
321+
fine as it only needs the "episode" value anyways.
322+
323+
.. _See: https://docs.ray.io/en/releases-1.6.0/_modules/ray/rllib\
324+
/agents/callbacks.html
325+
"""
326+
mem_callback.on_episode_end(worker=None, base_env=None,
327+
policies=None, episode=info["episode"],
328+
env_index=None)
329+
mdpp_on_episode_end(info)
330+
331+
tune_config["callbacks"]["on_episode_end"] = combined_on_episode_end
332+
307333
if args.wandb is not None:
308334
from ray.tune.integration.wandb import WandbLoggerCallback
309335
if not os.path.isfile(WANDB_API_KEY_FILE):

0 commit comments

Comments
 (0)