"""
Final overlay visualization: render Franka in MuJoCo, overlay on real DROID images,
and mark projected joint keypoints for alignment verification.

Usage:
    MUJOCO_GL=egl python droid_testing/final_overlay.py --episode 0
"""

import argparse
import numpy as np
import mujoco
import cv2
from pathlib import Path
from scipy.spatial.transform import Rotation
import re
import sys

from load_droid_episode import load_episode

FRANKA_DIR = Path(__file__).parent / "franka_panda"
W, H = 320, 180
SCALE = 3

BODY_NAMES = ["link0", "link2", "link4", "link6", "hand"]
BODY_COLORS = [
    (255, 50, 50),     # red - base
    (50, 255, 50),     # green - shoulder
    (50, 50, 255),     # blue - elbow
    (255, 255, 50),    # yellow - wrist
    (255, 50, 255),    # magenta - hand/EEF
]


def build_render_model(cam_pos, cam_quat_wxyz, fovy_deg):
    """Build MuJoCo model with camera, no floor/sky."""
    with open(FRANKA_DIR / "scene.xml") as f:
        xml = f.read()

    ps = " ".join(f"{v}" for v in cam_pos)
    qs = " ".join(f"{v}" for v in cam_quat_wxyz)
    cam_xml = f'<body name="cb" pos="{ps}" quat="{qs}"><camera name="dc" fovy="{fovy_deg:.4f}" mode="fixed"/></body>'
    xml = xml.replace("</worldbody>", cam_xml + "</worldbody>")
    xml = re.sub(r'<geom name="floor"[^/]*/>', "", xml)
    xml = re.sub(r'<texture type="skybox"[^/]*/>', "", xml)
    xml = re.sub(r'<rgba haze="[^"]*"/>', "", xml)

    tmp = FRANKA_DIR / "_tmp_final.xml"
    with open(tmp, "w") as f:
        f.write(xml)
    try:
        model = mujoco.MjModel.from_xml_path(str(tmp))
    finally:
        tmp.unlink(missing_ok=True)
    return model, mujoco.MjData(model)


def build_fk_model():
    """Build a plain model for forward kinematics (no camera needed)."""
    return mujoco.MjModel.from_xml_path(str(FRANKA_DIR / "scene.xml")), None


def set_joints(model, data, jp7, gp=0.0):
    for i in range(7):
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, f"joint{i+1}")
        data.qpos[model.jnt_qposadr[jid]] = jp7[i]
    fw = np.clip(float(gp) * 0.04, 0.0, 0.04)
    for fn in ["finger_joint1", "finger_joint2"]:
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, fn)
        data.qpos[model.jnt_qposadr[jid]] = fw
    mujoco.mj_forward(model, data)


def extrinsics_to_mj(ext6d, fy):
    pos = ext6d[:3]
    R_bc = Rotation.from_euler("xyz", ext6d[3:6]).as_matrix()
    R_mj = R_bc @ np.diag([1., -1., -1.])
    q = Rotation.from_matrix(R_mj).as_quat()
    fovy = 2.0 * np.arctan(H / (2.0 * fy)) * 180.0 / np.pi
    return pos, np.array([q[3], q[0], q[1], q[2]]), fovy


def project_opencv(p_world, ext6d, fx, fy, cx, cy):
    R_bc = Rotation.from_euler("xyz", ext6d[3:6]).as_matrix()
    R_cb = R_bc.T
    t_cb = -R_cb @ ext6d[:3]
    p_cam = R_cb @ p_world + t_cb
    if p_cam[2] <= 0:
        return None
    u = fx * p_cam[0] / p_cam[2] + cx
    v = fy * p_cam[1] / p_cam[2] + cy
    return int(round(u)), int(round(v))


def render_frame(model, data, w, h):
    cid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "dc")
    renderer = mujoco.Renderer(model, height=h, width=w)
    renderer.update_scene(data, camera=cid)
    rgb = renderer.render().copy()
    renderer.close()
    return rgb


