Source code for mani_skill.envs.utils.randomization.batched_rng
"""
Code implementation for a batched random number generator. The goal is to enable seeding a batched random number generator with a batch of seeds to ensure randomization
in CPU simulators and GPU simulators are the same
"""
from typing import Union
import numpy as np
from mani_skill.utils import common
[docs]class BatchedRNG(np.random.RandomState):
def __init__(self, rngs: list):
[docs] self.batch_size = len(rngs)
@classmethod
[docs] def from_seeds(cls, seeds: list[int], backend: str = "numpy:random_state"):
if backend == "numpy:random_state":
return cls(rngs=[np.random.RandomState(seed) for seed in seeds])
raise ValueError(f"Unknown batched RNG backend: {backend}")
@classmethod
[docs] def from_rngs(cls, rngs: list):
return cls(rngs=rngs)
[docs] def __getitem__(self, idx: Union[int, list[int], np.ndarray]):
idx = common.to_numpy(idx)
if np.iterable(idx):
return BatchedRNG.from_rngs([self.rngs[i] for i in idx])
return self.rngs[idx]
[docs] def __setitem__(
self,
idx: Union[int, list[int], np.ndarray],
value: Union[np.random.RandomState, list[np.random.RandomState]],
):
idx = common.to_numpy(idx)
if np.iterable(idx):
for i, new_v in zip(idx, value):
self.rngs[i] = new_v
else:
self.rngs[idx] = value
[docs] def __getattribute__(self, item):
if item in [
"rngs",
"__class__",
"__dict__",
"__getattribute__",
"__str__",
"__repr__",
"__getitem__",
"batch_size",
]:
return object.__getattribute__(self, item)
if callable(getattr(self.rngs[0], item)):
def method(*args, **kwargs):
return np.array(
[
object.__getattribute__(rng, item)(*args, **kwargs)
for rng in self.rngs
]
)
return method
else:
return np.array([getattr(rng, item) for rng in self.rngs])