"""Verify Posed DROID camera extrinsics by rendering Franka silhouette overlays.

For each sampled episode, renders the MuJoCo Franka at the recorded joint positions
from the posed camera viewpoint, and overlays the silhouette on the real video frame.
Produces side-by-side comparison images for start/middle/end frames.

Usage:
    MUJOCO_GL=egl python verify_silhouette.py
"""

import json
import os
import numpy as np
import pandas as pd
import imageio.v3 as iio
import cv2
import mujoco
from scipy.spatial.transform import Rotation as R

# ── Paths ──────────────────────────────────────────────────────────────────────
DROID_ROOT = "/data/cameron/droid"
MANIFEST_PATH = os.path.join(DROID_ROOT, "manifest_posed_ext2.json")
FRANKA_XML = "/data/cameron/para/droid_testing/franka_panda/panda_nohand.xml"
OUTPUT_DIR = "/data/cameron/para_droid_pretrain/posed_droid/silhouette_verification"

# Camera params — DROID uses RealSense D435i typically
# At 320x180: approximate intrinsics from D435i specs (69.4° x 42.5° FOV at 1280x720)
IMG_W, IMG_H = 320, 180


def load_manifest():
    with open(MANIFEST_PATH) as f:
        data = json.load(f)
    return data["episodes"]


def load_episode_data(ep_idx):
    """Load parquet + video frames for an episode."""
    chunk = ep_idx // 1000
    parquet_path = os.path.join(DROID_ROOT, f"data/chunk-{chunk:03d}/episode_{ep_idx:06d}.parquet")
    video_ext1 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_1_left/episode_{ep_idx:06d}.mp4")
    video_ext2 = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_2_left/episode_{ep_idx:06d}.mp4")

    df = pd.read_parquet(parquet_path)
    joints = np.stack(df["observation.state.joint_position"].values)  # (T, 7)

    frames_ext1 = iio.imread(video_ext1, plugin="pyav")  # (T, H, W, 3)
    frames_ext2 = iio.imread(video_ext2, plugin="pyav")

    return joints, frames_ext1, frames_ext2, df


def cam2base_to_mujoco_camera(extrinsics_6d):
    """Convert [x, y, z, euler_xyz] cam2base to MuJoCo camera pos + orientation.

    The posed JSON stores cam2base as [x, y, z, rx, ry, rz] (euler xyz).
    This is the transform from camera frame (OpenCV: +Z forward, +Y down)
    to robot base frame.

    MuJoCo camera convention: -Z is forward (into scene), +Y is up.
    """
    pos = np.array(extrinsics_6d[:3])
    rot_cam2base = R.from_euler("xyz", extrinsics_6d[3:6]).as_matrix()

    # Camera position in base/world frame = translation part of cam2base
    cam_pos = pos

    # Camera orientation: rot_cam2base maps OpenCV camera axes to base axes
    # OpenCV camera: X=right, Y=down, Z=forward
    # MuJoCo camera: X=right, Y=up, Z=backward
    # Conversion from OpenCV to MuJoCo camera frame: flip Y and Z
    opencv_to_mujoco = np.diag([1, -1, -1])

    # The full rotation from MuJoCo camera frame to world:
    # R_world_from_mjcam = R_world_from_opencv_cam @ R_opencv_cam_from_mjcam
    # R_opencv_cam_from_mjcam = opencv_to_mujoco^-1 = opencv_to_mujoco (self-inverse)
    cam_rot_world = rot_cam2base @ opencv_to_mujoco

    return cam_pos, cam_rot_world


def render_robot_silhouette(model, data, joints, cam_pos, cam_rot, width, height, fovy=45.0):
    """Render robot from given camera pose, return RGB + depth mask."""
    # Set joint positions
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    # Set up renderer
    renderer = mujoco.Renderer(model, height=height, width=width)

    # Set up camera
    scene_option = mujoco.MjvOption()

    # Use a free-floating camera
    cam = mujoco.MjvCamera()
    cam.type = mujoco.mjtCamera.mjCAMERA_FREE

    # Set camera position
    # MjvCamera lookat + distance + azimuth + elevation parameterization is awkward.
    # Instead, we'll add a camera to the model XML or use the low-level scene camera.
    # Actually, let's just render via the scene directly.

    renderer.update_scene(data, camera=cam, scene_option=scene_option)

    # Override the scene camera with our exact pose
    # scene.camera[0] is the abstract camera, but we need to set the
    # OpenGL camera matrix. Let's use a different approach:
    # render with mujoco's built-in offscreen rendering by constructing the
    # camera transform manually.

    # Actually the cleanest way: modify the scene's camera transform directly
    scene = renderer._scene
    # scene.camera[0].pos = cam_pos
    # scene.camera[0].forward = cam_rot @ [0, 0, -1] in MuJoCo convention

    # Set camera position and orientation in the scene
    scene.camera[0].pos[:] = cam_pos
    scene.camera[0].forward[:] = cam_rot[:, 2]  # MuJoCo -Z is forward, but cam_rot already accounts for this
    scene.camera[0].up[:] = cam_rot[:, 1]  # Y axis is up

    # Render
    rgb = renderer.render()
    renderer.enable_depth_rendering()
    depth = renderer.render()
    renderer.disable_depth_rendering()
    renderer.close()

    return rgb, depth


