Source code for mani_skill.utils.tree
import torch
# NOTE (stao): when tensordict is used we should replace all of this
[docs]def slice(x, i):
if isinstance(x, dict):
return {k: slice(v, i) for k, v in x.items()}
else:
return x[i]
[docs]def cat(x: list):
if isinstance(x[0], dict):
return {k: cat([d[k] for d in x]) for k in x[0].keys()}
else:
return torch.cat(x, dim=0)
[docs]def replace(x, i, y):
if isinstance(x, dict):
for k, v in x.items():
replace(v, i, y[k])
else:
x[i] = y
[docs]def shape(x, first_only=False):
"""
Get the shape of leaf items in a tree. If first_only is True, return the shape of the first item only
"""
if isinstance(x, dict):
if first_only:
return shape(next(iter(x.values())), first_only)
return {k: shape(v, first_only) for k, v in x.items()}
else:
return x.shape