"""Verify Posed DROID camera extrinsics — v2 with contour overlays and convention comparison.

Renders with BOTH cam2base and base2cam interpretations side-by-side,
uses contour-only overlay for clearer alignment checking.

Usage:
    MUJOCO_GL=egl python verify_silhouette_v2.py
"""

import json
import os
import numpy as np
import pandas as pd
import imageio.v3 as iio
import cv2
import mujoco
from scipy.spatial.transform import Rotation as R

DROID_ROOT = "/data/cameron/droid"
MANIFEST_PATH = os.path.join(DROID_ROOT, "manifest_posed_ext2.json")
FRANKA_XML = "/data/cameron/para/droid_testing/franka_panda/panda_nohand.xml"
OUTPUT_DIR = "/data/cameron/para_droid_pretrain/posed_droid/silhouette_verification_v2"
IMG_W, IMG_H = 640, 360  # 2x real resolution for clearer visualization


def load_manifest():
    with open(MANIFEST_PATH) as f:
        return json.load(f)["episodes"]


def load_episode_data(ep_idx):
    chunk = ep_idx // 1000
    parquet = os.path.join(DROID_ROOT, f"data/chunk-{chunk:03d}/episode_{ep_idx:06d}.parquet")
    vid_ext1 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_1_left/episode_{ep_idx:06d}.mp4")
    vid_ext2 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_2_left/episode_{ep_idx:06d}.mp4")
    df = pd.read_parquet(parquet)
    joints = np.stack(df["observation.state.joint_position"].values)
    frames1 = iio.imread(vid_ext1, plugin="pyav")
    frames2 = iio.imread(vid_ext2, plugin="pyav")
    return joints, frames1, frames2


def extrinsics_to_cam_pose(ext_6d, convention="cam2base"):
    """Convert 6D extrinsics to camera pos + rotation matrix in world frame.

    convention='cam2base': T maps camera→base (the named convention)
    convention='base2cam': T maps base→camera (try the inverse)
    """
    pos = np.array(ext_6d[:3])
    rot = R.from_euler("xyz", ext_6d[3:6]).as_matrix()

    if convention == "cam2base":
        # T_cam2base: camera position in base = pos, camera axes in base = rot columns
        cam_pos = pos
        # OpenCV cam: X=right, Y=down, Z=forward
        # MuJoCo cam: X=right, Y=up, Z=backward
        opencv_to_mj = np.diag([1, -1, -1])
        cam_rot = rot @ opencv_to_mj
    else:
        # T_base2cam: invert to get cam2base
        # T_cam2base = T_base2cam^-1 => R^T, -R^T @ t
        rot_inv = rot.T
        pos_inv = -rot_inv @ pos
        cam_pos = pos_inv
        opencv_to_mj = np.diag([1, -1, -1])
        cam_rot = rot_inv @ opencv_to_mj

    return cam_pos, cam_rot


def load_franka_assets(xml_path):
    xml_dir = os.path.dirname(os.path.abspath(xml_path))
    assets = {}
    asset_dir = os.path.join(xml_dir, "assets")
    if os.path.isdir(asset_dir):
        for fn in os.listdir(asset_dir):
            fpath = os.path.join(asset_dir, fn)
            if os.path.isfile(fpath):
                with open(fpath, "rb") as f:
                    assets[os.path.join("assets", fn)] = f.read()
    return assets


FRANKA_ASSETS = load_franka_assets(FRANKA_XML)


def render_robot(xml_path, joints, cam_pos, cam_rot, width, height, fovy=50.0):
    """Render robot from posed camera, return RGB + binary mask."""
    with open(xml_path) as f:
        xml = f.read()

    cam_x = cam_rot[:, 0]
    cam_y = cam_rot[:, 1]
    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}" '
        f'xyaxes="{cam_x[0]} {cam_x[1]} {cam_x[2]} {cam_y[0]} {cam_y[1]} {cam_y[2]}" '
        f'fovy="{fovy}"/>'
    )
    xml = xml.replace(
        '<light name="top" pos="0 0 2" mode="trackcom"/>',
        f'<light name="top" pos="0 0 2" mode="trackcom"/>\n    {cam_str}'
    )

    model = mujoco.MjModel.from_xml_string(xml, FRANKA_ASSETS)
    data = mujoco.MjData(model)
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")

    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()

    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.disable_depth_rendering()
    renderer.close()

    mask = (depth < depth.max() * 0.99).astype(np.uint8)
    return rgb, mask


