"""Generate sample DROID verification video with keypoint, robot silhouette, and ground projection.

For each frame, renders:
  - Robot silhouette (MuJoCo render, green semi-transparent overlay + contour)
  - EEF keypoint (green circle) projected via camera intrinsics/extrinsics
  - Base-plane projection (cyan ring) showing the height visualization
  - Yellow height line connecting EEF to base-plane projection
  - EEF rotation axes (RGB)

Usage:
    MUJOCO_GL=egl python generate_sample_video.py
    MUJOCO_GL=egl python generate_sample_video.py --episodes 2 9 --camera ext2
"""

import argparse
import json
import os
import numpy as np
import pandas as pd
import imageio.v3 as iio
import cv2
import mujoco
from scipy.spatial.transform import Rotation as R

DROID_ROOT = "/data/cameron/droid"
MANIFEST_PATH = os.path.join(DROID_ROOT, "manifest_posed_ext2.json")
POSED_JSON_PATH = "/data/cameron/para_droid_pretrain/posed_droid/pnp_cam2base_multiview.json"
INTRINSICS_PATH = "/data/cameron/random/droid_replay/submodules/extrinsics_keypoint/intrinsics.json"
FRANKA_XML = "/data/cameron/para/droid_testing/franka_panda/panda_nohand.xml"
OUTPUT_DIR = "/data/cameron/para_droid_pretrain/posed_droid/sample_videos"

NATIVE_W, NATIVE_H = 1280, 720
VID_W, VID_H = 320, 180
BASE_Z = 0.0
AXIS_LEN = 0.08
SCALE = 3  # upscale for output


# ── Asset loading ──────────────────────────────────────────────────────────────

def load_franka_assets():
    xml_dir = os.path.dirname(os.path.abspath(FRANKA_XML))
    assets = {}
    for fn in os.listdir(os.path.join(xml_dir, "assets")):
        with open(os.path.join(xml_dir, "assets", fn), "rb") as f:
            assets[os.path.join("assets", fn)] = f.read()
    return assets

ASSETS = load_franka_assets()


# ── Projection helpers ────────────────────────────────────────────────────────

def build_projection(ext6d, fx, fy, cx, cy):
    """Return (R_cam_base, t_cam, K) from DROID 6D extrinsics (cam2base)."""
    R_base_cam = R.from_euler("xyz", ext6d[3:6]).as_matrix()
    R_cam_base = R_base_cam.T
    t_cam = -R_cam_base @ np.array(ext6d[:3])
    K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    return R_cam_base, t_cam, K


def project(p_world, R_cam_base, t_cam, K):
    """Project a 3D world point to pixel (u, v). Returns None if behind camera."""
    p_cam = R_cam_base @ p_world + t_cam
    if p_cam[2] <= 0:
        return None
    p_px = K @ p_cam
    return float(p_px[0] / p_px[2]), float(p_px[1] / p_px[2])


def px_int(uv):
    if uv is None:
        return None
    return int(round(uv[0])), int(round(uv[1]))


def in_frame(uv_int, w, h):
    return uv_int is not None and 0 <= uv_int[0] < w and 0 <= uv_int[1] < h


# ── MuJoCo silhouette rendering ──────────────────────────────────────────────

def render_silhouette(joints, cam_pos, cam_rot, width, height, fovy):
    """Render robot silhouette, return binary mask."""
    with open(FRANKA_XML) as f:
        xml = f.read()

    cam_x = cam_rot[:, 0]
    cam_y = cam_rot[:, 1]
    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}" '
        f'xyaxes="{cam_x[0]} {cam_x[1]} {cam_x[2]} {cam_y[0]} {cam_y[1]} {cam_y[2]}" '
        f'fovy="{fovy}"/>'
    )
    xml = xml.replace(
        '<light name="top" pos="0 0 2" mode="trackcom"/>',
        f'<light name="top" pos="0 0 2" mode="trackcom"/>\n    {cam_str}'
    )

    model = mujoco.MjModel.from_xml_string(xml, ASSETS)
    data = mujoco.MjData(model)
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")

    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.disable_depth_rendering()
    renderer.close()

    mask = (depth < depth.max() * 0.99).astype(np.uint8)
    return mask


# ── Combined frame rendering ─────────────────────────────────────────────────

