Source code for mani_skill.utils.wrappers.visual_encoders

from typing import Literal

import gymnasium as gym
import torch

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common


[docs]class VisualEncoderWrapper(gym.ObservationWrapper): def __init__(self, env, encoder: Literal["r3m"], encoder_config=dict()):
[docs] self.base_env: BaseEnv = env.unwrapped
assert encoder == "r3m", "Only encoder='r3m' is supported at the moment" if encoder == "r3m": assert self.base_env.obs_mode in [ "rgbd", "rgb", ], "r3m encoder requires obs_mode to be set to rgbd or rgb" import torchvision.transforms as T from r3m import load_r3m # self.num_images = len(venv.observation_space['image']) self.model = load_r3m("resnet18") # resnet18, resnet34 self.model.eval() self.model.to(self.base_env.device) self.transforms = T.Compose( [ T.Resize((224, 224), antialias=True), ] ) # HWC -> CHW self.single_image_embedding_size = 512 # for resnet18 self.base_env.update_obs_space( common.to_numpy( self.observation(common.to_tensor(self.base_env._init_raw_obs)) ) ) super().__init__(env) @torch.no_grad()
[docs] def observation(self, obs: dict): vec_img_embeddings_list = [] image_obs = obs.pop("sensor_data") del obs["sensor_param"] for image in image_obs.values(): vec_image = image["rgb"] # (N, H, W, 3), [0, 255] torch.int16 vec_image = self.transforms( vec_image.permute(0, 3, 1, 2) ) # (numenv, 3, 224, 224) vec_image = vec_image.to(self.base_env.device) vec_img_embedding = self.model( vec_image ).detach() # (numenv, self.single_image_embedding_size) vec_img_embeddings_list.append(vec_img_embedding) vec_embedding = torch.cat( vec_img_embeddings_list, dim=-1 ) # (numenv, self.image_embedding_size) obs["embedding"] = vec_embedding return obs