Skip to content

Commit 3411bf7

Browse files
Added default reward_function for cont. envs; remove bug in ImageContinuous;
1 parent ce5aa6a commit 3411bf7

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

‎example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def discrete_environment_image_representations_example():
127127
augmented_state_dict = env.get_augmented_state()
128128
next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds
129129
# the current discrete state.
130-
print("sars', done =", state, action, reward, next_state, done)
130+
print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape)
131131

132132
env.close()
133133

@@ -175,7 +175,7 @@ def discrete_environment_diameter_image_representations_example():
175175
augmented_state_dict = env.get_augmented_state()
176176
next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds
177177
# the current discrete state.
178-
print("sars', done =", state, action, reward, next_state, done)
178+
print("sars', done, shape =", state, action, reward, next_state, done, next_state_image.shape)
179179

180180
env.close()
181181

@@ -262,7 +262,7 @@ def continuous_environment_example_move_to_a_point_irrelevant_image():
262262
augmented_state_dict = env.get_augmented_state()
263263
next_state = augmented_state_dict["curr_state"].copy() # Underlying MDP state holds
264264
# the current continuous state.
265-
print("sars', done =", state, action, reward, next_state, done)
265+
print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape)
266266

267267
env.close()
268268

@@ -388,7 +388,7 @@ def grid_environment_image_representations_example():
388388
action = actions[i]
389389
next_obs, reward, done, trunc, info = env.step(action)
390390
next_state = env.get_augmented_state()["augmented_state"][-1]
391-
print("sars', done =", state, action, reward, next_state, done)
391+
print("sars', done, image shape =", state, action, reward, next_state, done, next_obs.shape)
392392
state = next_state
393393

394394
env.reset()[0]

‎mdp_playground/envs/rl_toy_env.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def __init__(self, **config):
345345
# if config["state_space_type"] == "discrete":
346346
# assert "init_state_dist" in config
347347

348+
# Common defaults for all types of environments:
348349
if "terminal_state_density" not in config:
349350
self.terminal_state_density = 0.25
350351
else:
@@ -483,6 +484,7 @@ def __init__(self, **config):
483484
else:
484485
self.image_scale_range = config["image_scale_range"]
485486

487+
# Defaults for the individual environment types:
486488
if config["state_space_type"] == "discrete":
487489
if "reward_dist" not in config:
488490
self.reward_dist = None
@@ -498,6 +500,11 @@ def __init__(self, **config):
498500
# if not self.use_custom_mdp:
499501
self.state_space_dim = config["state_space_dim"]
500502

503+
# ##TODO Do something to dismbiguate the Python function reward_function from the
504+
# choice of reward_function below.
505+
if "reward_function" not in config:
506+
config["reward_function"] = "move_to_a_point"
507+
501508
if "transition_dynamics_order" not in config:
502509
self.dynamics_order = 1
503510
else:
@@ -548,8 +555,9 @@ def __init__(self, **config):
548555
self.repeats_in_sequences = config["repeats_in_sequences"]
549556

550557

558+
# ##TODO Move these to the individual env types' defaults section above?
551559
if config["state_space_type"] == "discrete":
552-
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
560+
self.dtype_s = np.int32 if "dtype_s" not in config else config["dtype_s"]
553561
if self.irrelevant_features:
554562
assert (
555563
len(config["action_space_size"]) == 2
@@ -589,7 +597,7 @@ def __init__(self, **config):
589597

590598
# Set the dtype for the observation space:
591599
if self.image_representations:
592-
self.dtype_o = np.float32 if "dtype_o" not in config else config["dtype_o"]
600+
self.dtype_o = np.uint8 if "dtype_o" not in config else config["dtype_o"]
593601
else:
594602
self.dtype_o = self.dtype_s if "dtype_o" not in config else config["dtype_o"]
595603

‎mdp_playground/spaces/image_continuous.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
term_spaces=None,
2525
width=100,
2626
height=100,
27+
num_channels=3,
2728
circle_radius=5,
2829
target_point=None,
2930
relevant_indices=[0, 1],
@@ -43,6 +44,8 @@ def __init__(
4344
The width of the image
4445
height : int
4546
The height of the image
47+
num_channels : int
48+
The number of channels in the image ###TODO: Support for 1 channel; unify with ImageMultiDiscrete
4649
circle_radius : int
4750
The radius of the circle which represents the agent and target point
4851
target_point : np.array
@@ -60,6 +63,7 @@ def __init__(
6063
assert (self.feature_space.low != -np.inf).any()
6164
self.width = width
6265
self.height = height
66+
self.num_channels = num_channels
6367
# Warn if resolution is too low?
6468
self.circle_radius = circle_radius
6569
self.target_point = target_point
@@ -99,7 +103,7 @@ def __init__(
99103

100104
# Shape has 1 appended for Ray Rllib to be compatible IIRC
101105
super(ImageContinuous, self).__init__(
102-
shape=(width, height, 1), dtype=dtype, low=0, high=255
106+
shape=(width, height, num_channels), dtype=dtype, low=0, high=255
103107
)
104108
super(ImageContinuous, self).seed(seed=seed)
105109

@@ -117,10 +121,10 @@ def generate_image(self, position, relevant=True):
117121
118122
"""
119123
# Use RGB
120-
image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour)
121-
# Use L for black and white 8-bit pixels instead of RGB in case not
122-
# using custom images
123-
# image_ = Image.new("L", (self.width, self.height))
124+
if self.num_channels == 3:
125+
image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour)
126+
elif self.num_channels == 1:
127+
image_ = Image.new("L", (self.width, self.height), color=self.bg_colour)
124128
draw = ImageDraw.Draw(image_)
125129

126130
# Draw in decreasing order of importance:
@@ -239,7 +243,7 @@ def contains(self, x):
239243
if x.shape == (
240244
self.width,
241245
self.height,
242-
1,
246+
self.num_channels,
243247
): # TODO compare each pixel for all possible images?
244248
return True
245249

0 commit comments

Comments
 (0)