Source code for mani_skill.vector.wrappers.gymnasium

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Tuple, Union

import gymnasium as gym
import torch
from gymnasium.vector import VectorEnv

from mani_skill.utils.common import torch_clone_dict
from mani_skill.utils.structs.types import Array

if TYPE_CHECKING:
    from gymnasium import Env

    from mani_skill.envs.sapien_env import BaseEnv


[docs]class ManiSkillVectorEnv(VectorEnv): """ Gymnasium Vector Env implementation for ManiSkill environments running on the GPU for parallel simulation and optionally parallel rendering Args: env: The environment created via gym.make / after wrappers are applied. If a string is given, we use gym.make(env) to create an environment num_envs: The number of parallel environments. This is only used if the env argument is a string env_kwargs: Environment kwargs to pass to gym.make. This is only used if the env argument is a string auto_reset (bool): Whether this wrapper will auto reset the environment (following the same API/conventions as Gymnasium). Default is True (recommended as most ML/RL libraries use auto reset) ignore_terminations (bool): Whether this wrapper ignores terminations when deciding when to auto reset. Terminations can be caused by the task reaching a success or fail state as defined in a task's evaluation function. Default is False, meaning there is early stop in episode rollouts. If set to True, this would generally for situations where you may want to model a task as infinite horizon where a task stops only due to the timelimit. record_metrics (bool): If True, the returned info objects will contain the metrics: return, length, success_once, success_at_end, fail_once, fail_at_end. success/fail metrics are recorded only when the environment has success/fail criteria. success/fail_at_end are recorded only when ignore_terminations is True. """ def __init__( self, env: Union[Env, str], num_envs: int = 1, auto_reset: bool = True, ignore_terminations: bool = False, record_metrics: bool = False, **kwargs, ): if isinstance(env, str): self._env = gym.make(env, num_envs=num_envs, **kwargs) else: self._env = env num_envs = self.base_env.num_envs
[docs] self.num_envs = num_envs
[docs] self.auto_reset = auto_reset
[docs] self.ignore_terminations = ignore_terminations
[docs] self.record_metrics = record_metrics
[docs] self.spec = self._env.spec
from mani_skill.utils.gym_utils import IS_GYMNASIUM_1 if IS_GYMNASIUM_1: self.single_observation_space = self._env.get_wrapper_attr( "single_observation_space" ) self.single_action_space = self._env.get_wrapper_attr("single_action_space") self.action_space = self._env.get_wrapper_attr("action_space") self.observation_space = self._env.get_wrapper_attr("observation_space") self.metadata = self._env.metadata # hardcoded SAME STEP reset mode for now. Trying to support others with backwards compatability with gym < 1.0 # might be too much of a hassle. self.metadata.update(autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) else: super().__init__( num_envs, self._env.get_wrapper_attr("single_observation_space"), self._env.get_wrapper_attr("single_action_space"), ) if not self.ignore_terminations and auto_reset: assert ( self.base_env.reconfiguration_freq == 0 or self.base_env.num_envs == 1 ), "With partial resets, environment cannot be reconfigured automatically" if self.record_metrics: self.success_once = torch.zeros( self.num_envs, device=self.base_env.device, dtype=torch.bool ) self.fail_once = torch.zeros( self.num_envs, device=self.base_env.device, dtype=torch.bool ) self.returns = torch.zeros( self.num_envs, device=self.base_env.device, dtype=torch.float32 ) @property
[docs] def device(self): return self.base_env.device
@property
[docs] def base_env(self) -> BaseEnv: return self._env.unwrapped # type: ignore
@property
[docs] def unwrapped(self): return self.base_env
[docs] def reset( self, *, seed: Optional[Union[int, list[int]]] = None, options: Optional[dict] = None, ): obs, info = self._env.reset(seed=seed, options=options) # type: ignore if options is not None and "env_idx" in options: env_idx = options["env_idx"] mask = torch.zeros( self.num_envs, dtype=torch.bool, device=self.base_env.device ) mask[env_idx] = True if self.record_metrics: self.success_once[mask] = False self.fail_once[mask] = False self.returns[mask] = 0 else: if self.record_metrics: self.success_once[:] = False self.fail_once[:] = False self.returns[:] = 0 return obs, info
[docs] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, actions: Union[Array, dict] ) -> Tuple[Array, Array, Array, Array, dict]: obs, rew, terminations, truncations, infos = self._env.step(actions) episode_info: Optional[dict] = None if self.record_metrics: episode_info = dict() self.returns += rew if "success" in infos: self.success_once = self.success_once | infos["success"] episode_info["success_once"] = self.success_once.clone() if "fail" in infos: self.fail_once = self.fail_once | infos["fail"] episode_info["fail_once"] = self.fail_once.clone() episode_info["return"] = self.returns.clone() episode_info["episode_len"] = self.base_env.elapsed_steps.clone() episode_info["reward"] = ( episode_info["return"] / episode_info["episode_len"] ) if isinstance(terminations, bool): terminations = torch.tensor([terminations], device=self.device) if self.ignore_terminations: terminations[:] = False if episode_info: if "success" in infos: episode_info["success_at_end"] = infos["success"].clone() if "fail" in infos: episode_info["fail_at_end"] = infos["fail"].clone() if self.record_metrics: infos["episode"] = episode_info dones = torch.logical_or( terminations, truncations # pyright: ignore[reportArgumentType] ) if dones.any() and self.auto_reset: final_obs = torch_clone_dict(obs) env_idx = torch.arange(0, self.num_envs, device=self.device)[dones] final_info = torch_clone_dict(infos) obs, infos = self.reset(options=dict(env_idx=env_idx)) # gymnasium calls it final observation but it really is just o_{t+1} or the true next observation infos["final_observation"] = final_obs infos["final_info"] = final_info # NOTE (stao): that adding masks like below is a bit redundant and not necessary # but this is to follow the standard gymnasium API infos["_final_info"] = dones infos["_final_observation"] = dones infos["_elapsed_steps"] = dones # NOTE (stao): Unlike gymnasium, the code here does not add masks for every key in the info object. return ( obs, rew, terminations, truncations, infos, ) # pyright: ignore[reportReturnType]
[docs] def close_extras(self, **kwargs): self._env.close()
[docs] def call(self, name: str, *args, **kwargs): function = getattr(self._env, name) return function(*args, **kwargs)
[docs] def get_attr(self, name: str): raise RuntimeError( "To get an attribute get it from the .env property of this object" )
[docs] def render(self): return self.base_env.render()