Source code for mani_skill.agents.multi_agent
from typing import Generic, TypeVar
import torch
from gymnasium import spaces
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.sensors.base_sensor import BaseSensorConfig
[docs]class MultiAgent(BaseAgent, Generic[T]):
def __init__(self, agents: list[BaseAgent]):
self.agents = agents
[docs] self.agents_dict: dict[str, BaseAgent] = dict()
[docs] self.scene = agents[0].scene
[docs] self.sensor_configs = []
[docs] self._sensor_config_agent_map: dict[str, int] = dict()
for i, agent in enumerate(self.agents):
for sensor_config in agent._sensor_configs:
sensor_config.uid = f"{agent.uid}-{i}-{sensor_config.uid}"
self.sensor_configs.append(sensor_config)
self._sensor_config_agent_map[sensor_config.uid] = i
self.agents_dict[f"{agent.uid}-{i}"] = agent
[docs] def get_proprioception(self):
proprioception = dict()
for i, agent in enumerate(self.agents):
proprioception[f"{agent.uid}-{i}"] = agent.get_proprioception()
return proprioception
@property
[docs] def _sensor_configs(self) -> list[BaseSensorConfig]:
"""Returns a list of sensor configs for this agent. It contains the sensor configs of the individual agents."""
return self.sensor_configs
@property
[docs] def control_mode(self):
"""Get the currently activated controller uid of each robot"""
return {uid: agent.control_mode for uid, agent in self.agents_dict.items()}
[docs] def set_control_mode(self, control_mode: list[str] = None):
"""Set the controller, drive properties, and reset for each agent. If given control mode is None, will set defaults"""
if control_mode is None:
for agent in self.agents:
agent.set_control_mode()
else:
assert len(control_mode) == len(
self.agents
), "For task with multiple agents, setting control mode on the MultiAgent object requires a control mode for each agent"
for cm, agent in zip(control_mode, self.agents):
agent.set_control_mode(cm)
@property
[docs] def controller(self):
return {uid: agent.controller for uid, agent in self.agents_dict.items()}
@property
[docs] def action_space(self):
return spaces.Dict(
{uid: agent.action_space for uid, agent in self.agents_dict.items()}
)
@property
[docs] def single_action_space(self):
return spaces.Dict(
{uid: agent.single_action_space for uid, agent in self.agents_dict.items()}
)
[docs] def set_action(self, action):
"""
Set the agent's action which is to be executed in the next environment timestep
"""
for uid, agent in self.agents_dict.items():
agent.set_action(action[uid])
[docs] def before_simulation_step(self):
for agent in self.agents:
agent.controller.before_simulation_step()
[docs] def get_controller_state(self):
"""
Get the state of the controller.
"""
return {
uid: agent.get_controller_state() for uid, agent in self.agents_dict.items()
}
[docs] def set_controller_state(self, state: dict):
for uid, agent in self.agents_dict.items():
agent.set_controller_state(state[uid])
# -------------------------------------------------------------------------- #
# Other
# -------------------------------------------------------------------------- #
[docs] def reset(self, init_qpos=None):
"""
Reset the robot to a rest position or a given q-position
"""
for uid, agent in self.agents_dict.items():
if init_qpos is not None and uid in init_qpos:
agent.reset(init_qpos=[uid])
else:
agent.reset()