"""Project DROID EEF keypoint into camera image for sanity checking.

Analogous to libero/debug_libero_projection.py but for the DROID dataset.

For each selected frame, draws on the exterior camera image:
  - Green filled circle: EEF 3D position projected to pixel
  - Cyan ring + yellow line: same XY at robot base height (z=0), showing height
  - RGB axis lines: EEF local rotation frame (x=red, y=green, z=blue)
  - Text: camera intrinsics, EEF world position, frame info

Usage:
    MUJOCO_GL=egl python droid_testing/debug_droid_projection.py --episode 10
    MUJOCO_GL=egl python droid_testing/debug_droid_projection.py --episode 10 --camera ext1 --fy 130
"""

import argparse
import os
import sys
from pathlib import Path

import cv2
import numpy as np
from scipy.spatial.transform import Rotation

from load_droid_episode import load_episode

# ---------- defaults ----------
DEFAULT_FY = 130.0  # ZED 2 wide mode at 320×180
IMG_W, IMG_H = 320, 180
BASE_Z = 0.0  # Franka base is at world origin in DROID
AXIS_LEN = 0.08  # meters — length of drawn rotation axes


# ---------- projection helpers ----------

def build_projection(ext6d, fx, fy, cx, cy):
    """Return (R_cam_base, t_cam, K) from DROID 6D extrinsics."""
    R_base_cam = Rotation.from_euler("xyz", ext6d[3:6]).as_matrix()
    R_cam_base = R_base_cam.T
    t_cam = -R_cam_base @ ext6d[:3]
    K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    return R_cam_base, t_cam, K


def project(p_world, R_cam_base, t_cam, K):
    """Project a 3D world point to pixel (u, v). Returns None if behind camera."""
    p_cam = R_cam_base @ p_world + t_cam
    if p_cam[2] <= 0:
        return None
    p_px = K @ p_cam
    u = p_px[0] / p_px[2]
    v = p_px[1] / p_px[2]
    return float(u), float(v)


def px_int(uv):
    """Convert float (u,v) to int pixel coords."""
    if uv is None:
        return None
    return int(round(uv[0])), int(round(uv[1]))


def in_frame(uv_int, w, h):
    return uv_int is not None and 0 <= uv_int[0] < w and 0 <= uv_int[1] < h


# ---------- main rendering ----------

def render_projection_frame(
    rgb_frame,      # (H, W, 3) uint8 RGB
    eef_pos,        # (3,) world position
    eef_euler,      # (3,) euler xyz — EEF orientation in base frame
    ext6d,          # (6,) DROID camera extrinsics
    fx, fy, cx, cy,
    frame_idx,
    cam_name,
    ep_idx,
    language="",
):
    """Draw projection overlay on a single frame. Returns (H, W, 3) uint8 RGB."""
    h, w = rgb_frame.shape[:2]
    vis = rgb_frame.copy()

    R_cam_base, t_cam, K = build_projection(ext6d, fx, fy, cx, cy)
    eef_rot = Rotation.from_euler("xyz", eef_euler).as_matrix()

    # --- 1. Project EEF keypoint ---
    eef_px = project(eef_pos, R_cam_base, t_cam, K)
    eef_pxi = px_int(eef_px)

    if in_frame(eef_pxi, w, h):
        u, v = eef_pxi
        cv2.circle(vis, (u, v), 5, (0, 255, 0), -1)
        cv2.putText(vis, "eef", (u + 7, v - 7),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.40, (0, 255, 0), 1, cv2.LINE_AA)

        # --- 2. Base-plane projection (same XY, z=BASE_Z) ---
        eef_base = eef_pos.copy()
        eef_base[2] = BASE_Z
        base_px = project(eef_base, R_cam_base, t_cam, K)
        base_pxi = px_int(base_px)

        if in_frame(base_pxi, w, h):
            ub, vb = base_pxi
            cv2.circle(vis, (ub, vb), 5, (0, 255, 255), 2)  # cyan ring
            cv2.putText(
                vis,
                f"base z={BASE_Z:.2f}",
                (ub + 7, vb + 12),
                cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 255), 1, cv2.LINE_AA,
            )
            # Yellow height line connecting EEF to base-plane projection
            cv2.line(vis, (u, v), (ub, vb), (255, 255, 0), 2, cv2.LINE_AA)

        # --- 3. EEF rotation axes ---
        axis_colors = {"x": (255, 0, 0), "y": (0, 255, 0), "z": (0, 0, 255)}
        for i, (axis_name, color) in enumerate(axis_colors.items()):
            endpoint = eef_pos + eef_rot[:, i] * AXIS_LEN
            ax_px = project(endpoint, R_cam_base, t_cam, K)
            ax_pxi = px_int(ax_px)
            if in_frame(ax_pxi, w, h):
                cv2.line(vis, (u, v), ax_pxi, color, 2, cv2.LINE_AA)
                cv2.circle(vis, ax_pxi, 3, color, -1)

    else:
        label = f"eef out of frame"
        if eef_px is not None:
            label += f": ({eef_px[0]:.0f}, {eef_px[1]:.0f})"
        cv2.putText(vis, label, (8, 16),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.40, (255, 180, 0), 1, cv2.LINE_AA)

    # --- 4. Text overlays ---
    # Top: episode / frame / task info
    fovy = 2 * np.arctan(h / (2 * fy)) * 180 / np.pi
    header = f"ep={ep_idx} frame={frame_idx} cam={cam_name}"
    _text_shadow(vis, header, (8, h - 24), 0.38)

    intr_text = f"fx={fx:.0f} fy={fy:.0f} cx={cx:.0f} cy={cy:.0f} fovy={fovy:.1f}"
    _text_shadow(vis, intr_text, (8, h - 10), 0.33)

    eef_text = f"eef=[{eef_pos[0]:.3f}, {eef_pos[1]:.3f}, {eef_pos[2]:.3f}]"
    _text_shadow(vis, eef_text, (8, 14), 0.33)

    if language:
        _text_shadow(vis, language[:60], (8, 28), 0.30, color=(200, 200, 255))

    return vis