def render_frame(
    rgb_frame,    # (H, W, 3) uint8 RGB at VID resolution
    joints,       # (7,) joint positions
    eef_pos,      # (3,) EEF world position
    eef_euler,    # (3,) EEF euler xyz
    ext6d,        # (6,) camera extrinsics
    fx, fy, cx, cy,  # intrinsics at VID resolution
    cam_pos, cam_rot, fovy,  # MuJoCo camera params
    frame_idx, n_frames, ep_idx, cam_name, language="",
):
    """Render a single annotated frame with silhouette + keypoint + ground projection."""
    h, w = rgb_frame.shape[:2]
    vis = rgb_frame.copy()

    # --- 1. Robot silhouette overlay ---
    try:
        mask = render_silhouette(joints, cam_pos, cam_rot, w, h, fovy)
        # Green semi-transparent fill
        green_overlay = vis.copy()
        green_overlay[mask > 0] = (
            green_overlay[mask > 0].astype(np.float32) * 0.6 +
            np.array([0, 200, 0], dtype=np.float32) * 0.4
        ).astype(np.uint8)
        vis = green_overlay

        # Contour
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(vis, contours, -1, (0, 255, 0), 1, cv2.LINE_AA)
    except Exception as e:
        _text_shadow(vis, f"silhouette err: {str(e)[:40]}", (4, h - 36), 0.28, (255, 100, 100))

    # --- 2. EEF keypoint projection ---
    R_cam_base, t_cam, K = build_projection(ext6d, fx, fy, cx, cy)
    eef_rot = R.from_euler("xyz", eef_euler).as_matrix()

    eef_px = project(eef_pos, R_cam_base, t_cam, K)
    eef_pxi = px_int(eef_px)

    if in_frame(eef_pxi, w, h):
        u, v = eef_pxi
        # Filled green circle for EEF
        cv2.circle(vis, (u, v), 4, (0, 255, 0), -1, cv2.LINE_AA)
        cv2.circle(vis, (u, v), 5, (255, 255, 255), 1, cv2.LINE_AA)

        # --- 3. Base-plane projection (height visualization) ---
        eef_base = eef_pos.copy()
        eef_base[2] = BASE_Z
        base_px = project(eef_base, R_cam_base, t_cam, K)
        base_pxi = px_int(base_px)

        if in_frame(base_pxi, w, h):
            ub, vb = base_pxi
            cv2.circle(vis, (ub, vb), 4, (0, 255, 255), 2, cv2.LINE_AA)  # cyan ring
            # Yellow height line
            cv2.line(vis, (u, v), (ub, vb), (255, 255, 0), 1, cv2.LINE_AA)
            _text_shadow(vis, f"h={eef_pos[2]:.2f}", (ub + 6, vb + 4), 0.28, (0, 255, 255))

        # --- 4. EEF rotation axes ---
        axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
        for i, color in enumerate(axis_colors):
            endpoint = eef_pos + eef_rot[:, i] * AXIS_LEN
            ax_px = project(endpoint, R_cam_base, t_cam, K)
            ax_pxi = px_int(ax_px)
            if in_frame(ax_pxi, w, h):
                cv2.line(vis, (u, v), ax_pxi, color, 1, cv2.LINE_AA)

    # --- 5. Text overlays ---
    _text_shadow(vis, f"ep={ep_idx} f={frame_idx}/{n_frames} {cam_name}", (4, 12), 0.30)
    eef_text = f"eef=[{eef_pos[0]:.2f},{eef_pos[1]:.2f},{eef_pos[2]:.2f}]"
    _text_shadow(vis, eef_text, (4, 24), 0.28)
    if language:
        _text_shadow(vis, language[:50], (4, 36), 0.25, (200, 200, 255))

    return vis


def _text_shadow(img, text, org, scale, color=(255, 255, 255)):
    cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, (0, 0, 0), 2, cv2.LINE_AA)
    cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, 1, cv2.LINE_AA)


# ── Data loading helpers ──────────────────────────────────────────────────────

