Skip to content

Commit 1645f4e

Browse files
Fix failing test
1 parent 2c4524f commit 1645f4e

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

‎mdp_playground/spaces/image_multi_discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ def sample(self):
291291
def __repr__(self):
292292
return (
293293
"{} with multi-discrete space of shape: {} and "
294-
"images of resolution: {} x {} and dtype: {}".format(
295-
self.__class__, self.state_space_sizes, self.width, self.height, self.dtype
294+
"images of resolution: {} and dtype: {}".format(
295+
self.__class__, self.state_space_sizes, self.shape, self.dtype
296296
)
297297
)
298298

‎mdp_playground/spaces/test_image_multi_discrete.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,34 @@ def test_image_multi_discrete(self):
1818
ds4 = [ds4.n]
1919
print(ds4)
2020
imd = ImageMultiDiscrete(ds4, transforms="shift")
21+
print("imd.dtype, shape", imd.dtype, imd.shape)
22+
print("imd", imd)
2123
from PIL import Image
2224

2325
# img1 = Image.fromarray(imd.disjoint_states[0][1], 'L')
2426
img1 = Image.fromarray(np.squeeze(imd.get_concatenated_image(3)), "L")
2527
if render:
2628
img1.show()
2729

28-
imd = ImageMultiDiscrete(
29-
ds4,
30-
transforms="scale,shift,rotate,flip",
31-
use_custom_images="textures",
32-
cust_path="/home/rajanr/textures",
33-
)
34-
img1 = Image.fromarray(np.squeeze(imd.get_concatenated_image(2)), "RGB")
30+
if render:
31+
imd = ImageMultiDiscrete(
32+
ds4,
33+
transforms="scale,shift,rotate,flip",
34+
use_custom_images="textures",
35+
cust_path="/home/rajanr/textures",
36+
)
37+
img1 = Image.fromarray(np.squeeze(imd.get_concatenated_image(2)), "RGB")
3538
if render:
3639
img1.show()
3740

38-
imd = ImageMultiDiscrete(
39-
ds4,
40-
transforms="scale,shift,rotate,flip",
41-
use_custom_images="images",
42-
cust_path="/home/rajanr/textures",
43-
)
44-
img1 = Image.fromarray(np.squeeze(imd.get_concatenated_image(1)), "RGB")
41+
if render:
42+
imd = ImageMultiDiscrete(
43+
ds4,
44+
transforms="scale,shift,rotate,flip",
45+
use_custom_images="images",
46+
cust_path="/home/rajanr/textures",
47+
)
48+
img1 = Image.fromarray(np.squeeze(imd.get_concatenated_image(1)), "RGB")
4549
if render:
4650
img1.show()
4751

0 commit comments

Comments
 (0)