def render_robot_with_model_camera(xml_path, joints, cam_pos, cam_rot, width, height, fovy=42.5):
    """Render by adding a camera directly to the MuJoCo model XML."""
    # Read original XML
    with open(xml_path) as f:
        xml = f.read()

    # Convert rotation matrix to MuJoCo camera specification
    # MuJoCo camera: pos, xyaxes (first 3 = x-axis in world, last 3 = y-axis in world)
    cam_x = cam_rot[:, 0]  # right
    cam_y = cam_rot[:, 1]  # up

    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}" '
        f'xyaxes="{cam_x[0]} {cam_x[1]} {cam_x[2]} {cam_y[0]} {cam_y[1]} {cam_y[2]}" '
        f'fovy="{fovy}"/>'
    )

    # Insert camera into worldbody (after the light)
    xml = xml.replace(
        '<light name="top" pos="0 0 2" mode="trackcom"/>',
        f'<light name="top" pos="0 0 2" mode="trackcom"/>\n    {cam_str}'
    )

    # Load model with correct asset directory
    xml_dir = os.path.dirname(os.path.abspath(xml_path))
    assets = {}
    asset_dir = os.path.join(xml_dir, "assets")
    if os.path.isdir(asset_dir):
        for fn in os.listdir(asset_dir):
            fpath = os.path.join(asset_dir, fn)
            if os.path.isfile(fpath):
                with open(fpath, "rb") as f:
                    assets[os.path.join("assets", fn)] = f.read()

    model = mujoco.MjModel.from_xml_string(xml, assets)
    data = mujoco.MjData(model)

    # Set joint positions
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    renderer = mujoco.Renderer(model, height=height, width=width)

    # Find the camera ID
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")

    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()

    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.disable_depth_rendering()

    renderer.close()

    return rgb, depth


def create_overlay(real_frame, robot_rgb, robot_depth, alpha=0.5):
    """Overlay robot silhouette on real frame."""
    # Resize real frame to match render size if needed
    if real_frame.shape[:2] != robot_rgb.shape[:2]:
        real_frame = cv2.resize(real_frame, (robot_rgb.shape[1], robot_rgb.shape[0]))

    # Create mask from depth (non-infinite depth = robot)
    mask = (robot_depth < robot_depth.max() * 0.99).astype(np.uint8)

    # Create colored silhouette overlay
    overlay = real_frame.copy()
    # Tint robot pixels green
    robot_tint = real_frame.copy()
    robot_tint[mask > 0] = (
        robot_tint[mask > 0] * (1 - alpha) +
        np.array([0, 255, 0], dtype=np.float32) * alpha
    ).astype(np.uint8)

    # Also draw robot silhouette contour
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(robot_tint, contours, -1, (0, 255, 0), 1)

    return robot_tint, mask


