Skip to content

Commit 819002c

Browse files
Actually Added timesteps_total to in run_experiments.py; some testing code
1 parent b3a09fd commit 819002c

File tree

3 files changed

+66
-12
lines changed

3 files changed

+66
-12
lines changed

‎mdp_playground/scripts/run_experiments.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
33
Takes a configuration file, experiment name and config number to run as
44
optional arguments.
5+
6+
e.g.: python mdp_playground/scripts/run_experiments.py -a 0 -n 0 -c \
7+
default_config.py -e default_config
58
"""
69

710
from __future__ import absolute_import
@@ -129,7 +132,7 @@ def main(args):
129132
"training.",
130133
)
131134
parser.add_argument(
132-
"-t",
135+
"-d",
133136
"--framework-dir",
134137
dest="framework_dir",
135138
action="store",
@@ -139,6 +142,15 @@ def main(args):
139142
"framework (e.g. Ray Rllib, Stable Baselines 3). This "
140143
"name will be passed to the framework.",
141144
)
145+
parser.add_argument(
146+
"-t",
147+
"--timesteps-total",
148+
dest="timesteps_total",
149+
action="store",
150+
default=None,
151+
type=int,
152+
help="Total number of env steps to run expt for."
153+
)
142154
# parser.add_argument('-t', '--tune-hps', dest='tune_hps', action='store',
143155
# default=False, type=bool,
144156
# help='Used for tuning the hyperparameters that can be '
@@ -227,7 +239,9 @@ def main(args):
227239
)
228240
pp.pprint(tune_config)
229241

230-
if "timesteps_total" in dir(config):
242+
if args.timesteps_total is not None:
243+
timesteps_total = args.timesteps_total
244+
elif "timesteps_total" in dir(config):
231245
timesteps_total = config.timesteps_total
232246
else:
233247
timesteps_total = tune_config["timesteps_total"]

‎tests/test_analysis_code.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import sys
1+
import sys, os
22
from datetime import datetime
33
import logging
44
import copy
55
import numpy as np
6-
from mdp_playground.envs.rl_toy_env import RLToyEnv
76
import unittest
87
import pytest
98

@@ -16,6 +15,24 @@
1615

1716

1817
class TestAnalysisCode(unittest.TestCase):
18+
19+
# ###TODO Enable once scipy can be upraded
20+
# def test_generate_plots(self):
21+
22+
# exit_code = os.system(
23+
# sys.executable
24+
# + " misc/generate_mdpp_plots.py -f misc/test_expt_list.txt -n 1"
25+
# )
26+
# assert exit_code == 0
27+
28+
# from glob import glob
29+
# plot_list = glob("*.pdf")
30+
# plot_list_exp = ['rainbow_seq_del_train_final_reward_delay_episode_reward_mean_1d.pdf', 'rainbow_seq_del_train_final_reward_sequence_length_episode_reward_mean_1d.pdf', 'dqn_seq_del_train_learning_curves_episode_reward_mean.pdf', 'rainbow_seq_del_train_final_reward_mean_heat_map_episode_reward_mean.pdf', 'dqn_seq_del_train_final_reward_mean_heat_map_episode_reward_mean.pdf', 'dqn_seq_del_train_final_reward_sequence_length_episode_reward_mean_1d.pdf', 'rainbow_seq_del_train_final_reward_std_heat_map_episode_reward_mean.pdf', 'dqn_seq_del_train_final_reward_delay_episode_reward_mean_1d.pdf', 'dqn_seq_del_train_final_reward_std_heat_map_episode_reward_mean.pdf']
31+
32+
# import collections
33+
# assert collections.Counter(plot_list) == collections.Counter(plot_list_exp), "Unexpected PDF file found when generating plots. Found:" + str(plot_list)
34+
35+
1936
@pytest.mark.skip(
2037
reason="CAVE dependencies throw ImportError: cannot import name 'StatusType'"
2138
)
@@ -63,3 +80,6 @@ def test_mdpp_to_cave(self):
6380
for i in range(2):
6481
l = fh.readline()
6582
assert l.strip() == results_json_line_2
83+
84+
if __name__ == "__main__":
85+
unittest.main()

‎tests/test_run_experiments.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
import os
55

6+
import numpy as np
7+
68
import logging
79

810
from datetime import datetime
@@ -30,7 +32,7 @@ def test_dqn_test_expt(self):
3032

3133
exit_code = os.system(
3234
sys.executable
33-
+ " mdp_playground/scripts/run_experiments.py -a 0 -n 0 -c experiments/dqn_test_expt.py -e dqn_test_expt"
35+
+ " run_experiments.py -n 0 -c experiments/dqn_test_expt.py -e dqn_test_expt"
3436
)
3537
assert exit_code == 0
3638

@@ -43,7 +45,6 @@ def test_dqn_test_expt(self):
4345
experiments, load_eval=False, exp_type="grid"
4446
)
4547

46-
import numpy as np
4748

4849
final_metrics = np.squeeze(list_exp_data[0]["train_stats"])
4950
np.testing.assert_allclose(
@@ -56,13 +57,32 @@ def test_dqn_test_expt(self):
5657
exit_code = os.system("rm dqn_test_expt_0*.csv")
5758
assert exit_code == 0
5859

59-
def test_default_config(self):
60+
# Similar thing is tested above. These tests are time consuming, so rather have only a few of them.
61+
# def test_default_config(self):
6062

61-
exit_code = os.system(
62-
sys.executable
63-
+ " mdp_playground/scripts/run_experiments.py -a 0 -n 0 -c default_config.py -e default_config"
64-
)
65-
assert exit_code == 0
63+
# exit_code = os.system(
64+
# sys.executable
65+
# + " mdp_playground/scripts/run_experiments.py -n 0 -c default_config.py -e default_config"
66+
# )
67+
# assert exit_code == 0
68+
69+
# ###TODO Enable once branches are merged
70+
# def test_10_random_expts(self):
71+
72+
# from glob import glob
73+
# expt_list = glob("experiments/*.py")
74+
75+
# # sel_expt_list = np.random.randint(0, len(expt_list), 10)
76+
# expt_list = np.random.permutation(expt_list)
77+
# for i in range(2):
78+
# conf_file = expt_list[i]
79+
# exp_name = conf_file.split('/')[-1].split('.')[0]
80+
81+
# exit_code = os.system(
82+
# sys.executable
83+
# + " run_experiments.py -n 0 -c " + conf_file + " -e " + exp_name + " -t 2000"
84+
# )
85+
# assert exit_code == 0
6686

6787

6888
if __name__ == "__main__":

0 commit comments

Comments
 (0)