def contour_overlay(real_frame, mask, color=(0, 255, 0), thickness=2):
    """Draw robot silhouette contour on real frame."""
    h, w = mask.shape[:2]
    frame = cv2.resize(real_frame, (w, h))
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    result = frame.copy()
    cv2.drawContours(result, contours, -1, color, thickness)
    # Also add semi-transparent fill
    fill = frame.copy()
    for c in contours:
        cv2.fillPoly(fill, [c], color)
    result = cv2.addWeighted(result, 0.7, fill, 0.3, 0)
    return result


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    episodes = load_manifest()

    # Pick 3 exact-match episodes from different collectors for diversity
    exact = [e for e in episodes if e["match_dist"] == 0.0]
    # Group by collector (first part of posed_ep_id)
    seen_collectors = set()
    diverse_exact = []
    for e in exact:
        collector = e["posed_ep_id"].split("+")[0]
        if collector not in seen_collectors:
            seen_collectors.add(collector)
            diverse_exact.append(e)
        if len(diverse_exact) >= 4:
            break

    print(f"Sampled {len(diverse_exact)} diverse exact-match episodes:")
    for e in diverse_exact:
        print(f"  ep={e['ep_idx']}, collector={e['posed_ep_id'].split('+')[0]}, id={e['posed_ep_id']}")

    fovys = [42.0, 50.0, 58.0, 69.0]

    for ep_info in diverse_exact:
        ep_idx = ep_info["ep_idx"]
        print(f"\n{'='*60}")
        print(f"Episode {ep_idx} — {ep_info['posed_ep_id']}")

        try:
            joints, frames1, frames2, = load_episode_data(ep_idx)
        except Exception as e:
            print(f"  SKIP: {e}")
            continue

        n = min(len(joints), len(frames2))
        mid = n // 2
        frame_idxs = [0, mid, n - 1]

        # Process ext2 camera (usually the better viewpoint for seeing the robot)
        ext_6d = ep_info["posed_ext2"]

        # For each frame: compare cam2base vs base2cam, and FOV sweep
        for fi in frame_idxs:
            if fi >= n:
                continue

            real = frames2[fi]
            jpos = joints[fi]

            panels = []
            # Real frame panel
            real_resized = cv2.resize(real, (IMG_W, IMG_H))
            cv2.putText(real_resized, f"Real frame {fi}/{n}", (5, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            panels.append(real_resized)

            # Try both conventions with default FOV
            for conv in ["cam2base", "base2cam"]:
                cam_pos, cam_rot = extrinsics_to_cam_pose(ext_6d, convention=conv)
                try:
                    rgb, mask = render_robot(FRANKA_XML, jpos, cam_pos, cam_rot,
                                            IMG_W, IMG_H, fovy=50.0)
                    overlay = contour_overlay(real, mask,
                                             color=(0, 255, 0) if conv == "cam2base" else (255, 100, 0))
                    cv2.putText(overlay, f"{conv} fov=50", (5, 20),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                    panels.append(overlay)
                except Exception as e:
                    blank = np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8)
                    cv2.putText(blank, f"{conv}: {str(e)[:40]}", (5, 20),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
                    panels.append(blank)

            row = np.concatenate(panels, axis=1)
            # Save individual comparison
            out_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_ext2_f{fi:04d}_convention.png")
            cv2.imwrite(out_path, cv2.cvtColor(row, cv2.COLOR_RGB2BGR))

        # FOV sweep at mid-frame with the better convention
        fov_rows = []
        for fovy in fovys:
            for conv in ["cam2base", "base2cam"]:
                cam_pos, cam_rot = extrinsics_to_cam_pose(ext_6d, convention=conv)
                try:
                    rgb, mask = render_robot(FRANKA_XML, joints[mid], cam_pos, cam_rot,
                                            IMG_W, IMG_H, fovy=fovy)
                    overlay = contour_overlay(frames2[mid], mask,
                                             color=(0, 255, 0) if conv == "cam2base" else (255, 100, 0))
                    cv2.putText(overlay, f"{conv} fov={fovy}", (5, 20),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                    fov_rows.append(overlay)
                except:
                    pass

        if fov_rows:
            # Arrange in 2-column grid (cam2base | base2cam per fov)
            grid_rows = []
            for i in range(0, len(fov_rows), 2):
                pair = fov_rows[i:i+2]
                if len(pair) == 2:
                    grid_rows.append(np.concatenate(pair, axis=1))
                else:
                    grid_rows.append(pair[0])
            grid = np.concatenate(grid_rows, axis=0)
            out_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_ext2_fov_sweep.png")
            cv2.imwrite(out_path, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  FOV sweep: {out_path}")

        # Also do ext1
        ext_6d_1 = ep_info["posed_ext1"]
        cam_pos, cam_rot = extrinsics_to_cam_pose(ext_6d_1, convention="cam2base")
        panels_ext1 = []
        for fi in frame_idxs:
            if fi >= min(len(joints), len(frames1)):
                continue
            try:
                rgb, mask = render_robot(FRANKA_XML, joints[fi], cam_pos, cam_rot,
                                        IMG_W, IMG_H, fovy=50.0)
                overlay = contour_overlay(frames1[fi], mask)
                cv2.putText(overlay, f"ext1 cam2base f={fi}", (5, 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                panels_ext1.append(overlay)
            except:
                pass
        if panels_ext1:
            stack = np.concatenate(panels_ext1, axis=0)
            out_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_ext1_overlay.png")
            cv2.imwrite(out_path, cv2.cvtColor(stack, cv2.COLOR_RGB2BGR))
            print(f"  ext1: {out_path}")

    print(f"\nOutputs: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
