"""Compute median gripper rotation across entire dataset."""
import sys
import os
from pathlib import Path
import numpy as np
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import argparse

parser = argparse.ArgumentParser(description="Compute median gripper rotation from dataset")
parser.add_argument("--dataset_dir", "-d", default="scratch/parsed_propercup_train", type=str, help="Dataset directory")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

all_quats = []
for episode_dir in episode_dirs:
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    for frame_file in frame_files:
        frame_str = f"{int(frame_file.stem):06d}"
        pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if pose_path.exists():
            gripper_pose = np.load(pose_path)
            rot = gripper_pose[:3, :3]
            quat = R.from_matrix(rot).as_quat()  # xyzw format
            all_quats.append(quat)

all_quats = np.array(all_quats)
mean_quat = all_quats.mean(axis=0)
mean_quat = mean_quat / np.linalg.norm(mean_quat)
median_rot = R.from_quat(mean_quat).as_matrix()

print(f"Computed mean rotation from {len(all_quats)} gripper poses across {len(episode_dirs)} episodes")
print(f"Rotation matrix:")
print(median_rot)
print(f"\nAs numpy array (for hardcoding):")
print(f"median_dataset_rotation = np.array({repr(median_rot)})")
