"""Build a quality-filtered manifest of DROID episodes.

Checks each episode's camera extrinsics quality by projecting the EEF trajectory
through the camera and measuring what fraction of frames project within the image.
Episodes with poor in-frame rates likely have miscalibrated extrinsics.

When Posed DROID annotations (KarlP/droid) become available, this script can be
extended to use those instead — the manifest format is the same.

Usage:
    python build_droid_manifest.py --data_root /data/cameron/droid --camera ext2
    python build_droid_manifest.py --data_root /data/cameron/droid --camera ext2 --min_in_frame 0.75
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from scipy.spatial.transform import Rotation as ScipyR

# Estimated ZED 2 intrinsics at 320x180
DEFAULT_FY = 130.0
IMG_W, IMG_H = 320, 180


def compute_in_frame_fraction(eef_positions, extrinsics_6d, fy=DEFAULT_FY):
    """Compute fraction of EEF positions that project within the camera frame."""
    R_bc = ScipyR.from_euler("xyz", extrinsics_6d[3:6]).as_matrix()
    R_cb = R_bc.T
    t_cb = -R_cb @ extrinsics_6d[:3]

    fx = fy
    cx, cy = IMG_W / 2.0, IMG_H / 2.0

    in_frame = 0
    for p in eef_positions:
        p_cam = R_cb @ p + t_cb
        if p_cam[2] <= 0:
            continue
        u = fx * p_cam[0] / p_cam[2] + cx
        v = fy * p_cam[1] / p_cam[2] + cy
        if 0 <= u < IMG_W and 0 <= v < IMG_H:
            in_frame += 1

    return in_frame / len(eef_positions) if len(eef_positions) > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, default="/data/cameron/droid")
    parser.add_argument("--camera", type=str, default="ext2", choices=["ext1", "ext2"])
    parser.add_argument("--min_in_frame", type=float, default=0.75,
                        help="Minimum fraction of EEF in-frame to include episode")
    parser.add_argument("--min_frames", type=int, default=32,
                        help="Minimum episode length")
    parser.add_argument("--output", type=str, default="",
                        help="Output manifest path (default: data_root/manifest_{camera}.json)")
    args = parser.parse_args()

    data_root = Path(args.data_root)
    ext_col = (f"camera_extrinsics.exterior_1_left" if args.camera == "ext1"
               else "camera_extrinsics.exterior_2_left")
    cam_key = ("observation.images.exterior_1_left" if args.camera == "ext1"
               else "observation.images.exterior_2_left")

    output = args.output or str(data_root / f"manifest_{args.camera}.json")

    import pandas as pd

    TOTAL_EPISODES = 95600
    t0 = time.time()

    results = {
        "camera": args.camera,
        "min_in_frame": args.min_in_frame,
        "min_frames": args.min_frames,
        "fy": DEFAULT_FY,
        "episodes": [],  # list of {ep_idx, num_frames, in_frame_frac, building, collector_id}
    }

    stats = {"total": 0, "missing": 0, "too_short": 0, "low_quality": 0, "accepted": 0}

    for ep in range(TOTAL_EPISODES):
        chunk = f"chunk-{ep // 1000:03d}"
        ep_str = f"episode_{ep:06d}"
        pq_path = data_root / "data" / chunk / f"{ep_str}.parquet"
        vid_path = data_root / "videos" / chunk / cam_key / f"{ep_str}.mp4"

        stats["total"] += 1

        if not pq_path.exists() or not vid_path.exists():
            stats["missing"] += 1
            continue

        try:
            df = pd.read_parquet(pq_path, columns=[
                ext_col, "observation.state.cartesian_position",
                "building", "collector_id"
            ])
        except Exception:
            stats["missing"] += 1
            continue

        T = len(df)
        if T < args.min_frames:
            stats["too_short"] += 1
            continue

        ext = np.array(df[ext_col].iloc[0], dtype=np.float64)
        cart = np.stack(df["observation.state.cartesian_position"].values).astype(np.float64)
        eef_pos = cart[:, :3]

        in_frame_frac = compute_in_frame_fraction(eef_pos, ext)

        if in_frame_frac < args.min_in_frame:
            stats["low_quality"] += 1
            continue

        stats["accepted"] += 1
        results["episodes"].append({
            "ep_idx": ep,
            "num_frames": T,
            "in_frame_frac": round(in_frame_frac, 3),
            "building": str(df["building"].iloc[0]) if "building" in df.columns else "",
            "collector_id": str(df["collector_id"].iloc[0]) if "collector_id" in df.columns else "",
        })

        if stats["total"] % 5000 == 0:
            elapsed = time.time() - t0
            rate = stats["total"] / elapsed
            eta = (TOTAL_EPISODES - stats["total"]) / rate
            print(f"  [{stats['total']}/{TOTAL_EPISODES}] accepted={stats['accepted']} "
                  f"low_q={stats['low_quality']} ({rate:.0f} eps/s, ETA {eta/60:.1f}min)")

    elapsed = time.time() - t0
    results["stats"] = stats

    # Summary statistics
    if results["episodes"]:
        fracs = [e["in_frame_frac"] for e in results["episodes"]]
        frame_counts = [e["num_frames"] for e in results["episodes"]]
        results["summary"] = {
            "total_accepted": len(results["episodes"]),
            "total_frames": sum(frame_counts),
            "mean_in_frame": round(np.mean(fracs), 3),
            "median_in_frame": round(np.median(fracs), 3),
            "mean_episode_length": round(np.mean(frame_counts), 1),
        }

    with open(output, "w") as f:
        json.dump(results, f, indent=2)

    print(f"\nDone in {elapsed:.0f}s")
    print(f"Stats: {json.dumps(stats, indent=2)}")
    print(f"Accepted: {stats['accepted']} episodes ({stats['accepted']/stats['total']*100:.1f}%)")
    if results.get("summary"):
        print(f"Total frames: {results['summary']['total_frames']:,}")
        print(f"Mean in-frame: {results['summary']['mean_in_frame']:.3f}")
    print(f"Saved: {output}")


if __name__ == "__main__":
    main()
