Skip to content

Commit 445da74

Browse files
MAJOR: Added render() to RLToyEnv; increase compatibility with Gymnasium v1.0.0
1 parent 071dffe commit 445da74

File tree

5 files changed

+164
-63
lines changed

5 files changed

+164
-63
lines changed

‎example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
from mdp_playground.envs import RLToyEnv
3333
import numpy as np
3434

35-
display_images = True
36-
3735
def display_image(obs, mode="RGB"):
3836
# Display the image observation associated with the next state
3937
from PIL import Image
@@ -411,6 +409,8 @@ def atari_wrapper_example():
411409

412410
from mdp_playground.envs import GymEnvWrapper
413411
import gymnasium as gym
412+
import ale_py
413+
gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit
414414

415415
ae = gym.make("QbertNoFrameskip-v4")
416416
env = GymEnvWrapper(ae, **config)

‎mdp_playground/envs/gym_env_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM
1212
import logging
1313

14+
# Needed from Gymnasium v1.0.0 onwards
15+
import ale_py
16+
gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit
17+
18+
1419
# def get_gym_wrapper(base_class):
1520

1621

‎mdp_playground/envs/rl_toy_env.py

Lines changed: 145 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525

2626
class RLToyEnv(gym.Env):
27+
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
28+
2729
"""
2830
The base toy environment in MDP Playground. It is parameterised by a config dict and can be instantiated to be an MDP with any of the possible dimensions from the accompanying research paper. The class extends OpenAI Gym's environment gym.Env.
2931
@@ -428,61 +430,65 @@ def __init__(self, **config):
428430
self.image_representations = False
429431
else:
430432
self.image_representations = config["image_representations"]
431-
if "image_transforms" in config:
432-
assert config["state_space_type"] == "discrete", (
433-
"Image " "transforms are only applicable to discrete envs."
434-
)
435-
self.image_transforms = config["image_transforms"]
436-
else:
437-
self.image_transforms = "none"
438433

439-
if "image_width" in config:
440-
self.image_width = config["image_width"]
441-
else:
442-
self.image_width = 100
434+
# Moved these out of the image_representations block when adding render()
435+
# because they are needed for the render() method even if image_representations
436+
# is False.
437+
if "image_transforms" in config:
438+
assert config["state_space_type"] == "discrete", (
439+
"Image " "transforms are only applicable to discrete envs."
440+
)
441+
self.image_transforms = config["image_transforms"]
442+
else:
443+
self.image_transforms = "none"
443444

444-
if "image_height" in config:
445-
self.image_height = config["image_height"]
446-
else:
447-
self.image_height = 100
445+
if "image_width" in config:
446+
self.image_width = config["image_width"]
447+
else:
448+
self.image_width = 100
448449

449-
# The following transforms are only applicable in discrete envs:
450-
if config["state_space_type"] == "discrete":
451-
if "image_sh_quant" not in config:
452-
if "shift" in self.image_transforms:
453-
warnings.warn(
454-
"Setting image shift quantisation to the \
455-
default of 1, since no config value was provided for it."
456-
)
457-
self.image_sh_quant = 1
458-
else:
459-
self.image_sh_quant = None
450+
if "image_height" in config:
451+
self.image_height = config["image_height"]
452+
else:
453+
self.image_height = 100
454+
455+
# The following transforms are only applicable in discrete envs:
456+
if config["state_space_type"] == "discrete":
457+
if "image_sh_quant" not in config:
458+
if "shift" in self.image_transforms:
459+
warnings.warn(
460+
"Setting image shift quantisation to the \
461+
default of 1, since no config value was provided for it."
462+
)
463+
self.image_sh_quant = 1
460464
else:
461-
self.image_sh_quant = config["image_sh_quant"]
465+
self.image_sh_quant = None
466+
else:
467+
self.image_sh_quant = config["image_sh_quant"]
462468

463-
if "image_ro_quant" not in config:
464-
if "rotate" in self.image_transforms:
465-
warnings.warn(
466-
"Setting image rotate quantisation to the \
467-
default of 1, since no config value was provided for it."
468-
)
469-
self.image_ro_quant = 1
470-
else:
471-
self.image_ro_quant = None
469+
if "image_ro_quant" not in config:
470+
if "rotate" in self.image_transforms:
471+
warnings.warn(
472+
"Setting image rotate quantisation to the \
473+
default of 1, since no config value was provided for it."
474+
)
475+
self.image_ro_quant = 1
472476
else:
473-
self.image_ro_quant = config["image_ro_quant"]
477+
self.image_ro_quant = None
478+
else:
479+
self.image_ro_quant = config["image_ro_quant"]
474480

475-
if "image_scale_range" not in config:
476-
if "scale" in self.image_transforms:
477-
warnings.warn(
478-
"Setting image scale range to the default \
479-
of (0.5, 1.5), since no config value was provided for it."
480-
)
481-
self.image_scale_range = (0.5, 1.5)
482-
else:
483-
self.image_scale_range = None
481+
if "image_scale_range" not in config:
482+
if "scale" in self.image_transforms:
483+
warnings.warn(
484+
"Setting image scale range to the default \
485+
of (0.5, 1.5), since no config value was provided for it."
486+
)
487+
self.image_scale_range = (0.5, 1.5)
484488
else:
485-
self.image_scale_range = config["image_scale_range"]
489+
self.image_scale_range = None
490+
else:
491+
self.image_scale_range = config["image_scale_range"]
486492

487493
# Defaults for the individual environment types:
488494
if config["state_space_type"] == "discrete":
@@ -827,6 +833,15 @@ def __init__(self, **config):
827833
+ ", "
828834
+ str(len(self.augmented_state))
829835
)
836+
837+
# Needed for rendering with pygame for use with Gymnasium.Env's render() method:
838+
render_mode = config.get("render_mode", None)
839+
assert render_mode is None or render_mode in self.metadata["render_modes"]
840+
self.render_mode = render_mode
841+
842+
self.window = None
843+
self.clock = None
844+
830845
self.logger.debug(
831846
"MDP Playground toy env instantiated with config: " + str(self.config)
832847
)
@@ -1639,7 +1654,8 @@ def transition_function(self, state, action):
16391654
/ factorial_array[j]
16401655
)
16411656
# print('self.state_derivatives:', self.state_derivatives)
1642-
next_state = self.state_derivatives[0]
1657+
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
1658+
next_state = self.state_derivatives[0].copy()
16431659

16441660
else: # if action is from outside allowed action_space
16451661
next_state = state
@@ -1684,7 +1700,8 @@ def transition_function(self, state, action):
16841700
self.state_derivatives = [
16851701
zero_state.copy() for i in range(self.dynamics_order + 1)
16861702
]
1687-
self.state_derivatives[0] = next_state
1703+
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
1704+
self.state_derivatives[0] = next_state.copy()
16881705

16891706
if self.config["reward_function"] == "move_to_a_point":
16901707
next_state_rel = np.array(next_state, dtype=self.dtype_s)[
@@ -2126,7 +2143,7 @@ def get_augmented_state(self):
21262143

21272144
return augmented_state_dict
21282145

2129-
def reset(self, seed=None):
2146+
def reset(self, seed=None, options=None):
21302147
"""Resets the environment for the beginning of an episode and samples a start state from rho_0. For discrete environments uses the defined rho_0 directly. For continuous environments, samples a state and resamples until a non-terminal state is sampled.
21312148
21322149
Returns
@@ -2225,7 +2242,8 @@ def reset(self, seed=None):
22252242
zero_state.copy() for i in range(self.dynamics_order + 1)
22262243
] # #####IMP to have copy()
22272244
# otherwise it's the same array (in memory) at every position in the list
2228-
self.state_derivatives[0] = self.curr_state
2245+
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
2246+
self.state_derivatives[0] = self.curr_state.copy()
22292247

