Source code for mani_skill.utils.wrappers.flatten
import copy
import gymnasium as gym
import gymnasium.spaces.utils
import numpy as np
import torch
from gymnasium.vector.utils import batch_space
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common
[docs]class FlattenRGBDObservationWrapper(gym.ObservationWrapper):
"""
Flattens the rgbd mode observations into a dictionary with two keys, "rgbd" and "state"
Args:
rgb (bool): Whether to include rgb images in the observation
depth (bool): Whether to include depth images in the observation
state (bool): Whether to include state data in the observation
sep_depth (bool): Whether to separate depth and rgb images in the observation. Default is True.
Note that the returned observations will have a "rgb" or "depth" key depending on the rgb/depth bool flags, and will
always have a "state" key. If sep_depth is False, rgb and depth will be merged into a single "rgbd" key.
"""
def __init__(self, env, rgb=True, depth=True, state=True, sep_depth=True) -> None:
[docs] self.base_env: BaseEnv = env.unwrapped
super().__init__(env)
[docs] self.include_depth = depth
[docs] self.sep_depth = sep_depth
[docs] self.include_state = state
# check if rgb/depth data exists in first camera's sensor data
first_cam = next(iter(self.base_env._init_raw_obs["sensor_data"].values()))
if "depth" not in first_cam:
self.include_depth = False
if "rgb" not in first_cam:
self.include_rgb = False
new_obs = self.observation(self.base_env._init_raw_obs)
self.base_env.update_obs_space(new_obs)
[docs] def observation(self, observation: dict):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
rgb_images = []
depth_images = []
for cam_data in sensor_data.values():
if self.include_rgb:
rgb_images.append(cam_data["rgb"])
if self.include_depth:
depth_images.append(cam_data["depth"])
if len(rgb_images) > 0:
rgb_images = torch.concat(rgb_images, axis=-1)
if len(depth_images) > 0:
depth_images = torch.concat(depth_images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
if self.include_state:
ret["state"] = observation
if self.include_rgb and not self.include_depth:
ret["rgb"] = rgb_images
elif self.include_rgb and self.include_depth:
if self.sep_depth:
ret["rgb"] = rgb_images
ret["depth"] = depth_images
else:
ret["rgbd"] = torch.concat([rgb_images, depth_images], axis=-1)
elif self.include_depth and not self.include_rgb:
ret["depth"] = depth_images
return ret
[docs]class FlattenObservationWrapper(gym.ObservationWrapper):
"""
Flattens the observations into a single vector
"""
def __init__(self, env) -> None:
super().__init__(env)
self.base_env.update_obs_space(
common.flatten_state_dict(self.base_env._init_raw_obs)
)
@property
[docs] def base_env(self) -> BaseEnv:
return self.env.unwrapped
[docs] def observation(self, observation):
return common.flatten_state_dict(observation, use_torch=True)
[docs]class FlattenActionSpaceWrapper(gym.ActionWrapper):
"""
Flattens the action space. The original action space must be spaces.Dict
"""
def __init__(self, env) -> None:
super().__init__(env)
[docs] self._orig_single_action_space = copy.deepcopy(
self.base_env.single_action_space
)
[docs] self.single_action_space = gymnasium.spaces.utils.flatten_space(
self.base_env.single_action_space
)
if self.base_env.num_envs > 1:
self.action_space = batch_space(
self.single_action_space, n=self.base_env.num_envs
)
else:
self.action_space = self.single_action_space
@property
[docs] def base_env(self) -> BaseEnv:
return self.env.unwrapped
[docs] def action(self, action):
if (
self.base_env.num_envs == 1
and action.shape == self.single_action_space.shape
):
action = common.batch(action)
# TODO (stao): This code only supports flat dictionary at the moment
unflattened_action = dict()
start, end = 0, 0
for k, space in self._orig_single_action_space.items():
end += space.shape[0]
unflattened_action[k] = action[:, start:end]
start += space.shape[0]
return unflattened_action