"""Match LeRobot DROID episodes to Posed DROID calibration data.

Builds a manifest that maps each LeRobot episode index to its improved
camera extrinsics from the Posed DROID JSON. Uses collector_id + camera
extrinsic positions as a fingerprint for matching.

Usage:
    python build_posed_manifest.py
"""

import json
import time
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from pathlib import Path
from collections import defaultdict
from scipy.spatial import cKDTree

DROID_ROOT = Path("/data/cameron/droid")
POSED_JSON = Path("/data/cameron/para_droid_pretrain/posed_droid/pnp_cam2base_multiview.json")
OUTPUT = Path("/data/cameron/droid/manifest_posed_ext2.json")
TOTAL_EPISODES = 95600
MATCH_THRESHOLD = 0.20  # max combined position distance (meters) for a valid match


def main():
    t0 = time.time()

    # Load posed DROID data, index by collector_id
    print("Loading Posed DROID JSON...")
    with open(POSED_JSON) as f:
        posed = json.load(f)

    # Group by collector_id with spatial index per group
    posed_by_cid = {}  # cid -> {positions: (N, 6), entries: [...]}
    for ep_id, meta in posed.items():
        parts = ep_id.split("+")
        if len(parts) < 2:
            continue
        cid = parts[1]
        cams = {k: np.array(v, dtype=np.float64)
                for k, v in meta.items()
                if k != "relative_path" and isinstance(v, list) and len(v) == 6}
        if len(cams) < 2:
            continue
        serials = sorted(cams.keys())

        if cid not in posed_by_cid:
            posed_by_cid[cid] = {"positions": [], "entries": []}
        # Store positions of both cameras (sorted by serial) as 6D vector
        pos = np.concatenate([cams[serials[0]][:3], cams[serials[1]][:3]])
        posed_by_cid[cid]["positions"].append(pos)
        posed_by_cid[cid]["entries"].append({
            "ep_id": ep_id,
            "serials": serials,
            "extrinsics": {s: cams[s].tolist() for s in serials},
            "path": meta.get("relative_path", ""),
        })

    # Build KD-trees per collector for fast nearest-neighbor lookup
    print("Building spatial indices...")
    for cid in posed_by_cid:
        positions = np.array(posed_by_cid[cid]["positions"])
        posed_by_cid[cid]["tree"] = cKDTree(positions)
        # Also build the swapped version (ext1↔ext2)
        positions_swap = np.concatenate([positions[:, 3:6], positions[:, 0:3]], axis=1)
        posed_by_cid[cid]["tree_swap"] = cKDTree(positions_swap)

    print(f"  {len(posed_by_cid)} collectors, "
          f"{sum(len(v['entries']) for v in posed_by_cid.values())} posed episodes")

    # Match LeRobot episodes
    print(f"\nMatching {TOTAL_EPISODES} LeRobot episodes...")
    manifest_episodes = []
    stats = {"total": 0, "matched": 0, "no_cid": 0, "too_far": 0, "missing": 0}

    for ep in range(TOTAL_EPISODES):
        stats["total"] += 1
        chunk = f"chunk-{ep // 1000:03d}"
        pq_path = DROID_ROOT / "data" / chunk / f"episode_{ep:06d}.parquet"
        vid_path = DROID_ROOT / "videos" / chunk / "observation.images.exterior_2_left" / f"episode_{ep:06d}.mp4"

        if not pq_path.exists() or not vid_path.exists():
            stats["missing"] += 1
            continue

        try:
            df = pd.read_parquet(str(pq_path), columns=[
                "collector_id",
                "camera_extrinsics.exterior_1_left",
                "camera_extrinsics.exterior_2_left",
            ])
        except Exception:
            stats["missing"] += 1
            continue

        cid = df["collector_id"].iloc[0]
        if cid not in posed_by_cid:
            stats["no_cid"] += 1
            continue

        lr_ext1 = np.array(df["camera_extrinsics.exterior_1_left"].iloc[0], dtype=np.float64)
        lr_ext2 = np.array(df["camera_extrinsics.exterior_2_left"].iloc[0], dtype=np.float64)
        query = np.concatenate([lr_ext1[:3], lr_ext2[:3]])

        # Query both orderings (ext1↔ext2 assignment may differ)
        d1, idx1 = posed_by_cid[cid]["tree"].query(query)
        d2, idx2 = posed_by_cid[cid]["tree_swap"].query(query)

        if d1 <= d2:
            best_dist, best_idx, serial_order = d1, idx1, "same"
        else:
            best_dist, best_idx, serial_order = d2, idx2, "swapped"

        if best_dist > MATCH_THRESHOLD:
            stats["too_far"] += 1
            continue

        stats["matched"] += 1
        match = posed_by_cid[cid]["entries"][best_idx]
        serials = match["serials"]

        # Determine which serial maps to ext1 vs ext2
        if serial_order == "same":
            ext1_serial, ext2_serial = serials[0], serials[1]
        else:
            ext1_serial, ext2_serial = serials[1], serials[0]

        T = pq.read_metadata(str(pq_path)).num_rows

        manifest_episodes.append({
            "ep_idx": ep,
            "num_frames": T,
            "posed_ep_id": match["ep_id"],
            "match_dist": round(best_dist, 4),
            # Improved extrinsics: [x,y,z,rx,ry,rz] for each camera
            "posed_ext1": match["extrinsics"][ext1_serial],
            "posed_ext2": match["extrinsics"][ext2_serial],
            "ext1_serial": ext1_serial,
            "ext2_serial": ext2_serial,
        })

        if stats["total"] % 10000 == 0:
            elapsed = time.time() - t0
            rate = stats["total"] / elapsed
            eta = (TOTAL_EPISODES - stats["total"]) / rate
            print(f"  [{stats['total']}/{TOTAL_EPISODES}] matched={stats['matched']} "
                  f"({rate:.0f} eps/s, ETA {eta:.0f}s)")

    elapsed = time.time() - t0

    # Save manifest
    result = {
        "source": "posed_droid",
        "posed_json": str(POSED_JSON),
        "match_threshold": MATCH_THRESHOLD,
        "stats": stats,
        "episodes": manifest_episodes,
    }

    # Summary stats
    if manifest_episodes:
        dists = [e["match_dist"] for e in manifest_episodes]
        frame_counts = [e["num_frames"] for e in manifest_episodes]
        result["summary"] = {
            "total_matched": len(manifest_episodes),
            "total_frames": sum(frame_counts),
            "mean_match_dist": round(np.mean(dists), 4),
            "median_match_dist": round(np.median(dists), 4),
            "mean_episode_length": round(np.mean(frame_counts), 1),
        }

    with open(OUTPUT, "w") as f:
        json.dump(result, f)

    print(f"\nDone in {elapsed:.0f}s")
    print(f"Stats: {json.dumps(stats, indent=2)}")
    print(f"Matched: {stats['matched']} / {stats['total']} ({stats['matched']/max(stats['total'],1)*100:.1f}%)")
    if result.get("summary"):
        print(f"Total frames: {result['summary']['total_frames']:,}")
        print(f"Mean match distance: {result['summary']['mean_match_dist']:.4f}m")
    print(f"Saved: {OUTPUT}")


if __name__ == "__main__":
    main()
