#!/usr/bin/env python3
"""Visualize GT EEF projection for a panda dataset episode.

Shows on each frame:
  - Robot model alpha-blended over RGB image
  - Green filled circle:  EEF (virtual_gripper_keypoint) projected to pixel
  - Cyan open circle:     same XY at robot base Z (base-plane projection)
  - Yellow line:          connecting EEF to its base-plane drop
  - Red/green/blue lines: EEF local X/Y/Z rotation axes (0.08 m)
  - Header text with dataset info

Usage:
  # Full episode video (default):
  python vis_dataset_gt.py
  python vis_dataset_gt.py --dataset scratch/parsed_panda_dummy --episode 1

  # Single frame PNG:
  python vis_dataset_gt.py --frame 3
  python vis_dataset_gt.py --frame 3 --out scratch/my_vis.png
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import argparse
import numpy as np
import cv2
import mujoco
from pathlib import Path

from ExoConfigs.panda_exo import PANDA_BASE_ONLY_CONFIG, VIRTUAL_GRIPPER_BODY_NAME
from exo_utils import get_body_pose_in_world, render_from_camera_pose

N_ARM_JOINTS  = 7
GRIPPER_POS_MAX = 0.04   # metres per finger
AXIS_LEN      = 0.08     # metres, length of rotation axes drawn

# Hard-coded intrinsics from simple_dataset_record_panda.py (full-res 1920×1080)
CAM_K_FULL = np.array([
    [1.58847596e03, 0.0, 9.59500000e02],
    [0.0,           1.58847596e03, 5.39500000e02],
    [0.0,           0.0,           1.0],
], dtype=np.float64)
CAM_K_FULL_W, CAM_K_FULL_H = 1920, 1080


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def project(p_world: np.ndarray, w2c: np.ndarray, K: np.ndarray):
    """Project a world-frame 3D point to (u, v) pixel using w2c and pixel-space K.
    Returns None if the point is behind the camera.
    """
    p_c = w2c[:3, :] @ np.append(p_world, 1.0)
    if p_c[2] <= 0:
        return None
    u = K[0, 0] * p_c[0] / p_c[2] + K[0, 2]
    v = K[1, 1] * p_c[1] / p_c[2] + K[1, 2]
    return int(round(u)), int(round(v))


def in_frame(pt, w, h):
    return pt is not None and 0 <= pt[0] < w and 0 <= pt[1] < h


def draw_text(img, text, pos, scale=0.45, color=(255,255,255), thickness=1):
    cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, scale, (20,20,20), thickness+1, cv2.LINE_AA)
    cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, scale, color,     thickness,   cv2.LINE_AA)


def process_frame(ep_dir: Path, ts: str, frame_i: int, episode: int,
                  model, data, base_z: float) -> np.ndarray:
    """Render a single annotated frame. Returns a BGR numpy array."""
    img_bgr      = cv2.imread(str(ep_dir / f"{ts}.png"))
    joint_state  = np.load(ep_dir / f"{ts}.npy")
    camera_pose  = np.load(ep_dir / f"{ts}_camera_pose.npy")  # (4,4) w2c
    cam_K_norm   = np.load(ep_dir / f"{ts}_cam_K.npy")        # (3,3) normalized

    h, w = img_bgr.shape[:2]

    K = cam_K_norm.copy()
    K[0] *= w
    K[1] *= h

    # FK
    n_arm = min(N_ARM_JOINTS, data.qpos.size)
    data.qpos[:n_arm] = joint_state[:n_arm]
    if data.qpos.size >= N_ARM_JOINTS + 2:
        gw_m = joint_state[7] * GRIPPER_POS_MAX
        data.qpos[N_ARM_JOINTS]     = gw_m
        data.qpos[N_ARM_JOINTS + 1] = gw_m
    mujoco.mj_forward(model, data)

    try:
        eef_pose = get_body_pose_in_world(model, data, VIRTUAL_GRIPPER_BODY_NAME)
    except ValueError:
        eef_pose = np.eye(4)
    eef_pos = eef_pose[:3, 3]
    eef_rot = eef_pose[:3, :3]

    pt_eef = project(eef_pos, camera_pose, K)

    eef_base = eef_pos.copy(); eef_base[2] = base_z
    pt_base  = project(eef_base, camera_pose, K)

    axis_endpoints = {
        "x": eef_pos + eef_rot[:, 0] * AXIS_LEN,
        "y": eef_pos + eef_rot[:, 1] * AXIS_LEN,
        "z": eef_pos + eef_rot[:, 2] * AXIS_LEN,
    }
    axis_colors = {"x": (0, 0, 255), "y": (0, 255, 0), "z": (255, 0, 0)}

    # Render robot and alpha-blend
    K_render = CAM_K_FULL.copy()
    K_render[0] *= w / CAM_K_FULL_W
    K_render[1] *= h / CAM_K_FULL_H
    try:
        rendered_rgb = render_from_camera_pose(model, data, camera_pose, K_render, h, w)
        img_rgb  = img_bgr[:, :, ::-1]
        blended  = (img_rgb.astype(float) * 0.5 + rendered_rgb.astype(float) * 0.5).astype(np.uint8)
        vis = np.ascontiguousarray(blended[:, :, ::-1])
    except Exception as e:
        print(f"[warn] frame {ts} render failed: {e}")
        vis = img_bgr.copy()

    if in_frame(pt_eef, w, h):
        u, v = pt_eef
        if in_frame(pt_base, w, h):
            ug, vg = pt_base
            cv2.line(vis, (u, v), (ug, vg), (0, 255, 255), 2, cv2.LINE_AA)
            cv2.circle(vis, (ug, vg), 7, (0, 255, 255), 2)
            draw_text(vis, f"base z={base_z:.3f}", (ug + 8, vg + 12), color=(0,255,255))
        for name, ep in axis_endpoints.items():
            pt_ax = project(ep, camera_pose, K)
            if in_frame(pt_ax, w, h):
                cv2.line(vis, (u, v), pt_ax, axis_colors[name], 2, cv2.LINE_AA)
                cv2.circle(vis, pt_ax, 3, axis_colors[name], -1)
        cv2.circle(vis, (u, v), 7, (0, 255, 0), -1)
        draw_text(vis, "eef", (u + 8, v - 8), color=(0,255,0))
    else:
        draw_text(vis, f"eef proj out of frame: {pt_eef}", (10, 22), color=(0,165,255))

    header = (f"ep={episode:03d} frame={frame_i} ts={ts}  "
              f"eef=({eef_pos[0]:.3f},{eef_pos[1]:.3f},{eef_pos[2]:.3f})  uv={pt_eef}")
    draw_text(vis, header, (10, h - 24), scale=0.4)
    intr = f"fx={K[0,0]:.1f} fy={K[1,1]:.1f} cx={K[0,2]:.1f} cy={K[1,2]:.1f}  {w}x{h}"
    draw_text(vis, intr, (10, h - 8), scale=0.4, color=(200,200,200))

    return vis


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="scratch/parsed_panda_dummy")
    parser.add_argument("--episode", type=int, default=1)
    parser.add_argument("--frame",   type=int, default=-1,
                        help="Single frame index to render as PNG (-1 = full episode video)")
    parser.add_argument("--fps",     type=int, default=5,
                        help="Video framerate (default: 5, matches recording fps)")
    parser.add_argument("--out", default="",
                        help="Output path (.png for single frame, .mp4 for video)")
    args = parser.parse_args()

    dataset_root = Path(args.dataset)
    ep_dir = dataset_root / f"episode_{args.episode:03d}"
    if not ep_dir.exists():
        raise FileNotFoundError(f"Episode dir not found: {ep_dir}")

    frames = sorted(set(int(p.stem.split("_")[0]) for p in ep_dir.glob("*.png")))
    if not frames:
        raise RuntimeError(f"No frames found in {ep_dir}")

    # Build MuJoCo model once
    robot_config = PANDA_BASE_ONLY_CONFIG
    model = mujoco.MjModel.from_xml_string(robot_config.xml)
    data  = mujoco.MjData(model)
    base_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, "link0")
    base_z = float(data.xpos[base_body_id][2]) if base_body_id >= 0 else 0.0

    single_frame = args.frame >= 0

    if single_frame:
        if args.frame >= len(frames):
            raise ValueError(f"Frame {args.frame} out of range ({len(frames)} frames)")
        ts = f"{frames[args.frame]:06d}"
        vis = process_frame(ep_dir, ts, args.frame, args.episode, model, data, base_z)
        out_path = args.out or f"scratch/vis_gt_ep{args.episode:03d}_fr{args.frame:03d}.png"
        Path(out_path).parent.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(out_path, vis)
        print(f"Saved: {out_path}")
    else:
        out_path = args.out or f"scratch/vis_gt_ep{args.episode:03d}.mp4"
        Path(out_path).parent.mkdir(parents=True, exist_ok=True)

        # Determine frame size from first frame
        first_vis = process_frame(ep_dir, f"{frames[0]:06d}", 0, args.episode, model, data, base_z)
        fh, fw = first_vis.shape[:2]

        writer = cv2.VideoWriter(
            out_path,
            cv2.VideoWriter_fourcc(*"mp4v"),
            float(args.fps),
            (fw, fh),
        )
        if not writer.isOpened():
            raise RuntimeError(f"Failed to open video writer: {out_path}")

        writer.write(first_vis)
        for i, frame_idx in enumerate(frames[1:], start=1):
            ts = f"{frame_idx:06d}"
            vis = process_frame(ep_dir, ts, i, args.episode, model, data, base_z)
            writer.write(vis)
            print(f"\r  frame {i+1}/{len(frames)}", end="", flush=True)

        writer.release()
        print(f"\nSaved: {out_path}  ({len(frames)} frames @ {args.fps} fps)")


if __name__ == "__main__":
    main()
