Source code for mani_skill.agents.registration
from dataclasses import dataclass
from mani_skill import logger
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.utils import assets
from typing import Type, TypeVar
[docs]T = TypeVar("T", bound=BaseAgent)
@dataclass
[docs]class AgentSpec:
[docs] agent_cls: type[BaseAgent]
[docs] asset_download_ids: list[str]
[docs]REGISTERED_AGENTS: dict[str, AgentSpec] = {}
[docs]def register_agent(asset_download_ids: list[str] = [], override=False):
"""A decorator to register agents into ManiSkill so they can be used easily by string uid.
Args:
uid (str): unique id of the agent.
override (bool): whether to override the agent if it is already registered.
"""
def _register_agent(agent_cls: Type[T]) -> Type[T]:
if agent_cls.uid in REGISTERED_AGENTS:
if override:
logger.warn(f"Overriding registered agent {agent_cls.uid}")
REGISTERED_AGENTS.pop(agent_cls.uid)
else:
logger.warn(
f"Agent {agent_cls.uid} is already registered. Skip registration."
)
return agent_cls
REGISTERED_AGENTS[agent_cls.uid] = AgentSpec(
agent_cls=agent_cls, asset_download_ids=asset_download_ids
)
assets.DATA_GROUPS[agent_cls.uid] = asset_download_ids
return agent_cls
return _register_agent