def find_intrinsics(intrinsics_db, relative_path, cam_key):
    for key, val in intrinsics_db.items():
        if relative_path in key:
            if cam_key in val:
                return val[cam_key]  # [fx, cx, fy, cy]
    return None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--episodes", type=int, nargs="+", default=[9, 4, 2])
    parser.add_argument("--camera", type=str, default="ext2", choices=["ext1", "ext2"])
    parser.add_argument("--fps", type=int, default=15)
    parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
    args = parser.parse_args()

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    cam_name = args.camera
    cam_intr_key = f"exterior_image_{'2' if cam_name == 'ext2' else '1'}_left"
    cam_vid_key = f"observation.images.exterior_{'2' if cam_name == 'ext2' else '1'}_left"
    ext_key = f"posed_{cam_name}"

    print("Loading manifest + posed JSON + intrinsics...")
    with open(MANIFEST_PATH) as f:
        manifest = json.load(f)
    episodes_list = manifest["episodes"]
    ep_lookup = {e["ep_idx"]: e for e in episodes_list}

    with open(POSED_JSON_PATH) as f:
        posed_json = json.load(f)

    with open(INTRINSICS_PATH) as f:
        intrinsics_db = json.load(f)

    for ep_idx in args.episodes:
        if ep_idx not in ep_lookup:
            print(f"Episode {ep_idx} not in manifest, skipping")
            continue

        ep_info = ep_lookup[ep_idx]
        posed_ep_id = ep_info["posed_ep_id"]
        posed_entry = posed_json.get(posed_ep_id, {})
        relative_path = posed_entry.get("relative_path", "")

        print(f"\n{'='*60}")
        print(f"Episode {ep_idx} — {posed_ep_id}")
        print(f"  Path: {relative_path}")

        # Load intrinsics
        intr = find_intrinsics(intrinsics_db, relative_path, cam_intr_key)
        if intr is None:
            print(f"  No intrinsics found, skipping")
            continue

        fx_native, cx_native, fy_native, cy_native = intr
        # Scale to video resolution
        sx, sy = VID_W / NATIVE_W, VID_H / NATIVE_H
        fx, fy = fx_native * sx, fy_native * sy
        cx, cy = cx_native * sx, cy_native * sy
        fovy = 2 * np.degrees(np.arctan(NATIVE_H / (2 * fy_native)))
        print(f"  Intrinsics (native): fx={fx_native:.1f} fy={fy_native:.1f}")
        print(f"  Intrinsics (320x180): fx={fx:.1f} fy={fy:.1f} cx={cx:.1f} cy={cy:.1f}")
        print(f"  FOV: {fovy:.1f}°")

        # Load episode data
        chunk = ep_idx // 1000
        parquet = os.path.join(DROID_ROOT, f"data/chunk-{chunk:03d}/episode_{ep_idx:06d}.parquet")
        vid_path = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/{cam_vid_key}/episode_{ep_idx:06d}.mp4")

        df = pd.read_parquet(parquet)
        joints = np.stack(df["observation.state.joint_position"].values)
        cart_pos = np.stack(df["observation.state.cartesian_position"].values)  # (T, 6) [x,y,z,r,p,y]
        language = str(df["language_instruction"].iloc[0]) if "language_instruction" in df.columns else ""

        frames = iio.imread(vid_path, plugin="pyav")
        n = min(len(joints), len(frames), len(cart_pos))
        if args.max_frames > 0:
            n = min(n, args.max_frames)
        print(f"  Frames: {n}, Language: {language}")

        # Camera params
        ext6d = np.array(ep_info[ext_key])
        pos = np.array(ext6d[:3])
        rot = R.from_euler("xyz", ext6d[3:6]).as_matrix()
        cam_pos = pos
        opencv_to_mj = np.diag([1, -1, -1])
        cam_rot = rot @ opencv_to_mj

        # Render video
        out_w, out_h = VID_W * SCALE, VID_H * SCALE
        out_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_{cam_name}.mp4")
        writer = cv2.VideoWriter(
            out_path,
            cv2.VideoWriter_fourcc(*"mp4v"),
            float(args.fps),
            (out_w, out_h),
        )

        for fi in range(n):
            if fi % 50 == 0:
                print(f"  Frame {fi}/{n}...")

            eef_pos = np.array(cart_pos[fi][:3], dtype=np.float64)
            eef_euler = np.array(cart_pos[fi][3:6], dtype=np.float64)

            vis = render_frame(
                frames[fi], joints[fi], eef_pos, eef_euler,
                ext6d, fx, fy, cx, cy,
                cam_pos, cam_rot, fovy,
                fi, n, ep_idx, cam_name, language,
            )

            vis_up = cv2.resize(vis, (out_w, out_h), interpolation=cv2.INTER_LANCZOS4)
            writer.write(cv2.cvtColor(vis_up, cv2.COLOR_RGB2BGR))

        writer.release()
        print(f"  Video saved: {out_path}")

        # Also save first/mid/last as PNG strip
        strip_indices = [0, n // 2, n - 1]
        strip_panels = []
        for fi in strip_indices:
            eef_pos = np.array(cart_pos[fi][:3], dtype=np.float64)
            eef_euler = np.array(cart_pos[fi][3:6], dtype=np.float64)
            vis = render_frame(
                frames[fi], joints[fi], eef_pos, eef_euler,
                ext6d, fx, fy, cx, cy,
                cam_pos, cam_rot, fovy,
                fi, n, ep_idx, cam_name, language,
            )
            vis_up = cv2.resize(vis, (out_w, out_h), interpolation=cv2.INTER_LANCZOS4)
            strip_panels.append(vis_up)
        strip = np.concatenate(strip_panels, axis=0)
        strip_path = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_{cam_name}_strip.png")
        cv2.imwrite(strip_path, cv2.cvtColor(strip, cv2.COLOR_RGB2BGR))
        print(f"  Strip saved: {strip_path}")

    print(f"\nAll outputs: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
