Skip to content

Commit ce5aa6a

Browse files
Allow dtype_s and dtype_o of toy envs to be set for the underlying state space and observation space, respt. (action_space is currently set the same as the state space); partially fix some text cases.
1 parent 49551d2 commit ce5aa6a

File tree

2 files changed

+44
-29
lines changed

2 files changed

+44
-29
lines changed

‎mdp_playground/envs/rl_toy_env.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class RLToyEnv(gym.Env):
176176
The externally visible observation space for the enviroment.
177177
action_space : Gym.Space
178178
The externally visible action space for the enviroment.
179+
feature_space : Gym.Space
180+
In case of continuous and grid environments, this is the underlying state space. ##TODO Unify this across all types of environments.
179181
rewardable_sequences : dict
180182
holds the rewardable sequences. The keys are tuples of rewardable sequences and values are the rewards handed out. When make_denser is True for discrete environments, this dict also holds the rewardable partial sequences.
181183
@@ -519,7 +521,6 @@ def __init__(self, **config):
519521
elif config["state_space_type"] == "grid":
520522
assert "grid_shape" in config
521523
self.grid_shape = config["grid_shape"]
522-
self.grid_np_data_type = np.int64
523524
else:
524525
raise ValueError("Unknown state_space_type")
525526

@@ -546,9 +547,9 @@ def __init__(self, **config):
546547
else:
547548
self.repeats_in_sequences = config["repeats_in_sequences"]
548549

549-
self.dtype = np.float32 if "dtype" not in config else config["dtype"]
550550

551551
if config["state_space_type"] == "discrete":
552+
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
552553
if self.irrelevant_features:
553554
assert (
554555
len(config["action_space_size"]) == 2
@@ -570,6 +571,7 @@ def __init__(self, **config):
570571
)
571572
# assert (np.array(self.state_space_size) % np.array(self.diameter) == 0).all(), "state_space_size should be a multiple of the diameter to allow for the generation of regularly connected MDPs."
572573
elif config["state_space_type"] == "continuous":
574+
self.dtype_s = np.float32 if "dtype_s" not in config else config["dtype_s"]
573575
self.action_space_dim = self.state_space_dim
574576
if self.irrelevant_features:
575577
assert (
@@ -580,10 +582,18 @@ def __init__(self, **config):
580582
config["relevant_indices"] = range(self.state_space_dim)
581583
# config["irrelevant_indices"] = list(set(range(len(config["state_space_dim"]))) - set(config["relevant_indices"]))
582584
elif config["state_space_type"] == "grid":
585+
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
583586
# Repeat the grid for the irrelevant part as well
584587
if self.irrelevant_features:
585588
self.grid_shape = self.grid_shape * 2
586589

590+
# Set the dtype for the observation space:
591+
if self.image_representations:
592+
self.dtype_o = np.float32 if "dtype_o" not in config else config["dtype_o"]
593+
else:
594+
self.dtype_o = self.dtype_s if "dtype_o" not in config else config["dtype_o"]
595+
596+
587597
if ("init_state_dist" in config) and ("relevant_init_state_dist" not in config):
588598
config["relevant_init_state_dist"] = config["init_state_dist"]
589599

@@ -614,7 +624,7 @@ def __init__(self, **config):
614624
assert self.sequence_length == 1
615625
if "target_point" in config:
616626
self.target_point = np.array(
617-
config["target_point"], dtype=self.dtype
627+
config["target_point"], dtype=self.dtype_s
618628
)
619629
assert self.target_point.shape == (
620630
len(config["relevant_indices"]),
@@ -640,6 +650,7 @@ def __init__(self, **config):
640650
DiscreteExtended(
641651
self.state_space_size[0],
642652
seed=self.seed_dict["relevant_state_space"],
653+
# dtype=self.dtype_o, # Gymnasium seems to hardcode as np.int64
643654
)
644655
] # #seed #hardcoded, many time below as well
645656
self.action_spaces = [
@@ -671,7 +682,7 @@ def __init__(self, **config):
671682
# self.action_spaces[i] = DiscreteExtended(self.action_space_size[i],
672683
# seed=self.seed_dict["irrelevant_action_space"]) #seed
673684

674-
if self.image_representations:
685+
if self.image_representations: # for discrete envs
675686
# underlying_obs_space = MultiDiscreteExtended(self.state_space_size, seed=self.seed_dict["state_space"]) #seed
676687
self.observation_space = ImageMultiDiscrete(
677688
self.state_space_size,
@@ -714,7 +725,7 @@ def __init__(self, **config):
714725
self.state_space_max,
715726
shape=(self.state_space_dim,),
716727
seed=self.seed_dict["state_space"],
717-
dtype=self.dtype,
728+
dtype=self.dtype_s,
718729
) # #seed
719730
# hack #TODO # low and high are 1st 2 and required arguments
720731
# for instantiating BoxExtended
@@ -729,7 +740,7 @@ def __init__(self, **config):
729740
self.action_space_max,
730741
shape=(self.action_space_dim,),
731742
seed=self.seed_dict["action_space"],
732-
dtype=self.dtype,
743+
dtype=self.dtype_s,
733744
) # #seed
734745
# hack #TODO
735746

@@ -754,7 +765,7 @@ def __init__(self, **config):
754765
0 * underlying_space_maxes,
755766
underlying_space_maxes,
756767
seed=self.seed_dict["state_space"],
757-
dtype=self.dtype,
768+
dtype=self.dtype_s,
758769
) # #seed
759770

