"""Verify silhouette alignment using real per-episode intrinsics.

Usage:
    MUJOCO_GL=egl python verify_with_intrinsics.py
"""

import json
import os
import numpy as np
import pandas as pd
import imageio.v3 as iio
import cv2
import mujoco
from scipy.spatial.transform import Rotation as R

DROID_ROOT = "/data/cameron/droid"
MANIFEST_PATH = os.path.join(DROID_ROOT, "manifest_posed_ext2.json")
POSED_JSON_PATH = "/data/cameron/para_droid_pretrain/posed_droid/pnp_cam2base_multiview.json"
INTRINSICS_PATH = "/data/cameron/random/droid_replay/submodules/extrinsics_keypoint/intrinsics.json"
FRANKA_XML = "/data/cameron/para/droid_testing/franka_panda/panda_nohand.xml"
OUTPUT_DIR = "/data/cameron/para_droid_pretrain/posed_droid/silhouette_intrinsics"

# Video resolution
VID_W, VID_H = 320, 180
# Render at 2x for clarity
RENDER_W, RENDER_H = 640, 360
NATIVE_W, NATIVE_H = 1280, 720  # DROID native camera resolution


def load_franka_assets():
    xml_dir = os.path.dirname(os.path.abspath(FRANKA_XML))
    assets = {}
    for fn in os.listdir(os.path.join(xml_dir, "assets")):
        with open(os.path.join(xml_dir, "assets", fn), "rb") as f:
            assets[os.path.join("assets", fn)] = f.read()
    return assets

ASSETS = load_franka_assets()


def find_intrinsics(intrinsics_db, relative_path, cam_name="exterior_image_2_left"):
    """Find intrinsics for an episode by matching relative_path in the GCS keys."""
    # The relative_path looks like: TRI/success/2023-12-13/Wed_Dec_13_15:55:58_2023
    # The intrinsics keys look like: gs://.../<relative_path>/recordings/MP4--gs://.../<relative_path>/trajectory.h5
    for key, val in intrinsics_db.items():
        if relative_path in key:
            if cam_name in val:
                return val[cam_name]  # [fx, cx, fy, cy]
    return None


def render_with_intrinsics(xml_path, joints, cam_pos, cam_rot, width, height, fovy):
    """Render robot with given camera parameters."""
    with open(xml_path) as f:
        xml = f.read()

    cam_x = cam_rot[:, 0]
    cam_y = cam_rot[:, 1]
    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}" '
        f'xyaxes="{cam_x[0]} {cam_x[1]} {cam_x[2]} {cam_y[0]} {cam_y[1]} {cam_y[2]}" '
        f'fovy="{fovy}"/>'
    )
    xml = xml.replace(
        '<light name="top" pos="0 0 2" mode="trackcom"/>',
        f'<light name="top" pos="0 0 2" mode="trackcom"/>\n    {cam_str}'
    )

    model = mujoco.MjModel.from_xml_string(xml, ASSETS)
    data = mujoco.MjData(model)
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")
    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()

    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.disable_depth_rendering()
    renderer.close()

    mask = (depth < depth.max() * 0.99).astype(np.uint8)
    return rgb, mask


