Source code for mani_skill.trajectory.merge_trajectory

import argparse
from pathlib import Path
import h5py
from mani_skill.utils.logging_utils import logger

from mani_skill.utils.io_utils import dump_json, load_json


[docs]def merge_trajectories(output_path: str, traj_paths: list, recompute_id: bool = True): """ Merges multiple JSON and H5 files into a single JSON and H5 file. This function combines the contents of multiple JSON and H5 files. It keeps the first value for all keys (other than "episodes") and logs a warning for any differences. The "episodes" from each JSON file are merged into a single list, and the corresponding H5 data is copied to the output H5 file. Args: output_path (str): The path to the output H5 file. The corresponding JSON file will be saved with the same name but with a .json extension. traj_paths (list): A list of paths to the input trajectory files (H5 files). The corresponding JSON files should have the same name but with a .json extension. recompute_id (bool): If True, recompute the episode IDs to ensure they are unique. If False, keep the original episode IDs. Raises: AssertionError: If there is a conflict in the episode IDs when recompute_id is False. """ logger.info(f"Merging {output_path}") merged_h5_file = h5py.File(output_path, "w") merged_json_path = output_path.replace(".h5", ".json") merged_json_data = {"episodes": []} cnt = 0 for traj_path in traj_paths: traj_path = str(traj_path) logger.info(f"Merging{traj_path}") with h5py.File(traj_path, "r") as h5_file: json_data = load_json(traj_path.replace(".h5", ".json")) # For keys other than episodes, keep the first data # and check if there is any conflict with other data. for key, value in json_data.items(): if key == "episodes": continue if key not in merged_json_data: merged_json_data[key] = value else: if merged_json_data[key] != value: logger.warning(f"Conflict detected for key {key} in {traj_path}: {merged_json_data[key]} != {value}") # Merge episodes for ep in json_data["episodes"]: episode_id = ep["episode_id"] traj_id = f"traj_{episode_id}" # Copy h5 data if recompute_id: new_traj_id = f"traj_{cnt}" else: new_traj_id = traj_id assert new_traj_id not in merged_h5_file, new_traj_id h5_file.copy(traj_id, merged_h5_file, new_traj_id) # Copy json data if recompute_id: ep["episode_id"] = cnt merged_json_data["episodes"].append(ep) cnt += 1 merged_h5_file.close() dump_json(merged_json_path, merged_json_data, indent=2)
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("-i", "--input-dirs", nargs="+") parser.add_argument("-o", "--output-path", type=str) parser.add_argument("-p", "--pattern", type=str, default="trajectory.h5") args = parser.parse_args() traj_paths = [] for input_dir in args.input_dirs: input_dir = Path(input_dir) traj_paths.extend(sorted(input_dir.rglob(args.pattern))) output_dir = Path(args.output_path).parent output_dir.mkdir(exist_ok=True, parents=True) merge_trajectories(args.output_path, traj_paths)
if __name__ == "__main__": main()