760771
lows = np.array([-1] * len(self.grid_shape))
@@ -893,7 +904,7 @@ def init_terminal_states(self):
893904
# print("Term state lows, highs:", lows, highs)
894905
self.term_spaces.append(
895906
BoxExtended(
896-
low=lows, high=highs, seed=self.seed_, dtype=self.dtype
907+
low=lows, high=highs, seed=self.seed_, dtype=self.dtype_s
897908
)
898909
) # #seed #hack #TODO
899910
self.logger.debug(
@@ -931,7 +942,7 @@ def init_terminal_states(self):
931942
highs = term_state # #hardcoded
932943
self.term_spaces.append(
933944
BoxExtended(
934-
low=lows, high=highs, seed=self.seed_, dtype=self.grid_np_data_type
945+
low=lows, high=highs, seed=self.seed_, dtype=self.dtype_s
935946
)
936947
) # #seed #hack #TODO
937948

@@ -1657,7 +1668,7 @@ def transition_function(self, state, action):
16571668
# for a "wall", but would need to take care of multiple
16581669
# reflections near a corner/edge.
16591670
# Resets all higher order derivatives to 0
1660-
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype)
1671+
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype_s)
16611672
# #####IMP to have copy() otherwise it's the same array
16621673
# (in memory) at every position in the list:
16631674
self.state_derivatives = [
@@ -1666,7 +1677,7 @@ def transition_function(self, state, action):
16661677
self.state_derivatives[0] = next_state
16671678

16681679
if self.config["reward_function"] == "move_to_a_point":
1669-
next_state_rel = np.array(next_state, dtype=self.dtype)[
1680+
next_state_rel = np.array(next_state, dtype=self.dtype_s)[
16701681
self.config["relevant_indices"]
16711682
]
16721683
dist_ = np.linalg.norm(next_state_rel - self.target_point)
@@ -1678,7 +1689,7 @@ def transition_function(self, state, action):
16781689
# Need to check that dtype is int because Gym doesn't
16791690
if (
16801691
self.action_space.contains(action)
1681-
and np.array(action).dtype == self.grid_np_data_type
1692+
and np.array(action).dtype == self.dtype_s
16821693
):
16831694
if self.transition_noise:
16841695
# self._np_random.choice only works for 1-D arrays
@@ -1820,7 +1831,7 @@ def reward_function(self, state, action):
18201831
# of the formulae and see that programmatic results match: should
18211832
# also have a unit version of 4. for dist_of_pt_from_line() and
18221833
# an integration version here for total_deviation calc.?.
1823-
data_ = np.array(state_considered, dtype=self.dtype)[
1834+
data_ = np.array(state_considered, dtype=self.dtype_s)[
18241835
1 + delay : self.augmented_state_length,
18251836
self.config["relevant_indices"],
18261837
]
@@ -1863,10 +1874,10 @@ def reward_function(self, state, action):
18631874
# that. #TODO Generate it randomly to have random Rs?
18641875
if self.make_denser:
18651876
old_relevant_state = np.array(
1866-
state_considered, dtype=self.dtype
1877+
state_considered, dtype=self.dtype_s
18671878
)[-2, self.config["relevant_indices"]]
18681879
new_relevant_state = np.array(
1869-
state_considered, dtype=self.dtype
1880+
state_considered, dtype=self.dtype_s
18701881
)[-1, self.config["relevant_indices"]]
18711882
reward = -np.linalg.norm(new_relevant_state - self.target_point)
18721883
# Should allow other powers of the distance from target_point,
@@ -1879,7 +1890,7 @@ def reward_function(self, state, action):
18791890
# TODO also make_denser, sparse rewards only at target
18801891
else: # sparse reward
18811892
new_relevant_state = np.array(
1882-
state_considered, dtype=self.dtype
1893+
state_considered, dtype=self.dtype_s
18831894
)[-1, self.config["relevant_indices"]]
18841895
if (
18851896
np.linalg.norm(new_relevant_state - self.target_point)
@@ -1890,7 +1901,7 @@ def reward_function(self, state, action):
18901901
# stay in the radius and earn more reward.
18911902

18921903
reward -= self.action_loss_weight * np.linalg.norm(
1893-
np.array(action, dtype=self.dtype)
1904+
np.array(action, dtype=self.dtype_s)
18941905
)
18951906

18961907
elif self.config["state_space_type"] == "grid":
@@ -2044,8 +2055,8 @@ def step(self, action, imaginary_rollout=False):
20442055
if self.image_representations:
20452056
next_obs = self.observation_space.get_concatenated_image(next_state)
20462057

2047-
self.curr_state = next_state
2048-
self.curr_obs = next_obs
2058+
self.curr_state = self.dtype_s(next_state)
2059+
self.curr_obs = self.dtype_o(next_obs)
20492060

20502061
# #### TODO curr_state is external state, while we need to check relevant state for terminality! Done - by using augmented_state now instead of curr_state!
20512062
self.done = (
@@ -2199,7 +2210,7 @@ def reset(self, seed=None):
21992210

22002211
# if not self.use_custom_mdp:
22012212
# init the state derivatives needed for continuous spaces
2202-
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype)
2213+
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype_s)
22032214
self.state_derivatives = [
22042215
zero_state.copy() for i in range(self.dynamics_order + 1)
22052216
] # #####IMP to have copy()
@@ -2217,7 +2228,7 @@ def reset(self, seed=None):
22172228
while True: # Be careful about infinite loops
22182229
term_space_was_sampled = False
22192230
# curr_state is an np.array while curr_state_relevant is a list
2220-
self.curr_state = self.feature_space.sample().astype(int) # #random
2231+
self.curr_state = self.feature_space.sample().astype(self.dtype_s) # #random
22212232
self.curr_state_relevant = list(self.curr_state[[0, 1]]) # #hardcoded
22222233
if self.is_terminal_state(self.curr_state_relevant):
22232234
self.logger.debug(
@@ -2241,6 +2252,9 @@ def reset(self, seed=None):
22412252
else:
22422253
self.curr_obs = self.curr_state
22432254

2255+
self.curr_state = self.dtype_s(self.curr_state)
2256+
self.curr_obs = self.dtype_o(self.curr_obs)
2257+
22442258
self.logger.info("RESET called. curr_state reset to: " + str(self.curr_state))
22452259
self.reached_terminal = False
22462260

‎tests/test_mdp_playground.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_continuous_dynamics_move_along_a_line(self):
173173
# Test 5: R noise - same as Test 1 above except with reward noise and with only 5 steps
174174
# instead of 20.
175175
print("\nTest 5: \033[32;1;4mTEST_CONTINUOUS_DYNAMICS_R_NOISE\033[0m")
176-
config["reward_noise"] = lambda a: a.normal(0, 0.5)
176+
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)
177177
config["delay"] = 0
178178
env = RLToyEnv(**config)
179179
state = env.get_augmented_state()["curr_state"].copy() # env.reset()[0]
@@ -303,7 +303,7 @@ def test_continuous_dynamics_move_along_a_line(self):
303303

304304
# Test P noise
305305
print("\nTest 9: \033[32;1;4mTEST_CONTINUOUS_DYNAMICS_P_NOISE\033[0m")
306-
config["transition_noise"] = lambda a: a.normal([0] * 7, [0.5] * 7)
306+
config["transition_noise"] = lambda s, a, rng: rng.normal([0] * 7, [0.5] * 7)
307307
# Reset seed to have states far away from state maxes so that it is easier to
308308
# test stuff below, but in the end, the state is clipped to [-5, 5] anyway
309309
# while testing, so this wasn't really needed.
@@ -1243,9 +1243,10 @@ def test_discrete_dynamics(self):
12431243
config["generate_random_mdp"] = True
12441244
env = RLToyEnv(**config)
12451245
state = env.get_augmented_state()["curr_state"]
1246-
self.assertEqual(
1247-
type(state), int, "Type of discrete state should be int."
1248-
) # TODO Move this and the test_continuous_dynamics type checks to separate unit tests
1246+
if type(state) != int:
1247+
self.assertEqual(
1248+
state.dtype, env.observation_space.dtype, "Type of discrete state should be: " + str(env.observation_space.dtype)
1249+
) # TODO Move this and the test_continuous_dynamics type checks to separate unit tests
12491250

12501251
action = 2
12511252
next_state, reward, done, trunc, info = env.step(action)
@@ -1482,7 +1483,7 @@ def test_discrete_r_noise(self):
14821483
config["delay"] = 0
14831484
config["sequence_length"] = 1
14841485
config["reward_scale"] = 1.0
1485-
config["reward_noise"] = lambda a: a.normal(0, 0.5)
1486+
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)
14861487

14871488
config["generate_random_mdp"] = True
14881489
config["log_level"] = logging.INFO
@@ -1545,7 +1546,7 @@ def test_discrete_multiple_meta_features(self):
15451546
config["reward_scale"] = 2.5
15461547
config["reward_shift"] = -1.75
15471548
# config["transition_noise"] = 0.1
1548-
config["reward_noise"] = lambda a: a.normal(0, 0.5)
1549+
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)
15491550

15501551
config["generate_random_mdp"] = True
15511552
env = RLToyEnv(**config)
@@ -1804,7 +1805,7 @@ def test_discrete_image_representations(self):
18041805
config["reward_scale"] = 2.5
18051806
config["reward_shift"] = -1.75
18061807
# config["transition_noise"] = 0.1
1807-
config["reward_noise"] = lambda a: a.normal(0, 0.5)
1808+
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)
18081809

18091810
config["generate_random_mdp"] = True
18101811

0 commit comments

Comments
 (0)