|
24 | 24 |
|
25 | 25 |
|
26 | 26 | class RLToyEnv(gym.Env):
|
| 27 | + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} |
| 28 | + |
27 | 29 | """
|
28 | 30 | The base toy environment in MDP Playground. It is parameterised by a config dict and can be instantiated to be an MDP with any of the possible dimensions from the accompanying research paper. The class extends OpenAI Gym's environment gym.Env.
|
29 | 31 |
|
@@ -428,61 +430,65 @@ def __init__(self, **config):
|
428 | 430 | self.image_representations = False
|
429 | 431 | else:
|
430 | 432 | self.image_representations = config["image_representations"]
|
431 |
| - if "image_transforms" in config: |
432 |
| - assert config["state_space_type"] == "discrete", ( |
433 |
| - "Image " "transforms are only applicable to discrete envs." |
434 |
| - ) |
435 |
| - self.image_transforms = config["image_transforms"] |
436 |
| - else: |
437 |
| - self.image_transforms = "none" |
438 | 433 |
|
439 |
| - if "image_width" in config: |
440 |
| - self.image_width = config["image_width"] |
441 |
| - else: |
442 |
| - self.image_width = 100 |
| 434 | + # Moved these out of the image_representations block when adding render() |
| 435 | + # because they are needed for the render() method even if image_representations |
| 436 | + # is False. |
| 437 | + if "image_transforms" in config: |
| 438 | + assert config["state_space_type"] == "discrete", ( |
| 439 | + "Image " "transforms are only applicable to discrete envs." |
| 440 | + ) |
| 441 | + self.image_transforms = config["image_transforms"] |
| 442 | + else: |
| 443 | + self.image_transforms = "none" |
443 | 444 |
|
444 |
| - if "image_height" in config: |
445 |
| - self.image_height = config["image_height"] |
446 |
| - else: |
447 |
| - self.image_height = 100 |
| 445 | + if "image_width" in config: |
| 446 | + self.image_width = config["image_width"] |
| 447 | + else: |
| 448 | + self.image_width = 100 |
448 | 449 |
|
449 |
| - # The following transforms are only applicable in discrete envs: |
450 |
| - if config["state_space_type"] == "discrete": |
451 |
| - if "image_sh_quant" not in config: |
452 |
| - if "shift" in self.image_transforms: |
453 |
| - warnings.warn( |
454 |
| - "Setting image shift quantisation to the \ |
455 |
| - default of 1, since no config value was provided for it." |
456 |
| - ) |
457 |
| - self.image_sh_quant = 1 |
458 |
| - else: |
459 |
| - self.image_sh_quant = None |
| 450 | + if "image_height" in config: |
| 451 | + self.image_height = config["image_height"] |
| 452 | + else: |
| 453 | + self.image_height = 100 |
| 454 | + |
| 455 | + # The following transforms are only applicable in discrete envs: |
| 456 | + if config["state_space_type"] == "discrete": |
| 457 | + if "image_sh_quant" not in config: |
| 458 | + if "shift" in self.image_transforms: |
| 459 | + warnings.warn( |
| 460 | + "Setting image shift quantisation to the \ |
| 461 | + default of 1, since no config value was provided for it." |
| 462 | + ) |
| 463 | + self.image_sh_quant = 1 |
460 | 464 | else:
|
461 |
| - self.image_sh_quant = config["image_sh_quant"] |
| 465 | + self.image_sh_quant = None |
| 466 | + else: |
| 467 | + self.image_sh_quant = config["image_sh_quant"] |
462 | 468 |
|
463 |
| - if "image_ro_quant" not in config: |
464 |
| - if "rotate" in self.image_transforms: |
465 |
| - warnings.warn( |
466 |
| - "Setting image rotate quantisation to the \ |
467 |
| - default of 1, since no config value was provided for it." |
468 |
| - ) |
469 |
| - self.image_ro_quant = 1 |
470 |
| - else: |
471 |
| - self.image_ro_quant = None |
| 469 | + if "image_ro_quant" not in config: |
| 470 | + if "rotate" in self.image_transforms: |
| 471 | + warnings.warn( |
| 472 | + "Setting image rotate quantisation to the \ |
| 473 | + default of 1, since no config value was provided for it." |
| 474 | + ) |
| 475 | + self.image_ro_quant = 1 |
472 | 476 | else:
|
473 |
| - self.image_ro_quant = config["image_ro_quant"] |
| 477 | + self.image_ro_quant = None |
| 478 | + else: |
| 479 | + self.image_ro_quant = config["image_ro_quant"] |
474 | 480 |
|
475 |
| - if "image_scale_range" not in config: |
476 |
| - if "scale" in self.image_transforms: |
477 |
| - warnings.warn( |
478 |
| - "Setting image scale range to the default \ |
479 |
| - of (0.5, 1.5), since no config value was provided for it." |
480 |
| - ) |
481 |
| - self.image_scale_range = (0.5, 1.5) |
482 |
| - else: |
483 |
| - self.image_scale_range = None |
| 481 | + if "image_scale_range" not in config: |
| 482 | + if "scale" in self.image_transforms: |
| 483 | + warnings.warn( |
| 484 | + "Setting image scale range to the default \ |
| 485 | + of (0.5, 1.5), since no config value was provided for it." |
| 486 | + ) |
| 487 | + self.image_scale_range = (0.5, 1.5) |
484 | 488 | else:
|
485 |
| - self.image_scale_range = config["image_scale_range"] |
| 489 | + self.image_scale_range = None |
| 490 | + else: |
| 491 | + self.image_scale_range = config["image_scale_range"] |
486 | 492 |
|
487 | 493 | # Defaults for the individual environment types:
|
488 | 494 | if config["state_space_type"] == "discrete":
|
@@ -827,6 +833,15 @@ def __init__(self, **config):
|
827 | 833 | + ", "
|
828 | 834 | + str(len(self.augmented_state))
|
829 | 835 | )
|
| 836 | + |
| 837 | + # Needed for rendering with pygame for use with Gymnasium.Env's render() method: |
| 838 | + render_mode = config.get("render_mode", None) |
| 839 | + assert render_mode is None or render_mode in self.metadata["render_modes"] |
| 840 | + self.render_mode = render_mode |
| 841 | + |
| 842 | + self.window = None |
| 843 | + self.clock = None |
| 844 | + |
830 | 845 | self.logger.debug(
|
831 | 846 | "MDP Playground toy env instantiated with config: " + str(self.config)
|
832 | 847 | )
|
@@ -1639,7 +1654,8 @@ def transition_function(self, state, action):
|
1639 | 1654 | / factorial_array[j]
|
1640 | 1655 | )
|
1641 | 1656 | # print('self.state_derivatives:', self.state_derivatives)
|
1642 |
| - next_state = self.state_derivatives[0] |
| 1657 | + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state |
| 1658 | + next_state = self.state_derivatives[0].copy() |
1643 | 1659 |
|
1644 | 1660 | else: # if action is from outside allowed action_space
|
1645 | 1661 | next_state = state
|
@@ -1684,7 +1700,8 @@ def transition_function(self, state, action):
|
1684 | 1700 | self.state_derivatives = [
|
1685 | 1701 | zero_state.copy() for i in range(self.dynamics_order + 1)
|
1686 | 1702 | ]
|
1687 |
| - self.state_derivatives[0] = next_state |
| 1703 | + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state |
| 1704 | + self.state_derivatives[0] = next_state.copy() |
1688 | 1705 |
|
1689 | 1706 | if self.config["reward_function"] == "move_to_a_point":
|
1690 | 1707 | next_state_rel = np.array(next_state, dtype=self.dtype_s)[
|
@@ -2126,7 +2143,7 @@ def get_augmented_state(self):
|
2126 | 2143 |
|
2127 | 2144 | return augmented_state_dict
|
2128 | 2145 |
|
2129 |
| - def reset(self, seed=None): |
| 2146 | + def reset(self, seed=None, options=None): |
2130 | 2147 | """Resets the environment for the beginning of an episode and samples a start state from rho_0. For discrete environments uses the defined rho_0 directly. For continuous environments, samples a state and resamples until a non-terminal state is sampled.
|
2131 | 2148 |
|
2132 | 2149 | Returns
|
@@ -2225,7 +2242,8 @@ def reset(self, seed=None):
|
2225 | 2242 | zero_state.copy() for i in range(self.dynamics_order + 1)
|
2226 | 2243 | ] # #####IMP to have copy()
|
2227 | 2244 | # otherwise it's the same array (in memory) at every position in the list
|
2228 |
| - self.state_derivatives[0] = self.curr_state |
| 2245 | + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state |
| 2246 | + self.state_derivatives[0] = self.curr_state.copy() |
2229 | 2247 |
|
2230 | 2248 | self.augmented_state = [
|
2231 | 2249 | [np.nan] * self.state_space_dim
|
@@ -2316,6 +2334,82 @@ def seed(self, seed=None):
|
2316 | 2334 | )
|
2317 | 2335 | return self.seed_
|
2318 | 2336 |
|
| 2337 | + def render(self,): |
| 2338 | + ''' |
| 2339 | + Renders the environment using pygame if render_mode is "human" and returns the rendered |
| 2340 | + image if render_mode is "rgb_array". |
| 2341 | +
|
| 2342 | + Based on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ |
| 2343 | + ''' |
| 2344 | + |
| 2345 | + import pygame |
| 2346 | + |
| 2347 | + # Init stuff on first call. For non-image_representations based envs, it makes sense |
| 2348 | + # to only instantiate the render_space here and not in __init__ because it's only needed |
| 2349 | + # if render() is called. |
| 2350 | + if self.window is None: |
| 2351 | + if self.image_representations: |
| 2352 | + self.render_space = self.observation_space |
| 2353 | + else: |
| 2354 | + if self.config["state_space_type"] == "discrete": |
| 2355 | + self.render_space = ImageMultiDiscrete( |
| 2356 | + self.state_space_size, |
| 2357 | + width=self.image_width, |
| 2358 | + height=self.image_height, |
| 2359 | + transforms=self.image_transforms, |
| 2360 | + sh_quant=self.image_sh_quant, |
| 2361 | + scale_range=self.image_scale_range, |
| 2362 | + ro_quant=self.image_ro_quant, |
| 2363 | + circle_radius=20, |
| 2364 | + seed=self.seed_dict["image_representations"], |
| 2365 | + ) # #seed |
| 2366 | + elif self.config["state_space_type"] == "continuous": |
| 2367 | + self.render_space = ImageContinuous( |
| 2368 | + self.feature_space, |
| 2369 | + width=self.image_width, |
| 2370 | + height=self.image_height, |
| 2371 | + term_spaces=self.term_spaces, |
| 2372 | + target_point=self.target_point, |
| 2373 | + circle_radius=5, |
| 2374 | + seed=self.seed_dict["image_representations"], |
| 2375 | + ) # #seed |
| 2376 | + elif self.config["state_space_type"] == "grid": |
| 2377 | + target_pt = list_to_float_np_array(self.target_point) |
| 2378 | + self.render_space = ImageContinuous( |
| 2379 | + self.feature_space, |
| 2380 | + width=self.image_width, |
| 2381 | + height=self.image_height, |
| 2382 | + term_spaces=self.term_spaces, |
| 2383 | + target_point=target_pt, |
| 2384 | + circle_radius=5, |
| 2385 | + grid_shape=self.grid_shape, |
| 2386 | + seed=self.seed_dict["image_representations"], |
| 2387 | + ) # #seed |
| 2388 | + |
| 2389 | + |
| 2390 | + if self.window is None and self.render_mode == "human": |
| 2391 | + pygame.init() |
| 2392 | + pygame.display.init() |
| 2393 | + self.window = pygame.display.set_mode( |
| 2394 | + (self.image_width, self.image_height) |
| 2395 | + ) |
| 2396 | + if self.clock is None and self.render_mode == "human": |
| 2397 | + self.clock = pygame.time.Clock() |
| 2398 | + |
| 2399 | + # ##TODO There are repeated calculations here in calling get_concatenated_image |
| 2400 | + # that can be taken from storing variables in step() or reset(). |
| 2401 | + if self.render_mode == "human": |
| 2402 | + rgb_array = self.render_space.get_concatenated_image(self.curr_state) |
| 2403 | + pygame_surface = pygame.surfarray.make_surface(rgb_array) |
| 2404 | + self.window.blit(pygame_surface, pygame_surface.get_rect()) |
| 2405 | + pygame.event.pump() |
| 2406 | + pygame.display.update() |
| 2407 | + |
| 2408 | + # We need to ensure that human-rendering occurs at the predefined framerate. |
| 2409 | + # The following line will automatically add a delay to keep the framerate stable. |
| 2410 | + self.clock.tick(self.metadata["render_fps"]) |
| 2411 | + elif self.render_mode == "rgb_array": |
| 2412 | + return self.render_space.get_concatenated_image(self.curr_state) |
2319 | 2413 |
|
2320 | 2414 | def dist_of_pt_from_line(pt, ptA, ptB):
|
2321 | 2415 | """Returns shortest distance of a point from a line defined by 2 points - ptA and ptB.
|
|
0 commit comments