Source code for mani_skill.utils.wrappers.frame_stack

"""Wrapper that stacks frames. Adapted from gymnasium package to support GPU vectorizated environments."""
from collections import deque

import gymnasium as gym
import numpy as np
import torch

from mani_skill.envs.sapien_env import BaseEnv


[docs]class FrameStack(gym.ObservationWrapper): """Observation wrapper that stacks the observations in a rolling manner. For example, if the number of stacks is 4, then the returned observation contains the most recent 4 observations. For environment 'PickCube-v1', the original observation is an array with shape [42], so if we stack 4 observations, the processed observation has shape [4, 42]. This wrapper also supports dict observations, and will stack the leafs of the dictionary accordingly. Note: - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. """ def __init__(self, env: gym.Env, num_stack: int, lz4_compress: bool = False): """Observation wrapper that stacks the observations in a rolling manner. Args: env (Env): The environment to apply the wrapper num_stack (int): The number of frames to stack lz4_compress (bool): Use lz4 to compress the frames internally """ gym.ObservationWrapper.__init__(self, env)
[docs] self.num_stack = num_stack
[docs] self.frames = deque(maxlen=num_stack)
[docs] self.lz4_compress = lz4_compress
[self.frames.append(self.base_env._init_raw_obs) for _ in range(self.num_stack)]
[docs] self.use_dict = isinstance(self.base_env._init_raw_obs, dict)
new_obs = self.observation(self.base_env._init_raw_obs) self.base_env.update_obs_space(new_obs) @property
[docs] def base_env(self) -> BaseEnv: return self.env.unwrapped
[docs] def observation(self, observation): assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) if self.use_dict: return { k: torch.stack([x[k] for x in self.frames], dim=0).transpose(0, 1) for k in self.observation_space.keys() } else: return torch.stack(list(self.frames)).transpose(0, 1)
# LazyFrames equivalent code. It is unclear yet if this is faster or saves much memory. LazyFrames leverages __slots__ to save memory # and have faster attribute access. # if self.use_dict: # return { # k: LazyFrames([x[k] for x in self.frames], self.lz4_compress) # for k in self.observation_space.keys() # } # else: # return LazyFrames(list(self.frames), self.lz4_compress)
[docs] def step(self, action): """Steps through the environment, appending the observation to the frame buffer. Args: action: The action to step through the environment with Returns: Stacked observations, reward, terminated, truncated, and information from the environment """ observation, reward, terminated, truncated, info = self.env.step(action) self.frames.append(observation) return self.observation(None), reward, terminated, truncated, info
[docs] def reset(self, seed=None, options=None): """Reset the environment with kwargs. Args: seed: The seed for the environment reset options: The options for the environment reset Returns: The stacked observations """ if ( isinstance(options, dict) and "env_idx" in options and len(options["env_idx"]) < self.base_env.num_envs ): raise RuntimeError( "partial environment reset is currently not supported for the FrameStack wrapper at this moment for GPU parallelized environments" ) # NOTE (stao): we can support partial reset easily using tensordicts obs, info = self.env.reset(seed=seed, options=options) [self.frames.append(obs) for _ in range(self.num_stack)] return self.observation(None), info