def contour_overlay(real_frame, mask, color=(0, 255, 0), thickness=2):
    h, w = mask.shape[:2]
    frame = cv2.resize(real_frame, (w, h))
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    result = frame.copy()
    cv2.drawContours(result, contours, -1, color, thickness)
    fill = frame.copy()
    for c in contours:
        cv2.fillPoly(fill, [c], color)
    result = cv2.addWeighted(result, 0.7, fill, 0.3, 0)
    return result


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print("Loading data...")
    with open(MANIFEST_PATH) as f:
        manifest = json.load(f)
    episodes = manifest["episodes"]

    with open(POSED_JSON_PATH) as f:
        posed_json = json.load(f)

    with open(INTRINSICS_PATH) as f:
        intrinsics_db = json.load(f)
    print(f"  Intrinsics entries: {len(intrinsics_db)}")

    # Test episodes
    test_ep_idxs = [2, 4, 9, 10]
    test_eps = [e for e in episodes if e["ep_idx"] in test_ep_idxs]

    for ep_info in test_eps:
        ep_idx = ep_info["ep_idx"]
        posed_ep_id = ep_info["posed_ep_id"]

        # Get relative_path from posed JSON
        posed_entry = posed_json.get(posed_ep_id, {})
        relative_path = posed_entry.get("relative_path", "")
        print(f"\n{'='*60}")
        print(f"Episode {ep_idx} — {posed_ep_id}")
        print(f"  Relative path: {relative_path}")

        # Find intrinsics
        intr_ext2 = find_intrinsics(intrinsics_db, relative_path, "exterior_image_2_left")
        intr_ext1 = find_intrinsics(intrinsics_db, relative_path, "exterior_image_1_left")

        if intr_ext2 is None:
            print(f"  NO INTRINSICS FOUND for ext2, skipping")
            continue

        fx, cx, fy, cy = intr_ext2
        fovy = 2 * np.degrees(np.arctan(NATIVE_H / (2 * fy)))
        print(f"  Intrinsics ext2: fx={fx:.1f}, cx={cx:.1f}, fy={fy:.1f}, cy={cy:.1f}")
        print(f"  Computed fovy: {fovy:.1f}°")

        # Load episode data
        chunk = ep_idx // 1000
        parquet = os.path.join(DROID_ROOT, f"data/chunk-{chunk:03d}/episode_{ep_idx:06d}.parquet")
        vid_ext2 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_2_left/episode_{ep_idx:06d}.mp4")
        vid_ext1 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_1_left/episode_{ep_idx:06d}.mp4")

        df = pd.read_parquet(parquet)
        joints = np.stack(df["observation.state.joint_position"].values)
        frames_ext2 = iio.imread(vid_ext2, plugin="pyav")

        n = min(len(joints), len(frames_ext2))
        frame_idxs = [0, n // 2, n - 1]

        # Extrinsics
        ext_6d = np.array(ep_info["posed_ext2"])
        pos = np.array(ext_6d[:3])
        rot = R.from_euler("xyz", ext_6d[3:6]).as_matrix()
        cam_pos = pos
        opencv_to_mj = np.diag([1, -1, -1])
        cam_rot = rot @ opencv_to_mj

        # Render with CORRECT intrinsics-derived FOV
        all_rows = []
        for fi in frame_idxs:
            if fi >= n:
                continue
            real = frames_ext2[fi]
            jpos = joints[fi]

            # Render with correct FOV
            rgb_correct, mask_correct = render_with_intrinsics(
                FRANKA_XML, jpos, cam_pos, cam_rot, RENDER_W, RENDER_H, fovy=fovy)
            overlay_correct = contour_overlay(real, mask_correct, color=(0, 255, 0))
            cv2.putText(overlay_correct, f"CORRECT fovy={fovy:.0f} f={fi}", (5, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

            # Also render with our old guess (fov=50) for comparison
            rgb_guess, mask_guess = render_with_intrinsics(
                FRANKA_XML, jpos, cam_pos, cam_rot, RENDER_W, RENDER_H, fovy=50.0)
            overlay_guess = contour_overlay(real, mask_guess, color=(255, 100, 0))
            cv2.putText(overlay_guess, f"GUESS fovy=50 f={fi}", (5, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

            # Side-by-side: real | correct overlay | old guess overlay
            real_r = cv2.resize(real, (RENDER_W, RENDER_H))
            cv2.putText(real_r, f"Real f={fi}/{n}", (5, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

            # Also MuJoCo render
            cv2.putText(rgb_correct, f"MuJoCo fovy={fovy:.0f}", (5, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

            row = np.concatenate([real_r, rgb_correct, overlay_correct, overlay_guess], axis=1)
            all_rows.append(row)

        if all_rows:
            grid = np.concatenate(all_rows, axis=0)
            out = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_ext2_intrinsics.png")
            cv2.imwrite(out, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  Saved: {out}")

        # Also do ext1
        if intr_ext1 is not None:
            fx1, cx1, fy1, cy1 = intr_ext1
            fovy1 = 2 * np.degrees(np.arctan(NATIVE_H / (2 * fy1)))
            ext1_6d = np.array(ep_info["posed_ext1"])
            pos1 = np.array(ext1_6d[:3])
            rot1 = R.from_euler("xyz", ext1_6d[3:6]).as_matrix()
            cam_pos1 = pos1
            cam_rot1 = rot1 @ opencv_to_mj

            frames_ext1 = iio.imread(vid_ext1, plugin="pyav")
            n1 = min(len(joints), len(frames_ext1))

            rows1 = []
            for fi in [0, n1 // 2, n1 - 1]:
                if fi >= n1:
                    continue
                rgb1, mask1 = render_with_intrinsics(
                    FRANKA_XML, joints[fi], cam_pos1, cam_rot1, RENDER_W, RENDER_H, fovy=fovy1)
                overlay1 = contour_overlay(frames_ext1[fi], mask1, color=(0, 255, 0))
                cv2.putText(overlay1, f"ext1 fovy={fovy1:.0f} f={fi}", (5, 25),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                real1 = cv2.resize(frames_ext1[fi], (RENDER_W, RENDER_H))
                cv2.putText(real1, f"Real ext1 f={fi}", (5, 25),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                rows1.append(np.concatenate([real1, rgb1, overlay1], axis=1))

            if rows1:
                grid1 = np.concatenate(rows1, axis=0)
                out1 = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_ext1_intrinsics.png")
                cv2.imwrite(out1, cv2.cvtColor(grid1, cv2.COLOR_RGB2BGR))
                print(f"  ext1 saved: {out1}")

    print(f"\nAll outputs: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
