@@ -176,6 +176,8 @@ class RLToyEnv(gym.Env):
176
176
The externally visible observation space for the enviroment.
177
177
action_space : Gym.Space
178
178
The externally visible action space for the enviroment.
179
+ feature_space : Gym.Space
180
+ In case of continuous and grid environments, this is the underlying state space. ##TODO Unify this across all types of environments.
179
181
rewardable_sequences : dict
180
182
holds the rewardable sequences. The keys are tuples of rewardable sequences and values are the rewards handed out. When make_denser is True for discrete environments, this dict also holds the rewardable partial sequences.
181
183
@@ -519,7 +521,6 @@ def __init__(self, **config):
519
521
elif config ["state_space_type" ] == "grid" :
520
522
assert "grid_shape" in config
521
523
self .grid_shape = config ["grid_shape" ]
522
- self .grid_np_data_type = np .int64
523
524
else :
524
525
raise ValueError ("Unknown state_space_type" )
525
526
@@ -546,9 +547,9 @@ def __init__(self, **config):
546
547
else :
547
548
self .repeats_in_sequences = config ["repeats_in_sequences" ]
548
549
549
- self .dtype = np .float32 if "dtype" not in config else config ["dtype" ]
550
550
551
551
if config ["state_space_type" ] == "discrete" :
552
+ self .dtype_s = np .int64 if "dtype_s" not in config else config ["dtype_s" ]
552
553
if self .irrelevant_features :
553
554
assert (
554
555
len (config ["action_space_size" ]) == 2
@@ -570,6 +571,7 @@ def __init__(self, **config):
570
571
)
571
572
# assert (np.array(self.state_space_size) % np.array(self.diameter) == 0).all(), "state_space_size should be a multiple of the diameter to allow for the generation of regularly connected MDPs."
572
573
elif config ["state_space_type" ] == "continuous" :
574
+ self .dtype_s = np .float32 if "dtype_s" not in config else config ["dtype_s" ]
573
575
self .action_space_dim = self .state_space_dim
574
576
if self .irrelevant_features :
575
577
assert (
@@ -580,10 +582,18 @@ def __init__(self, **config):
580
582
config ["relevant_indices" ] = range (self .state_space_dim )
581
583
# config["irrelevant_indices"] = list(set(range(len(config["state_space_dim"]))) - set(config["relevant_indices"]))
582
584
elif config ["state_space_type" ] == "grid" :
585
+ self .dtype_s = np .int64 if "dtype_s" not in config else config ["dtype_s" ]
583
586
# Repeat the grid for the irrelevant part as well
584
587
if self .irrelevant_features :
585
588
self .grid_shape = self .grid_shape * 2
586
589
590
+ # Set the dtype for the observation space:
591
+ if self .image_representations :
592
+ self .dtype_o = np .float32 if "dtype_o" not in config else config ["dtype_o" ]
593
+ else :
594
+ self .dtype_o = self .dtype_s if "dtype_o" not in config else config ["dtype_o" ]
595
+
596
+
587
597
if ("init_state_dist" in config ) and ("relevant_init_state_dist" not in config ):
588
598
config ["relevant_init_state_dist" ] = config ["init_state_dist" ]
589
599
@@ -614,7 +624,7 @@ def __init__(self, **config):
614
624
assert self .sequence_length == 1
615
625
if "target_point" in config :
616
626
self .target_point = np .array (
617
- config ["target_point" ], dtype = self .dtype
627
+ config ["target_point" ], dtype = self .dtype_s
618
628
)
619
629
assert self .target_point .shape == (
620
630
len (config ["relevant_indices" ]),
@@ -640,6 +650,7 @@ def __init__(self, **config):
640
650
DiscreteExtended (
641
651
self .state_space_size [0 ],
642
652
seed = self .seed_dict ["relevant_state_space" ],
653
+ # dtype=self.dtype_o, # Gymnasium seems to hardcode as np.int64
643
654
)
644
655
] # #seed #hardcoded, many time below as well
645
656
self .action_spaces = [
@@ -671,7 +682,7 @@ def __init__(self, **config):
671
682
# self.action_spaces[i] = DiscreteExtended(self.action_space_size[i],
672
683
# seed=self.seed_dict["irrelevant_action_space"]) #seed
673
684
674
- if self .image_representations :
685
+ if self .image_representations : # for discrete envs
675
686
# underlying_obs_space = MultiDiscreteExtended(self.state_space_size, seed=self.seed_dict["state_space"]) #seed
676
687
self .observation_space = ImageMultiDiscrete (
677
688
self .state_space_size ,
@@ -714,7 +725,7 @@ def __init__(self, **config):
714
725
self .state_space_max ,
715
726
shape = (self .state_space_dim ,),
716
727
seed = self .seed_dict ["state_space" ],
717
- dtype = self .dtype ,
728
+ dtype = self .dtype_s ,
718
729
) # #seed
719
730
# hack #TODO # low and high are 1st 2 and required arguments
720
731
# for instantiating BoxExtended
@@ -729,7 +740,7 @@ def __init__(self, **config):
729
740
self .action_space_max ,
730
741
shape = (self .action_space_dim ,),
731
742
seed = self .seed_dict ["action_space" ],
732
- dtype = self .dtype ,
743
+ dtype = self .dtype_s ,
733
744
) # #seed
734
745
# hack #TODO
735
746
@@ -754,7 +765,7 @@ def __init__(self, **config):
754
765
0 * underlying_space_maxes ,
755
766
underlying_space_maxes ,
756
767
seed = self .seed_dict ["state_space" ],
757
- dtype = self .dtype ,
768
+ dtype = self .dtype_s ,
758
769
) # #seed
759
770
760
771
lows = np .array ([- 1 ] * len (self .grid_shape ))
@@ -893,7 +904,7 @@ def init_terminal_states(self):
893
904
# print("Term state lows, highs:", lows, highs)
894
905
self .term_spaces .append (
895
906
BoxExtended (
896
- low = lows , high = highs , seed = self .seed_ , dtype = self .dtype
907
+ low = lows , high = highs , seed = self .seed_ , dtype = self .dtype_s
897
908
)
898
909
) # #seed #hack #TODO
899
910
self .logger .debug (
@@ -931,7 +942,7 @@ def init_terminal_states(self):
931
942
highs = term_state # #hardcoded
932
943
self .term_spaces .append (
933
944
BoxExtended (
934
- low = lows , high = highs , seed = self .seed_ , dtype = self .grid_np_data_type
945
+ low = lows , high = highs , seed = self .seed_ , dtype = self .dtype_s
935
946
)
936
947
) # #seed #hack #TODO
937
948
@@ -1657,7 +1668,7 @@ def transition_function(self, state, action):
1657
1668
# for a "wall", but would need to take care of multiple
1658
1669
# reflections near a corner/edge.
1659
1670
# Resets all higher order derivatives to 0
1660
- zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype )
1671
+ zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype_s )
1661
1672
# #####IMP to have copy() otherwise it's the same array
1662
1673
# (in memory) at every position in the list:
1663
1674
self .state_derivatives = [
@@ -1666,7 +1677,7 @@ def transition_function(self, state, action):
1666
1677
self .state_derivatives [0 ] = next_state
1667
1678
1668
1679
if self .config ["reward_function" ] == "move_to_a_point" :
1669
- next_state_rel = np .array (next_state , dtype = self .dtype )[
1680
+ next_state_rel = np .array (next_state , dtype = self .dtype_s )[
1670
1681
self .config ["relevant_indices" ]
1671
1682
]
1672
1683
dist_ = np .linalg .norm (next_state_rel - self .target_point )
@@ -1678,7 +1689,7 @@ def transition_function(self, state, action):
1678
1689
# Need to check that dtype is int because Gym doesn't
1679
1690
if (
1680
1691
self .action_space .contains (action )
1681
- and np .array (action ).dtype == self .grid_np_data_type
1692
+ and np .array (action ).dtype == self .dtype_s
1682
1693
):
1683
1694
if self .transition_noise :
1684
1695
# self._np_random.choice only works for 1-D arrays
@@ -1820,7 +1831,7 @@ def reward_function(self, state, action):
1820
1831
# of the formulae and see that programmatic results match: should
1821
1832
# also have a unit version of 4. for dist_of_pt_from_line() and
1822
1833
# an integration version here for total_deviation calc.?.
1823
- data_ = np .array (state_considered , dtype = self .dtype )[
1834
+ data_ = np .array (state_considered , dtype = self .dtype_s )[
1824
1835
1 + delay : self .augmented_state_length ,
1825
1836
self .config ["relevant_indices" ],
1826
1837
]
@@ -1863,10 +1874,10 @@ def reward_function(self, state, action):
1863
1874
# that. #TODO Generate it randomly to have random Rs?
1864
1875
if self .make_denser :
1865
1876
old_relevant_state = np .array (
1866
- state_considered , dtype = self .dtype
1877
+ state_considered , dtype = self .dtype_s
1867
1878
)[- 2 , self .config ["relevant_indices" ]]
1868
1879
new_relevant_state = np .array (
1869
- state_considered , dtype = self .dtype
1880
+ state_considered , dtype = self .dtype_s
1870
1881
)[- 1 , self .config ["relevant_indices" ]]
1871
1882
reward = - np .linalg .norm (new_relevant_state - self .target_point )
1872
1883
# Should allow other powers of the distance from target_point,
@@ -1879,7 +1890,7 @@ def reward_function(self, state, action):
1879
1890
# TODO also make_denser, sparse rewards only at target
1880
1891
else : # sparse reward
1881
1892
new_relevant_state = np .array (
1882
- state_considered , dtype = self .dtype
1893
+ state_considered , dtype = self .dtype_s
1883
1894
)[- 1 , self .config ["relevant_indices" ]]
1884
1895
if (
1885
1896
np .linalg .norm (new_relevant_state - self .target_point )
@@ -1890,7 +1901,7 @@ def reward_function(self, state, action):
1890
1901
# stay in the radius and earn more reward.
1891
1902
1892
1903
reward -= self .action_loss_weight * np .linalg .norm (
1893
- np .array (action , dtype = self .dtype )
1904
+ np .array (action , dtype = self .dtype_s )
1894
1905
)
1895
1906
1896
1907
elif self .config ["state_space_type" ] == "grid" :
@@ -2044,8 +2055,8 @@ def step(self, action, imaginary_rollout=False):
2044
2055
if self .image_representations :
2045
2056
next_obs = self .observation_space .get_concatenated_image (next_state )
2046
2057
2047
- self .curr_state = next_state
2048
- self .curr_obs = next_obs
2058
+ self .curr_state = self . dtype_s ( next_state )
2059
+ self .curr_obs = self . dtype_o ( next_obs )
2049
2060
2050
2061
# #### TODO curr_state is external state, while we need to check relevant state for terminality! Done - by using augmented_state now instead of curr_state!
2051
2062
self .done = (
@@ -2199,7 +2210,7 @@ def reset(self, seed=None):
2199
2210
2200
2211
# if not self.use_custom_mdp:
2201
2212
# init the state derivatives needed for continuous spaces
2202
- zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype )
2213
+ zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype_s )
2203
2214
self .state_derivatives = [
2204
2215
zero_state .copy () for i in range (self .dynamics_order + 1 )
2205
2216
] # #####IMP to have copy()
@@ -2217,7 +2228,7 @@ def reset(self, seed=None):
2217
2228
while True : # Be careful about infinite loops
2218
2229
term_space_was_sampled = False
2219
2230
# curr_state is an np.array while curr_state_relevant is a list
2220
- self .curr_state = self .feature_space .sample ().astype (int ) # #random
2231
+ self .curr_state = self .feature_space .sample ().astype (self . dtype_s ) # #random
2221
2232
self .curr_state_relevant = list (self .curr_state [[0 , 1 ]]) # #hardcoded
2222
2233
if self .is_terminal_state (self .curr_state_relevant ):
2223
2234
self .logger .debug (
@@ -2241,6 +2252,9 @@ def reset(self, seed=None):
2241
2252
else :
2242
2253
self .curr_obs = self .curr_state
2243
2254
2255
+ self .curr_state = self .dtype_s (self .curr_state )
2256
+ self .curr_obs = self .dtype_o (self .curr_obs )
2257
+
2244
2258
self .logger .info ("RESET called. curr_state reset to: " + str (self .curr_state ))
2245
2259
self .reached_terminal = False
2246
2260
0 commit comments