Skip to content

Commit b3a09fd

Browse files
Added timesteps_total to in run_experiments.py; some testing code
1 parent 1645f4e commit b3a09fd

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

‎misc/generate_mdpp_plots.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Examples:
2+
# py generate_mdpp_plots.py -f expt_list.txt
3+
# py generate_mdpp_plots.py --dir-name 13699485 --exp-name dqn_del # --show-plots # dir_name and exp_name
4+
# Setup to analyse an MDP Playground experiment
5+
from mdp_playground.analysis import MDPP_Analysis
6+
7+
import yaml
8+
import argparse
9+
10+
from collections import Counter
11+
12+
# Based on https://stackoverflow.com/a/71751051/11063709, to allow keys to have a list of values
13+
# in case duplicate keys are present in the YAML.
14+
def parse_preserving_duplicates(src):
15+
class PreserveDuplicatesLoader(yaml.loader.Loader):
16+
pass
17+
18+
def map_constructor(loader, node, deep=False):
19+
keys = [loader.construct_object(node, deep=deep) for node, _ in node.value]
20+
vals = [loader.construct_object(node, deep=deep) for _, node in node.value]
21+
key_count = Counter(keys)
22+
data = {}
23+
for key, val in zip(keys, vals):
24+
if key_count[key] > 1:
25+
if key not in data:
26+
data[key] = []
27+
data[key].append(val)
28+
else:
29+
data[key] = [val]
30+
return data
31+
32+
PreserveDuplicatesLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, map_constructor)
33+
return yaml.load(src, PreserveDuplicatesLoader)
34+
35+
36+
def generate_plots(exp_name, dir_name, show_plots=False, options=''):
37+
print("Generating plots for " + str(dir_name) + ": " + exp_name + " with the following addnl. options: " + options)
38+
39+
# Set dir_name to the location where the CSV files from running an experiment were saved
40+
dir_name = str(dir_name) # e.g. 13699485
41+
# Set exp_name to the name that was given to the experiment when running it
42+
# exp_name = 'dqn_del'
43+
# Set the following to True to show plots that you generate below
44+
# show_plots = True
45+
# Set the following to True to save PDFs of plots that you generate below
46+
save_fig = True
47+
err_bar = 'bootstrap' # 't_dist', 'std'
48+
bonferroni = True
49+
if 'normalise_episodic_reward' in options:
50+
normalise_episodic_reward = True
51+
else:
52+
normalise_episodic_reward = False
53+
if 'eval' in options:
54+
load_eval = True
55+
else:
56+
load_eval = False
57+
if 'auto_y_scale' in options:
58+
common_y_scale = False
59+
else:
60+
common_y_scale = True
61+
62+
# Data loading
63+
mdpp_analysis = MDPP_Analysis()
64+
train_stats, eval_stats, train_curves, eval_curves, train_aucs, eval_aucs = mdpp_analysis.load_data(dir_name, exp_name, load_eval=load_eval, normalise_episodic_reward=normalise_episodic_reward)
65+
66+
# 1-D: Plots showing reward after total timesteps when varying a single meta-feature
67+
# Plots across n runs: Training: with std dev across the runs
68+
mdpp_analysis.plot_1d_dimensions(train_aucs, save_fig, bonferroni=bonferroni, err_bar=err_bar, show_plots=show_plots, common_y_scale=common_y_scale)
69+
70+
if 'ep_len' in options:
71+
mdpp_analysis.plot_1d_dimensions(train_aucs, save_fig, bonferroni=bonferroni, err_bar=err_bar, show_plots=show_plots, metric_num=-1)
72+
73+
# 2-D heatmap plots across n runs: Training runs: with std dev across the runs
74+
# There seems to be a bug with matplotlib - x and y axes tick labels are not correctly set even though we pass them. Please feel free to look into the code and suggest a correction if you find it.
75+
if 'plot_2d' in options:
76+
mdpp_analysis.plot_2d_heatmap(train_aucs, save_fig, show_plots=show_plots, common_y_scale=common_y_scale)
77+
78+
if 'ep_len' in options:
79+
mdpp_analysis.plot_2d_heatmap(train_aucs, save_fig, show_plots=show_plots, common_y_scale=common_y_scale, metric_num=-1)
80+
81+
# Plot learning curves: Training: Each curve corresponds to a different seed for the agent
82+
if 'learn_curves' in options:
83+
mdpp_analysis.plot_learning_curves(train_curves, save_fig, show_plots=show_plots, common_y_scale=common_y_scale)
84+
85+
if 'eval' in options:
86+
mdpp_analysis.plot_1d_dimensions(eval_aucs, save_fig, bonferroni=bonferroni, err_bar=err_bar, show_plots=show_plots, common_y_scale=common_y_scale, train=False)
87+
88+
if 'ep_len' in options:
89+
mdpp_analysis.plot_1d_dimensions(eval_aucs, save_fig, bonferroni=bonferroni, err_bar=err_bar, show_plots=show_plots, metric_num=-1, train=False)
90+
91+
if 'plot_2d' in options:
92+
mdpp_analysis.plot_2d_heatmap(eval_aucs, save_fig, show_plots=show_plots, common_y_scale=common_y_scale, train=False)
93+
94+
if 'ep_len' in options:
95+
mdpp_analysis.plot_2d_heatmap(eval_aucs, save_fig, show_plots=show_plots, common_y_scale=common_y_scale, metric_num=-1, train=False)
96+
97+
# Plot learning curves: Training: Each curve corresponds to a different seed for the agent
98+
if 'learn_curves' in options:
99+
mdpp_analysis.plot_learning_curves(eval_curves, save_fig, show_plots=show_plots, common_y_scale=common_y_scale, train=False)
100+
101+
102+
if __name__ == "__main__":
103+
104+
105+
parser = argparse.ArgumentParser(description="Process Latex .bib files")
106+
107+
parser.add_argument(
108+
"--exp-file", "-f", type=str, help="Expt. identifiers and names listed in a YAML file, i.e., dir_name: exp_name",
109+
)
110+
111+
parser.add_argument(
112+
"--dir-name", "-d", type=str, help="dir name where expt. CSVs are stored"
113+
)
114+
115+
parser.add_argument(
116+
"--exp-name", "-e", type=str, help="expt name, corresponds to the names of the CSV stats files and the <config>.py file used for the expt."
117+
)
118+
119+
parser.add_argument(
120+
"--show-plots", "-p", action='store_true', dest='show_plots', help="Toggle displaying plots", default=False,
121+
)
122+
123+
parser.add_argument(
124+
"--num-expts", "-n", type=int, help="First n expts in the list are plotted"
125+
)
126+
127+
args = parser.parse_args()
128+
129+
# print(args)
130+
131+
if args.exp_file is not None:
132+
with open(args.exp_file) as f:
133+
yaml_dict = parse_preserving_duplicates(f) # yaml.safe_load(f)
134+
135+
print("List of expts.:", yaml_dict)
136+
137+
i = 0
138+
for dir_name in yaml_dict:
139+
if len(yaml_dict[dir_name]) > 1:
140+
print("More than 1 expt. for the same expt_id:", dir_name, ". The expts.:", yaml_dict[dir_name])
141+
for j in range(len(yaml_dict[dir_name])):
142+
i += 1
143+
print("\nExpt. no.:", i , "from the list.")
144+
145+
exp_name = yaml_dict[dir_name][j].split(' ')[0]
146+
options = ' '.join(yaml_dict[dir_name][j].split(' ')[1:]) if ' ' in yaml_dict[dir_name][j] else ''
147+
# if 'learn_curves' in options:
148+
# if 'breakout_r' in exp_name:
149+
generate_plots(dir_name=dir_name, exp_name=exp_name, show_plots=args.show_plots, options=options)
150+
151+
# Need to break out of 2 for loops
152+
if args.num_expts is not None and i == args.num_expts:
153+
break
154+
155+
if args.num_expts is not None and i == args.num_expts:
156+
break
157+
158+
159+
else:
160+
dict_args = vars(args)
161+
del dict_args['exp_file']
162+
dict_args['dir_name'] = dict_args['dir_name'].split(' ')[0]
163+
dict_args['options'] = ' '.join(dict_args['dir_name'].split(' ')[1:]) if ' ' in dict_args['dir_name'] else ''
164+
# print(dict_args)
165+
generate_plots(**dict_args)
166+
167+

‎misc/test_expt_list.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/home/rajanr/mdpp_15786014: dqn_seq_del learn_curves normalise_episodic_reward plot_2d auto_y_scale
2+
/home/rajanr/mdpp_15786214: rainbow_seq_del normalise_episodic_reward plot_2d auto_y_scale
3+
/home/rajanr/mdpp_15786316: a3c_seq_del normalise_episodic_reward plot_2d auto_y_scale

0 commit comments

Comments
 (0)