Source code for mani_skill.utils.wrappers.cached_reset

from dataclasses import asdict, dataclass
from typing import Optional, Union

import dacite
import gymnasium as gym
import torch

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common, tree
from mani_skill.utils.structs.types import Device


@dataclass
[docs]class CachedResetsConfig:
[docs] num_resets: Optional[int] = None
"""The number of reset states to cache. If none it will cache `num_envs` number of reset states."""
[docs] device: Optional[Device] = None
"""The device to cache the reset states on. If none it will use the base environment's device."""
[docs] seed: Optional[int] = None
"""The seed to use for generating the cached reset states."""
[docs]class CachedResetWrapper(gym.Wrapper): """ Cached reset wrapper for ManiSkill3 environments. Caching resets allows you to skip slower parts of the reset function call and boost environment FPS as a result. Args: env: The environment to wrap. reset_to_env_states: A dictionary with keys "env_states" and optionally "obs". "env_states" is a dictionary of environment states to reset to. "obs" contains the corresponding observations generated at those env states. If reset_to_env_states is not provided, the wrapper will sample reset states from the environment using the given seed. config: A dictionary or a `CachedResetsConfig` object that contains the configuration for the cached resets. """ def __init__( self, env: gym.Env, reset_to_env_states: Optional[dict] = None, config: Union[CachedResetsConfig, dict] = CachedResetsConfig(), ): super().__init__(env) # setup config
[docs] self.num_envs = self.base_env.num_envs
if isinstance(config, CachedResetsConfig): config = asdict(config)
[docs] self.cached_resets_config = dacite.from_dict( data_class=CachedResetsConfig, data=config, config=dacite.Config(strict=True), )
cached_data_device = self.cached_resets_config.device if cached_data_device is None: cached_data_device = self.base_env.device # cache env_states and optionally obs
[docs] self._num_cached_resets = 0
if reset_to_env_states is not None: self._cached_resets_env_states = reset_to_env_states["env_states"] self._cached_resets_obs_buffer = reset_to_env_states.get("obs", None) self._num_cached_resets = tree.shape(self._cached_resets_env_states, first_only=True)[0] else: if self.cached_resets_config.num_resets is None: self.cached_resets_config.num_resets = self.num_envs self._cached_resets_env_states = [] self._cached_resets_obs_buffer = [] while self._num_cached_resets < self.cached_resets_config.num_resets: obs, _ = self.env.reset( seed=self.cached_resets_config.seed if self._num_cached_resets == 0 else None, options=dict( env_idx=torch.arange( 0, min( self.cached_resets_config.num_resets - self._num_cached_resets, self.num_envs, ), device=self.base_env.device, ) ), ) state = self.env.get_wrapper_attr("get_state_dict")() if ( self.cached_resets_config.num_resets - self._num_cached_resets < self.num_envs ): obs = tree.slice( obs, slice( 0, self.cached_resets_config.num_resets - self._num_cached_resets, ), ) state = tree.slice( state, slice( 0, self.cached_resets_config.num_resets - self._num_cached_resets, ), ) self._cached_resets_obs_buffer.append( common.to_tensor(obs, device=self.cached_resets_config.device) ) self._cached_resets_env_states.append( common.to_tensor(state, device=self.cached_resets_config.device) ) self._num_cached_resets += self.num_envs self._cached_resets_env_states = tree.cat(self._cached_resets_env_states) self._cached_resets_obs_buffer = tree.cat(self._cached_resets_obs_buffer)
[docs] self._cached_resets_env_states = common.to_tensor( self._cached_resets_env_states, device=cached_data_device )
if self._cached_resets_obs_buffer is not None: self._cached_resets_obs_buffer = common.to_tensor( self._cached_resets_obs_buffer, device=cached_data_device ) @property
[docs] def base_env(self) -> BaseEnv: return self.env.unwrapped
[docs] def reset( self, *args, seed: Optional[Union[int, list[int]]] = None, options: Optional[dict] = None, **kwargs ): env_idx = None if options is None: options = dict() if "env_idx" in options: env_idx = options["env_idx"] if self._cached_resets_env_states is not None: sampled_ids = torch.randint( 0, self._num_cached_resets, size=(len(env_idx) if env_idx is not None else self.num_envs,), device=self.base_env.device, ) options["reset_to_env_states"] = dict( env_states=tree.slice(self._cached_resets_env_states, sampled_ids), ) if self._cached_resets_obs_buffer is not None: options["reset_to_env_states"]["obs"] = tree.slice( self._cached_resets_obs_buffer, sampled_ids ) obs, info = self.env.reset(seed=seed, options=options) return obs, info