Source code for mani_skill.utils.registration

from __future__ import annotations

import json
import sys
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING, Optional, Type

import gymnasium as gym
import torch
from gymnasium.envs.registration import EnvSpec as GymEnvSpec
from gymnasium.envs.registration import WrapperSpec

from mani_skill import logger
from mani_skill.utils import assets, download_asset
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

if TYPE_CHECKING:
    from mani_skill.envs.sapien_env import BaseEnv


[docs]class EnvSpec: def __init__( self, uid: str, cls: Type[BaseEnv], max_episode_steps=None, asset_download_ids: Optional[list[str]] = [], default_kwargs: dict = None, ): """A specification for a ManiSkill environment."""
[docs] self.uid = uid
[docs] self.cls = cls
[docs] self.max_episode_steps = max_episode_steps
[docs] self.asset_download_ids = asset_download_ids
[docs] self.default_kwargs = {} if default_kwargs is None else default_kwargs
[docs] def make(self, **kwargs): _kwargs = self.default_kwargs.copy() _kwargs.update(kwargs) # check if all assets necessary are downloaded assets_to_download = [] for asset_id in self.asset_download_ids or []: is_data_group = asset_id in assets.DATA_GROUPS if is_data_group: found_data_group_assets = True for ( data_source_id ) in assets.expand_data_group_into_individual_data_source_ids(asset_id): if not assets.is_data_source_downloaded(data_source_id): assets_to_download.append(data_source_id) found_data_group_assets = False if not found_data_group_assets: print( f"Environment {self.uid} requires a set of assets in group {asset_id}. At least 1 of those assets could not be found" ) else: if not assets.is_data_source_downloaded(asset_id): assets_to_download.append(asset_id) data_source = assets.DATA_SOURCES[asset_id] print( f"Could not find asset {asset_id} at {data_source.output_dir / data_source.target_path}" ) if len(assets_to_download) > 0: if len(assets_to_download) <= 5: asset_download_msg = ", ".join(assets_to_download) else: asset_download_msg = f"{assets_to_download[:5]} (and {len(assets_to_download) - 10} more)" response = download_asset.prompt_yes_no( f"Environment {self.uid} requires asset(s) {asset_download_msg} which could not be found. Would you like to download them now?" ) if response: for asset_id in assets_to_download: download_asset.download(assets.DATA_SOURCES[asset_id]) else: print("Exiting as assets are not found or downloaded") exit() return self.cls(**_kwargs)
@property
[docs] def gym_spec(self): """Return a gym EnvSpec for this env""" entry_point = self.cls.__module__ + ":" + self.cls.__name__ return GymEnvSpec( self.uid, entry_point, max_episode_steps=self.max_episode_steps, kwargs=self.default_kwargs, )
[docs]REGISTERED_ENVS: dict[str, EnvSpec] = {}
[docs]def register( name: str, cls: Type[BaseEnv], max_episode_steps=None, asset_download_ids: list[str] = [], default_kwargs: dict = None, ): """Register a ManiSkill environment.""" # hacky way to avoid circular import errors when users inherit a task in ManiSkill and try to register it themselves from mani_skill.envs.sapien_env import BaseEnv if name in REGISTERED_ENVS: logger.warn(f"Env {name} already registered") if not issubclass(cls, BaseEnv): raise TypeError(f"Env {name} must inherit from BaseEnv") for asset_id in asset_download_ids: if asset_id not in assets.DATA_SOURCES and asset_id not in assets.DATA_GROUPS: raise KeyError(f"Asset {asset_id} not found in data sources or groups") REGISTERED_ENVS[name] = EnvSpec( name, cls, max_episode_steps=max_episode_steps, asset_download_ids=asset_download_ids, default_kwargs=default_kwargs, ) assets.DATA_GROUPS[name] = asset_download_ids
[docs]class TimeLimitWrapper(gym.Wrapper): """like the standard gymnasium timelimit wrapper but fixes truncated variable to be a torch tensor and batched""" def __init__(self, env: gym.Env, max_episode_steps: int): super().__init__(env) prev_frame_locals = sys._getframe(1).f_locals frame = sys._getframe(1) # check for user supplied max_episode_steps during gym.make calls if frame.f_code.co_name == "make" and "max_episode_steps" in prev_frame_locals: if prev_frame_locals["max_episode_steps"] is not None: max_episode_steps = prev_frame_locals["max_episode_steps"] # do some wrapper surgery to remove the previous timelimit wrapper # with gymnasium 0.29.1, this will remove the timelimit wrapper and nothing else. curr_env = env found_env = False while curr_env is not None: if isinstance(curr_env, gym.wrappers.TimeLimit): self.env = curr_env.env found_env = True break if hasattr(curr_env, "env"): curr_env = curr_env.env else: break if not found_env: self.env = env
[docs] self._max_episode_steps = max_episode_steps
@property
[docs] def base_env(self) -> BaseEnv: return self.env.unwrapped
[docs] def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) if self._max_episode_steps is not None: truncated = self.base_env.elapsed_steps >= self._max_episode_steps else: truncated = torch.zeros( (self.base_env.num_envs,), dtype=torch.bool, device=self.base_env.device ) return observation, reward, terminated, truncated, info
[docs]def make(env_id, **kwargs): """Instantiate a ManiSkill environment. Args: env_id (str): Environment ID. as_gym (bool, optional): Add TimeLimit wrapper as gym. **kwargs: Keyword arguments to pass to the environment. """ if env_id not in REGISTERED_ENVS: raise KeyError("Env {} not found in registry".format(env_id)) env_spec = REGISTERED_ENVS[env_id] env = env_spec.make(**kwargs) return env
[docs]def make_vec(env_id, **kwargs): env = gym.make(env_id, **kwargs) env = ManiSkillVectorEnv(env) return env
[docs]def register_env( uid: str, max_episode_steps=None, override=False, asset_download_ids: list[str] = [], **kwargs, ): """A decorator to register ManiSkill environments. Args: uid (str): unique id of the environment. max_episode_steps (int): maximum number of steps in an episode. asset_download_ids (list[str]): asset download ids the environment depends on. When environments are created this list is checked to see if the user has all assets downloaded and if not, prompt the user if they wish to download them. override (bool): whether to override the environment if it is already registered. Notes: - `max_episode_steps` is processed differently from other keyword arguments in gym. `gym.make` wraps the env with `gym.wrappers.TimeLimit` to limit the maximum number of steps. - `gym.EnvSpec` uses kwargs instead of **kwargs! """ try: json.dumps(kwargs) except TypeError: raise RuntimeError( f"You cannot register_env with non json dumpable kwargs, e.g. classes or types. If you really need to do this, it is recommended to create a mapping of string to the unjsonable data and to pass the string in the kwarg and during env creation find the data you need" ) def _register_env(cls): if uid in REGISTERED_ENVS: if override: from gymnasium.envs.registration import registry logger.warn(f"Override registered env {uid}") REGISTERED_ENVS.pop(uid) registry.pop(uid) else: logger.warn(f"Env {uid} is already registered. Skip registration.") return cls # Register for ManiSkill register( uid, cls, max_episode_steps=max_episode_steps, asset_download_ids=asset_download_ids, default_kwargs=deepcopy(kwargs), ) # Register for gym gym.register( uid, entry_point=partial(make, env_id=uid), # NOTE (stao): gym.make_vec is now deprecated and should not be used with ManiSkill. It does not work for gym versions > 1. vector_entry_point=partial(make_vec, env_id=uid), max_episode_steps=max_episode_steps, disable_env_checker=True, # Temporary solution as we allow empty observation spaces kwargs=deepcopy(kwargs), additional_wrappers=[ WrapperSpec( "MSTimeLimit", entry_point="mani_skill.utils.registration:TimeLimitWrapper", kwargs=dict(max_episode_steps=max_episode_steps), ) ], ) return cls return _register_env