def _text_shadow(img, text, org, scale, color=(255, 255, 255)):
    """Draw text with dark shadow for readability."""
    cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, (0, 0, 0), 2, cv2.LINE_AA)
    cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, 1, cv2.LINE_AA)


# ---------- CLI ----------

def main():
    parser = argparse.ArgumentParser(description="DROID camera / EEF projection debug")
    parser.add_argument("--episode", type=int, default=10)
    parser.add_argument("--camera", type=str, default="ext2", choices=["ext1", "ext2"])
    parser.add_argument("--fy", type=float, default=DEFAULT_FY)
    parser.add_argument("--frames", type=int, nargs="+", default=None,
                        help="Frame indices (default: first, middle, last)")
    parser.add_argument("--scale", type=int, default=3, help="Upscale factor for output")
    parser.add_argument("--out-image", type=str, default="")
    parser.add_argument("--out-video", type=str, default="", help="Render full episode as mp4")
    parser.add_argument("--fps", type=int, default=15)
    args = parser.parse_args()

    fx = fy = args.fy
    cx, cy = IMG_W / 2.0, IMG_H / 2.0
    S = args.scale

    print(f"Loading episode {args.episode}...")
    ep = load_episode(args.episode)
    T = ep["num_frames"]
    lang = ep.get("language_instruction", "")
    print(f"  Frames: {T}, Task: {lang}")

    cam = args.camera
    images = ep[f"{cam}_images"]
    ext = ep[f"{cam}_extrinsics"]
    print(f"  Camera: {cam}, extrinsics: {ext}")

    if args.frames is None:
        frame_indices = [0, T // 2, T - 1]
    else:
        frame_indices = [min(f, T - 1) for f in args.frames]

    # ---------- single-frame / strip output ----------
    panels = []
    for fidx in frame_indices:
        eef_pos = np.array(ep["cartesian_positions"][fidx][:3], dtype=np.float64)
        eef_euler = np.array(ep["cartesian_positions"][fidx][3:6], dtype=np.float64)

        vis = render_projection_frame(
            images[fidx], eef_pos, eef_euler, ext,
            fx, fy, cx, cy,
            fidx, cam, args.episode, lang,
        )
        # Upscale
        vis_up = cv2.resize(vis, (IMG_W * S, IMG_H * S), interpolation=cv2.INTER_LANCZOS4)
        panels.append(vis_up)

    strip = np.concatenate(panels, axis=0)

    out_path = args.out_image or f"droid_testing/output/debug_proj_ep{args.episode}_{cam}.png"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, cv2.cvtColor(strip, cv2.COLOR_RGB2BGR))
    print(f"Wrote image: {out_path}")

    # ---------- optional full-episode video ----------
    if args.out_video:
        os.makedirs(os.path.dirname(args.out_video), exist_ok=True)
        h0, w0 = panels[0].shape[:2]
        writer = cv2.VideoWriter(
            args.out_video,
            cv2.VideoWriter_fourcc(*"mp4v"),
            float(args.fps),
            (w0, h0),
        )
        if not writer.isOpened():
            raise RuntimeError(f"Failed to open video writer: {args.out_video}")

        for fidx in range(T):
            eef_pos = np.array(ep["cartesian_positions"][fidx][:3], dtype=np.float64)
            eef_euler = np.array(ep["cartesian_positions"][fidx][3:6], dtype=np.float64)
            vis = render_projection_frame(
                images[fidx], eef_pos, eef_euler, ext,
                fx, fy, cx, cy,
                fidx, cam, args.episode, lang,
            )
            vis_up = cv2.resize(vis, (w0, h0), interpolation=cv2.INTER_LANCZOS4)
            writer.write(cv2.cvtColor(vis_up, cv2.COLOR_RGB2BGR))

        writer.release()
        print(f"Wrote video: {args.out_video}")


if __name__ == "__main__":
    main()
