Source code for mani_skill.agents.multi_agent

from typing import Generic, TypeVar

import torch
from gymnasium import spaces

from mani_skill.agents.base_agent import BaseAgent
from mani_skill.sensors.base_sensor import BaseSensorConfig

[docs]T = TypeVar("T")
[docs]class MultiAgent(BaseAgent, Generic[T]):
[docs] agents: T
def __init__(self, agents: list[BaseAgent]): self.agents = agents
[docs] self.agents_dict: dict[str, BaseAgent] = dict()
[docs] self.scene = agents[0].scene
[docs] self.sensor_configs = []
[docs] self._sensor_config_agent_map: dict[str, int] = dict()
for i, agent in enumerate(self.agents): for sensor_config in agent._sensor_configs: sensor_config.uid = f"{agent.uid}-{i}-{sensor_config.uid}" self.sensor_configs.append(sensor_config) self._sensor_config_agent_map[sensor_config.uid] = i self.agents_dict[f"{agent.uid}-{i}"] = agent
[docs] def get_proprioception(self): proprioception = dict() for i, agent in enumerate(self.agents): proprioception[f"{agent.uid}-{i}"] = agent.get_proprioception() return proprioception
@property
[docs] def _sensor_configs(self) -> list[BaseSensorConfig]: """Returns a list of sensor configs for this agent. It contains the sensor configs of the individual agents.""" return self.sensor_configs
@property
[docs] def control_mode(self): """Get the currently activated controller uid of each robot""" return {uid: agent.control_mode for uid, agent in self.agents_dict.items()}
[docs] def set_control_mode(self, control_mode: list[str] = None): """Set the controller, drive properties, and reset for each agent. If given control mode is None, will set defaults""" if control_mode is None: for agent in self.agents: agent.set_control_mode() else: assert len(control_mode) == len( self.agents ), "For task with multiple agents, setting control mode on the MultiAgent object requires a control mode for each agent" for cm, agent in zip(control_mode, self.agents): agent.set_control_mode(cm)
@property
[docs] def controller(self): return {uid: agent.controller for uid, agent in self.agents_dict.items()}
@property
[docs] def action_space(self): return spaces.Dict( {uid: agent.action_space for uid, agent in self.agents_dict.items()} )
@property
[docs] def single_action_space(self): return spaces.Dict( {uid: agent.single_action_space for uid, agent in self.agents_dict.items()} )
[docs] def set_action(self, action): """ Set the agent's action which is to be executed in the next environment timestep """ for uid, agent in self.agents_dict.items(): agent.set_action(action[uid])
[docs] def before_simulation_step(self): for agent in self.agents: agent.controller.before_simulation_step()
[docs] def get_controller_state(self): """ Get the state of the controller. """ return { uid: agent.get_controller_state() for uid, agent in self.agents_dict.items() }
[docs] def set_controller_state(self, state: dict): for uid, agent in self.agents_dict.items(): agent.set_controller_state(state[uid])
# -------------------------------------------------------------------------- # # Other # -------------------------------------------------------------------------- #
[docs] def reset(self, init_qpos=None): """ Reset the robot to a rest position or a given q-position """ for uid, agent in self.agents_dict.items(): if init_qpos is not None and uid in init_qpos: agent.reset(init_qpos=[uid]) else: agent.reset()