22302248
self.augmented_state = [
22312249
[np.nan] * self.state_space_dim
@@ -2316,6 +2334,82 @@ def seed(self, seed=None):
23162334
)
23172335
return self.seed_
23182336

2337+
def render(self,):
2338+
'''
2339+
Renders the environment using pygame if render_mode is "human" and returns the rendered
2340+
image if render_mode is "rgb_array".
2341+
2342+
Based on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
2343+
'''
2344+
2345+
import pygame
2346+
2347+
# Init stuff on first call. For non-image_representations based envs, it makes sense
2348+
# to only instantiate the render_space here and not in __init__ because it's only needed
2349+
# if render() is called.
2350+
if self.window is None:
2351+
if self.image_representations:
2352+
self.render_space = self.observation_space
2353+
else:
2354+
if self.config["state_space_type"] == "discrete":
2355+
self.render_space = ImageMultiDiscrete(
2356+
self.state_space_size,
2357+
width=self.image_width,
2358+
height=self.image_height,
2359+
transforms=self.image_transforms,
2360+
sh_quant=self.image_sh_quant,
2361+
scale_range=self.image_scale_range,
2362+
ro_quant=self.image_ro_quant,
2363+
circle_radius=20,
2364+
seed=self.seed_dict["image_representations"],
2365+
) # #seed
2366+
elif self.config["state_space_type"] == "continuous":
2367+
self.render_space = ImageContinuous(
2368+
self.feature_space,
2369+
width=self.image_width,
2370+
height=self.image_height,
2371+
term_spaces=self.term_spaces,
2372+
target_point=self.target_point,
2373+
circle_radius=5,
2374+
seed=self.seed_dict["image_representations"],
2375+
) # #seed
2376+
elif self.config["state_space_type"] == "grid":
2377+
target_pt = list_to_float_np_array(self.target_point)
2378+
self.render_space = ImageContinuous(
2379+
self.feature_space,
2380+
width=self.image_width,
2381+
height=self.image_height,
2382+
term_spaces=self.term_spaces,
2383+
target_point=target_pt,
2384+
circle_radius=5,
2385+
grid_shape=self.grid_shape,
2386+
seed=self.seed_dict["image_representations"],
2387+
) # #seed
2388+
2389+
2390+
if self.window is None and self.render_mode == "human":
2391+
pygame.init()
2392+
pygame.display.init()
2393+
self.window = pygame.display.set_mode(
2394+
(self.image_width, self.image_height)
2395+
)
2396+
if self.clock is None and self.render_mode == "human":
2397+
self.clock = pygame.time.Clock()
2398+
2399+
# ##TODO There are repeated calculations here in calling get_concatenated_image
2400+
# that can be taken from storing variables in step() or reset().
2401+
if self.render_mode == "human":
2402+
rgb_array = self.render_space.get_concatenated_image(self.curr_state)
2403+
pygame_surface = pygame.surfarray.make_surface(rgb_array)
2404+
self.window.blit(pygame_surface, pygame_surface.get_rect())
2405+
pygame.event.pump()
2406+
pygame.display.update()
2407+
2408+
# We need to ensure that human-rendering occurs at the predefined framerate.
2409+
# The following line will automatically add a delay to keep the framerate stable.
2410+
self.clock.tick(self.metadata["render_fps"])
2411+
elif self.render_mode == "rgb_array":
2412+
return self.render_space.get_concatenated_image(self.curr_state)
23192413

