Skip to content

Commit 071dffe

Browse files
Removed self.P and self.R since they caused issues with copy.deepcopy()
1 parent 3411bf7 commit 071dffe

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

‎mdp_playground/envs/rl_toy_env.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,8 @@ def __init__(self, **config):
557557

558558
# ##TODO Move these to the individual env types' defaults section above?
559559
if config["state_space_type"] == "discrete":
560-
self.dtype_s = np.int32 if "dtype_s" not in config else config["dtype_s"]
560+
# I think Gymnasium wants int64, but Dreamer-V3 (~June 2024) code may prefer int32
561+
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
561562
if self.irrelevant_features:
562563
assert (
563564
len(config["action_space_size"]) == 2
@@ -1226,7 +1227,8 @@ def init_transition_function(self):
12261227
# fixed parameterisation for cont. envs. right now.
12271228
pass
12281229

1229-
self.P = lambda s, a: self.transition_function(s, a)
1230+
# ####IMP Keep this commented out as it causes problems with deepcopy().
1231+
# self.P = lambda s, a: self.transition_function(s, a)
12301232

12311233
def init_reward_function(self):
12321234
"""Initialises reward function, R by selecting random sequences to be rewardable for discrete environments. For continuous environments, we have fixed available options for the reward function."""
@@ -1550,7 +1552,7 @@ def get_rews(rng, r_dict):
15501552
elif self.config["state_space_type"] == "grid":
15511553
... # ###TODO Make sequences compatible with grid
15521554

1553-
self.R = lambda s, a: self.reward_function(s, a)
1555+
# self.R = lambda s, a: self.reward_function(s, a)
15541556

15551557
def transition_function(self, state, action):
15561558
"""The transition function, P.
@@ -2009,7 +2011,7 @@ def step(self, action, imaginary_rollout=False):
20092011
# ### TODO Decide whether to give reward before or after transition ("after" would mean taking next state into account and seems more logical to me) - make it a dimension? - R(s) or R(s, a) or R(s, a, s')? I'd say give it after and store the old state in the augmented_state to be able to let the R have any of the above possible forms. That would also solve the problem of implicit 1-step delay with giving it before. _And_ would not give any reward for already being in a rewarding state in the 1st step but _would_ give a reward if 1 moved to a rewardable state - even if called with R(s, a) because s' is stored in the augmented_state! #####IMP
20102012

20112013
# ###TODO P uses last state while R uses augmented state; for cont. env, P does know underlying state_derivatives - we don't want this to be the case for the imaginary rollout scenario;
2012-
next_state = self.P(state, action)
2014+
next_state = self.transition_function(state, action)
20132015

20142016
# if imaginary_rollout:
20152017
# pass
@@ -2025,7 +2027,7 @@ def step(self, action, imaginary_rollout=False):
20252027

20262028
self.total_transitions_episode += 1
20272029

2028-
self.reward = self.R(self.augmented_state, action)
2030+
self.reward = self.reward_function(self.augmented_state, action)
20292031

20302032
# #irrelevant dimensions part
20312033
if self.config["state_space_type"] == "discrete":

‎mdp_playground/spaces/image_continuous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
"Image observations are " "supported only " "for 1- or 2-D feature spaces."
102102
)
103103

104-
# Shape has 1 appended for Ray Rllib to be compatible IIRC
104+
# Shape needs 3rd dimension for Ray Rllib to be compatible IIRC
105105
super(ImageContinuous, self).__init__(
106106
shape=(width, height, num_channels), dtype=dtype, low=0, high=255
107107
)
@@ -244,7 +244,7 @@ def contains(self, x):
244244
self.width,
245245
self.height,
246246
self.num_channels,
247-
): # TODO compare each pixel for all possible images?
247+
): # TODO compare each pixel for all possible image observations? Hard to implement.
248248
return True
249249

250250
def to_jsonable(self, sample_n):

0 commit comments

Comments
 (0)