def make_comparison_grid(real_frame, robot_rgb, overlay, title=""):
    """Create a 3-panel comparison: real | MuJoCo render | overlay."""
    h, w = robot_rgb.shape[:2]
    real_resized = cv2.resize(real_frame, (w, h))

    # Add labels
    panels = []
    for img, label in [(real_resized, "Real"), (robot_rgb, "MuJoCo"), (overlay, "Overlay")]:
        img = img.copy()
        cv2.putText(img, label, (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        panels.append(img)

    grid = np.concatenate(panels, axis=1)

    if title:
        # Add title bar
        title_bar = np.zeros((25, grid.shape[1], 3), dtype=np.uint8)
        cv2.putText(title_bar, title, (5, 18), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        grid = np.concatenate([title_bar, grid], axis=0)

    return grid


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print("Loading manifest...")
    episodes = load_manifest()
    print(f"Total matched episodes: {len(episodes)}")

    # Sample episodes: pick ones with varying match quality
    # Sort by match_dist to get exact matches, medium matches, and borderline
    episodes_sorted = sorted(episodes, key=lambda e: e["match_dist"])

    # Pick: 2 exact match (dist=0), 2 medium, 1 borderline
    exact = [e for e in episodes_sorted if e["match_dist"] == 0.0][:2]
    medium = [e for e in episodes_sorted if 0.02 < e["match_dist"] < 0.06][:2]
    borderline = [e for e in episodes_sorted if 0.08 < e["match_dist"] < 0.11][:1]
    samples = exact + medium + borderline

    print(f"\nSampling {len(samples)} episodes:")
    for s in samples:
        print(f"  ep_idx={s['ep_idx']}, match_dist={s['match_dist']:.4f}, posed_id={s['posed_ep_id']}")

    # Try multiple FOV values to find the best match
    fovys = [42.0, 50.0, 58.0, 69.0]  # degrees vertical

    for ep_info in samples:
        ep_idx = ep_info["ep_idx"]
        print(f"\n{'='*60}")
        print(f"Episode {ep_idx} (match_dist={ep_info['match_dist']:.4f})")
        print(f"  Posed ID: {ep_info['posed_ep_id']}")

        try:
            joints, frames_ext1, frames_ext2, df = load_episode_data(ep_idx)
        except Exception as e:
            print(f"  SKIP: {e}")
            continue

        n_frames = len(joints)
        frame_indices = [0, n_frames // 2, n_frames - 1]
        print(f"  Frames: {n_frames}, checking indices: {frame_indices}")

        # Process both cameras
        for cam_name, cam_key, frames in [
            ("ext1", "posed_ext1", frames_ext1),
            ("ext2", "posed_ext2", frames_ext2),
        ]:
            ext_6d = ep_info[cam_key]
            cam_pos, cam_rot = cam2base_to_mujoco_camera(ext_6d)
            print(f"  Camera {cam_name}: pos={cam_pos}, ext={ext_6d[:3]}")

            all_grids = []
            for fi in frame_indices:
                if fi >= len(frames) or fi >= len(joints):
                    continue

                real_frame = frames[fi]  # (H, W, 3) RGB
                joint_pos = joints[fi]

                # Try the default FOV
                fovy = 50.0
                try:
                    robot_rgb, robot_depth = render_robot_with_model_camera(
                        FRANKA_XML, joint_pos, cam_pos, cam_rot,
                        width=IMG_W, height=IMG_H, fovy=fovy
                    )
                except Exception as e:
                    print(f"    Render failed frame {fi}: {e}")
                    continue

                overlay, mask = create_overlay(real_frame, robot_rgb, robot_depth)
                grid = make_comparison_grid(
                    real_frame, robot_rgb, overlay,
                    title=f"ep{ep_idx} {cam_name} frame={fi}/{n_frames} dist={ep_info['match_dist']:.3f} fov={fovy}"
                )
                all_grids.append(grid)

            if all_grids:
                # Stack vertically: start/mid/end
                full_grid = np.concatenate(all_grids, axis=0)
                out_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_{cam_name}.png")
                cv2.imwrite(out_path, cv2.cvtColor(full_grid, cv2.COLOR_RGB2BGR))
                print(f"    Saved: {out_path}")

    # Also do a FOV sweep on the first exact-match episode for calibration
    if exact:
        ep_info = exact[0]
        ep_idx = ep_info["ep_idx"]
        print(f"\n{'='*60}")
        print(f"FOV sweep on ep {ep_idx}")
        joints, frames_ext1, frames_ext2, df = load_episode_data(ep_idx)
        mid = len(joints) // 2

        for cam_name, cam_key, frames in [("ext2", "posed_ext2", frames_ext2)]:
            ext_6d = ep_info[cam_key]
            cam_pos, cam_rot = cam2base_to_mujoco_camera(ext_6d)

            fov_grids = []
            for fovy in fovys:
                robot_rgb, robot_depth = render_robot_with_model_camera(
                    FRANKA_XML, joints[mid], cam_pos, cam_rot,
                    width=IMG_W, height=IMG_H, fovy=fovy
                )
                overlay, _ = create_overlay(frames[mid], robot_rgb, robot_depth)
                grid = make_comparison_grid(
                    frames[mid], robot_rgb, overlay,
                    title=f"FOV={fovy}deg ep{ep_idx} {cam_name} mid-frame"
                )
                fov_grids.append(grid)

            full = np.concatenate(fov_grids, axis=0)
            out_path = os.path.join(OUTPUT_DIR, f"fov_sweep_ep{ep_idx:06d}_{cam_name}.png")
            cv2.imwrite(out_path, cv2.cvtColor(full, cv2.COLOR_RGB2BGR))
            print(f"  FOV sweep saved: {out_path}")

    print(f"\nAll outputs in: {OUTPUT_DIR}")
    print("Done!")


if __name__ == "__main__":
    main()
