Source code for mani_skill.agents.registration

from dataclasses import dataclass

from mani_skill import logger
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.utils import assets
from typing import Type, TypeVar

[docs]T = TypeVar("T", bound=BaseAgent)
@dataclass
[docs]class AgentSpec:
[docs] agent_cls: type[BaseAgent]
[docs] asset_download_ids: list[str]
[docs]REGISTERED_AGENTS: dict[str, AgentSpec] = {}
[docs]def register_agent(asset_download_ids: list[str] = [], override=False): """A decorator to register agents into ManiSkill so they can be used easily by string uid. Args: uid (str): unique id of the agent. override (bool): whether to override the agent if it is already registered. """ def _register_agent(agent_cls: Type[T]) -> Type[T]: if agent_cls.uid in REGISTERED_AGENTS: if override: logger.warn(f"Overriding registered agent {agent_cls.uid}") REGISTERED_AGENTS.pop(agent_cls.uid) else: logger.warn( f"Agent {agent_cls.uid} is already registered. Skip registration." ) return agent_cls REGISTERED_AGENTS[agent_cls.uid] = AgentSpec( agent_cls=agent_cls, asset_download_ids=asset_download_ids ) assets.DATA_GROUPS[agent_cls.uid] = asset_download_ids return agent_cls return _register_agent