def draw_keypoints(img, body_positions, ext6d, fx, fy, cx, cy, scale=1):
    """Draw projected body keypoints on image."""
    for pos_world, name, color in zip(body_positions, BODY_NAMES, BODY_COLORS):
        px = project_opencv(pos_world, ext6d, fx, fy, cx, cy)
        if px is not None and 0 <= px[0] < W and 0 <= px[1] < H:
            sx, sy = px[0] * scale, px[1] * scale
            cv2.circle(img, (sx, sy), 5 * scale // 2, color, 2)
            cv2.putText(img, name, (sx + 4*scale//2, sy),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.35 * scale / 2, color, max(1, scale//2))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode", type=int, default=0)
    parser.add_argument("--fy", type=float, default=150.0)
    parser.add_argument("--output", type=str, default="droid_testing/output_final")
    args = parser.parse_args()

    fx = fy = args.fy
    cx, cy = W / 2, H / 2

    print(f"Loading episode {args.episode}...")
    ep = load_episode(args.episode)
    T = ep["num_frames"]
    print(f"  Frames: {T}")

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    # FK model for getting body positions
    fk_model = mujoco.MjModel.from_xml_path(str(FRANKA_DIR / "scene.xml"))
    fk_data = mujoco.MjData(fk_model)

    frames_to_show = [0, T // 4, T // 2, 3 * T // 4, T - 1]

    for cam_name in ["ext1", "ext2"]:
        ext = ep[f"{cam_name}_extrinsics"]
        images = ep[f"{cam_name}_images"]
        print(f"\n=== {cam_name}: extrinsics={ext[:3]} ===")

        cam_pos, cam_quat, fovy = extrinsics_to_mj(ext, fy)
        render_model, render_data = build_render_model(cam_pos, cam_quat, fovy)

        all_rows = []
        for fidx in frames_to_show:
            if fidx >= T:
                continue

            jp = ep["joint_positions"][fidx]
            gp = ep["gripper_positions"][fidx]

            # FK for body positions
            set_joints(fk_model, fk_data, jp, gp)
            body_positions = []
            for bname in BODY_NAMES:
                bid = mujoco.mj_name2id(fk_model, mujoco.mjtObj.mjOBJ_BODY, bname)
                body_positions.append(fk_data.xpos[bid].copy())

            # Render
            set_joints(render_model, render_data, jp, gp)
            rendered = render_frame(render_model, render_data, W, H)

            real = images[fidx].copy()

            # Create robot mask
            gray = cv2.cvtColor(rendered, cv2.COLOR_RGB2GRAY)
            mask = gray > 3

            # --- Panel 1: Real image with keypoints ---
            real_up = cv2.resize(real, (W*SCALE, H*SCALE), interpolation=cv2.INTER_LINEAR)
            draw_keypoints(real_up, body_positions, ext, fx, fy, cx, cy, scale=SCALE)
            cv2.putText(real_up, f"Real f={fidx}", (5, 20*SCALE//3),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5*SCALE/3, (255, 255, 255), max(1, SCALE//3))

            # --- Panel 2: MuJoCo render upscaled ---
            rend_up = cv2.resize(rendered, (W*SCALE, H*SCALE), interpolation=cv2.INTER_LINEAR)
            draw_keypoints(rend_up, body_positions, ext, fx, fy, cx, cy, scale=SCALE)

            # --- Panel 3: Overlay ---
            overlay_up = cv2.resize(real, (W*SCALE, H*SCALE), interpolation=cv2.INTER_LINEAR)
            mask_up = cv2.resize(mask.astype(np.uint8)*255, (W*SCALE, H*SCALE),
                                  interpolation=cv2.INTER_NEAREST) > 127
            rend_colored = cv2.resize(rendered, (W*SCALE, H*SCALE), interpolation=cv2.INTER_LINEAR)

            # Semi-transparent overlay
            alpha = 0.45
            overlay_up[mask_up] = (
                (1 - alpha) * overlay_up[mask_up].astype(float) +
                alpha * rend_colored[mask_up].astype(float)
            ).astype(np.uint8)

            # Green contour
            contours, _ = cv2.findContours(
                mask_up.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
            )
            cv2.drawContours(overlay_up, contours, -1, (0, 255, 0), 2)
            draw_keypoints(overlay_up, body_positions, ext, fx, fy, cx, cy, scale=SCALE)

            row = np.concatenate([real_up, rend_up, overlay_up], axis=1)
            all_rows.append(row)

        if all_rows:
            grid = np.concatenate(all_rows, axis=0)
            out_path = output_dir / f"overlay_{cam_name}.png"
            cv2.imwrite(str(out_path), cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  Saved: {out_path}")

    print("\nDone!")


if __name__ == "__main__":
    main()