23202414
def dist_of_pt_from_line(pt, ptA, ptB):
23212415
"""Returns shortest distance of a point from a line defined by 2 points - ptA and ptB.

‎mdp_playground/spaces/image_continuous.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def get_concatenated_image(self, obs):
212212
# image to have >=3 dims
213213

214214
def convert_to_pixel(self, position):
215-
""" """
215+
"""
216+
Convert a continuous position to a pixel position in the image
217+
"""
216218
# It's implicit that both relevant and irrelevant sub-spaces have the
217219
# same max and min here:
218220
max = self.feature_space.high[self.relevant_indices]

‎tests/test_gym_env_wrapper.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_r_delay(self):
4444

4545
ae = gym.make("BeamRiderNoFrameskip-v4")
4646
aew = GymEnvWrapper(ae, **config)
47-
ob = aew.reset()
47+
ob, _ = aew.reset()
4848
print("observation_space.shape:", ob.shape)
4949
# print(ob)
5050
total_reward = 0.0
@@ -83,7 +83,7 @@ def test_r_shift(self):
8383

8484
ae = gym.make("BeamRiderNoFrameskip-v4")
8585
aew = GymEnvWrapper(ae, **config)
86-
ob = aew.reset()
86+
ob, _ = aew.reset()
8787
print("observation_space.shape:", ob.shape)
8888
# print(ob)
8989
total_reward = 0.0
@@ -123,7 +123,7 @@ def test_r_scale(self):
123123

124124
ae = gym.make("BeamRiderNoFrameskip-v4")
125125
aew = GymEnvWrapper(ae, **config)
126-
ob = aew.reset()
126+
ob, _ = aew.reset()
127127
print("observation_space.shape:", ob.shape)
128128
# print(ob)
129129
total_reward = 0.0
@@ -164,7 +164,7 @@ def test_r_scale(self):
164164

165165
# ae = gym.make("BeamRiderNoFrameskip-v4")
166166
# aew = GymEnvWrapper(ae, **config)
167-
# ob = aew.reset()
167+
# ob, _ = aew.reset()
168168
# print("observation_space.shape:", ob.shape)
169169
# # print(ob)
170170
# total_reward = 0.0
@@ -211,7 +211,7 @@ def test_r_scale(self):
211211
# game = "".join([g.capitalize() for g in game.split("_")])
212212
# ae = gym.make("{}NoFrameskip-v4".format(game))
213213
# aew = GymEnvWrapper(ae, **config)
214-
# ob = aew.reset()
214+
# ob, _ = aew.reset()
215215
# print("observation_space.shape:", ob.shape)
216216
# # print(ob)
217217
# total_reward = 0.0
@@ -253,7 +253,7 @@ def test_r_delay_p_noise_r_noise(self):
253253

254254
ae = gym.make("BeamRiderNoFrameskip-v4")
255255
aew = GymEnvWrapper(ae, **config)
256-
ob = aew.reset()
256+
ob, _ = aew.reset()
257257
print("observation_space.shape:", ob.shape)
258258
# print(ob)
259259
total_reward = 0.0
@@ -316,7 +316,7 @@ def test_discrete_irr_features(self):
316316

317317
ae = gym.make("BeamRiderNoFrameskip-v4")
318318
aew = GymEnvWrapper(ae, **config)
319-
ob = aew.reset()
319+
ob, _ = aew.reset()
320320
print("type(observation_space):", type(ob))
321321
# print(ob)
322322
total_reward = 0.0
@@ -364,7 +364,7 @@ def test_image_transforms(self):
364364

365365
ae = gym.make("BeamRiderNoFrameskip-v4")
366366
aew = GymEnvWrapper(ae, **config)
367-
ob = aew.reset()
367+
ob, _ = aew.reset()
368368
print("observation_space.shape:", ob.shape)
369369
assert ob.shape == (100, 100, 3), "Observation shape of the env was unexpected."
370370
# print(ob)
@@ -420,7 +420,7 @@ def test_cont_irr_features(self):
420420
# register_env("HalfCheetahWrapper-v3", lambda config: HalfCheetahWrapperV3(**config))
421421

422422
hc3w = GymEnvWrapper(hc3, **config)
423-
ob = hc3w.reset()
423+
ob, _ = hc3w.reset()
424424
print("obs shape, type(observation_space):", ob.shape, type(ob))
425425
print("initial obs: ", ob)
426426
assert (

0 commit comments

Comments
 (0)