"""
Render Franka Panda in MuJoCo using DROID camera extrinsics & joint states,
then overlay on real images to verify alignment.

Usage:
    MUJOCO_GL=egl python droid_testing/render_overlay.py --episode 0 --camera ext1
    MUJOCO_GL=egl python droid_testing/render_overlay.py --episode 0 --camera ext2
"""

import argparse
import numpy as np
import mujoco
import cv2
from pathlib import Path
from scipy.spatial.transform import Rotation

from load_droid_episode import load_episode

# --- Franka MuJoCo model path ---
FRANKA_SCENE_XML = str(Path(__file__).parent / "franka_panda" / "scene.xml")

# --- Default camera intrinsics estimate for ZED 2 at 320x180 ---
# ZED 2 wide mode at 720p: fx≈530, scaled to 320px width → fx≈132
# Validated by visual alignment across multiple episodes
DEFAULT_FX = 130.0
DEFAULT_FY = 130.0
DEFAULT_CX = 160.0
DEFAULT_CY = 90.0
IMG_W, IMG_H = 320, 180


def build_mujoco_model_with_camera(cam_pos, cam_quat_wxyz, fovy_deg):
    """
    Build MuJoCo model with a fixed camera at the given pose.

    We modify the scene XML to include a camera body with exact pos/quat,
    which gives us full 6DOF control (unlike the free camera's azimuth/elevation).
    """
    import tempfile, shutil

    scene_dir = Path(FRANKA_SCENE_XML).parent

    # Read the original scene XML
    with open(FRANKA_SCENE_XML) as f:
        xml = f.read()

    # Insert a camera body into the worldbody
    pos_str = f"{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}"
    quat_str = f"{cam_quat_wxyz[0]} {cam_quat_wxyz[1]} {cam_quat_wxyz[2]} {cam_quat_wxyz[3]}"

    camera_xml = f"""
    <body name="droid_cam_body" pos="{pos_str}" quat="{quat_str}">
      <camera name="droid_cam" fovy="{fovy_deg:.4f}" mode="fixed"/>
    </body>
"""
    xml = xml.replace("</worldbody>", camera_xml + "  </worldbody>")

    # Remove floor, skybox, and haze for clean robot-only rendering
    # Remove floor geom
    import re
    xml = re.sub(r'<geom name="floor"[^/]*/>', '', xml)
    # Remove skybox texture
    xml = re.sub(r'<texture type="skybox"[^/]*/>', '', xml)
    # Remove haze
    xml = re.sub(r'<rgba haze="[^"]*"/>', '', xml)

    # Write to a temp file in the same directory so includes resolve
    tmp_path = scene_dir / "_tmp_droid_scene.xml"
    with open(tmp_path, "w") as f:
        f.write(xml)

    try:
        model = mujoco.MjModel.from_xml_path(str(tmp_path))
    finally:
        tmp_path.unlink(missing_ok=True)

    data = mujoco.MjData(model)
    return model, data


def set_franka_joints(model, data, joint_positions_7, gripper_pos=0.0):
    """Set Franka 7-DOF arm joints + gripper fingers."""
    for i in range(7):
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, f"joint{i+1}")
        data.qpos[model.jnt_qposadr[jid]] = joint_positions_7[i]

    finger_width = np.clip(float(gripper_pos) * 0.04, 0.0, 0.04)
    for fn in ["finger_joint1", "finger_joint2"]:
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, fn)
        data.qpos[model.jnt_qposadr[jid]] = finger_width

    mujoco.mj_forward(model, data)


def droid_extrinsics_to_camera_matrix(extrinsics_6d):
    """
    Convert DROID [x,y,z,rx,ry,rz] to 4x4 T_base_cam.

    DROID camera extrinsics represent the camera pose in robot base frame:
      R = Rotation.from_euler("xyz", [rx,ry,rz])  -- camera-to-base rotation
      t = [x,y,z]  -- camera position in base frame
    """
    pos = extrinsics_6d[:3]
    euler = extrinsics_6d[3:6]
    R = Rotation.from_euler("xyz", euler).as_matrix()
    T = np.eye(4)
    T[:3, :3] = R
    T[:3, 3] = pos
    return T


def droid_extrinsics_to_mujoco_camera(extrinsics_6d, fy, height):
    """
    Convert DROID extrinsics to MuJoCo camera parameters.

    Returns: cam_pos (3,), cam_quat_wxyz (4,), fovy_deg (float)
    """
    T_base_cam = droid_extrinsics_to_camera_matrix(extrinsics_6d)
    cam_pos = T_base_cam[:3, 3]
    R_base_cam = T_base_cam[:3, :3]

    # DROID camera: OpenCV convention (Z forward, X right, Y down)
    # MuJoCo camera: looks along body's -Z, with Y up
    # Convert: R_base_mjbody = R_base_cam @ Flip, where Flip = diag(1,-1,-1)
    flip = np.diag([1.0, -1.0, -1.0])
    R_base_mj = R_base_cam @ flip

    # scipy quat: (x,y,z,w) → MuJoCo: (w,x,y,z)
    q = Rotation.from_matrix(R_base_mj).as_quat()
    cam_quat_wxyz = np.array([q[3], q[0], q[1], q[2]])

    # Vertical FOV from focal length
    fovy_deg = 2.0 * np.arctan(height / (2.0 * fy)) * 180.0 / np.pi

    return cam_pos, cam_quat_wxyz, fovy_deg


def render_frame(model, data, cam_name, width, height):
    """Render RGB from a named camera."""
    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, cam_name)
    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()
    renderer.close()
    return rgb


def overlay_images(real_img, rendered_img, alpha=0.5, mask_threshold=10):
    """Overlay rendered robot on real image with alpha blending + contour."""
    gray = cv2.cvtColor(rendered_img, cv2.COLOR_RGB2GRAY)
    mask = gray > mask_threshold

    overlay = real_img.copy()
    overlay[mask] = (
        (1 - alpha) * real_img[mask].astype(float) +
        alpha * rendered_img[mask].astype(float)
    ).astype(np.uint8)

    contours, _ = cv2.findContours(
        mask.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    cv2.drawContours(overlay, contours, -1, (0, 255, 0), 1)
    return overlay


def visualize_episode(episode_data, camera="ext1", output_dir="droid_testing/output",
                      fx=DEFAULT_FX, fy=DEFAULT_FY, cx=DEFAULT_CX, cy=DEFAULT_CY,
                      frame_indices=None):
    """Render and overlay for selected frames."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    T = episode_data["num_frames"]
    if frame_indices is None:
        frame_indices = [0, T // 2, T - 1]

    if camera == "ext1":
        images = episode_data["ext1_images"]
        extrinsics = episode_data["ext1_extrinsics"]
    else:
        images = episode_data["ext2_images"]
        extrinsics = episode_data["ext2_extrinsics"]

    print(f"Camera: {camera}")
    print(f"  Extrinsics: {extrinsics}")

    # Convert to MuJoCo camera
    cam_pos, cam_quat, fovy = droid_extrinsics_to_mujoco_camera(extrinsics, fy, IMG_H)
    print(f"  MuJoCo cam pos: {cam_pos}")
    print(f"  MuJoCo cam quat (wxyz): {cam_quat}")
    print(f"  FOVy: {fovy:.1f}°")

    # Build model with embedded camera
    model, data = build_mujoco_model_with_camera(cam_pos, cam_quat, fovy)

    for idx in frame_indices:
        if idx >= T:
            continue

        print(f"\n  Frame {idx}/{T}...")
        joint_pos = episode_data["joint_positions"][idx]
        gripper_pos = episode_data["gripper_positions"][idx]
        set_franka_joints(model, data, joint_pos, gripper_pos)

        # Render
        rendered = render_frame(model, data, "droid_cam", IMG_W, IMG_H)

        # Principal point shift (MuJoCo always centers at W/2, H/2)
        dx = int(round(cx - IMG_W / 2))
        dy = int(round(cy - IMG_H / 2))
        if dx != 0 or dy != 0:
            M = np.float32([[1, 0, -dx], [0, 1, -dy]])
            rendered = cv2.warpAffine(rendered, M, (IMG_W, IMG_H))

        real_img = images[idx]
        overlay = overlay_images(real_img, rendered, alpha=0.4)

        prefix = f"frame_{idx:04d}_{camera}"
        cv2.imwrite(str(output_dir / f"{prefix}_real.png"),
                     cv2.cvtColor(real_img, cv2.COLOR_RGB2BGR))
        cv2.imwrite(str(output_dir / f"{prefix}_render.png"),
                     cv2.cvtColor(rendered, cv2.COLOR_RGB2BGR))
        cv2.imwrite(str(output_dir / f"{prefix}_overlay.png"),
                     cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
        print(f"    Saved: {prefix}_*.png")

    # Comparison strip
    strips = []
    for idx in frame_indices:
        if idx >= T:
            continue
        prefix = f"frame_{idx:04d}_{camera}"
        real = cv2.imread(str(output_dir / f"{prefix}_real.png"))
        render = cv2.imread(str(output_dir / f"{prefix}_render.png"))
        over = cv2.imread(str(output_dir / f"{prefix}_overlay.png"))
        strip = np.concatenate([real, render, over], axis=1)
        cv2.putText(strip, f"Frame {idx}", (5, 15),
                     cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        cv2.putText(strip, "Real", (5, IMG_H - 5),
                     cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        cv2.putText(strip, "MuJoCo", (IMG_W + 5, IMG_H - 5),
                     cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        cv2.putText(strip, "Overlay", (2 * IMG_W + 5, IMG_H - 5),
                     cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        strips.append(strip)

    if strips:
        comparison = np.concatenate(strips, axis=0)
        cv2.imwrite(str(output_dir / f"comparison_{camera}.png"), comparison)
        print(f"\nSaved comparison: {output_dir}/comparison_{camera}.png")


def main():
    parser = argparse.ArgumentParser(description="DROID MuJoCo overlay rendering")
    parser.add_argument("--episode", type=int, default=0)
    parser.add_argument("--camera", type=str, default="ext1", choices=["ext1", "ext2"])
    parser.add_argument("--fx", type=float, default=DEFAULT_FX)
    parser.add_argument("--fy", type=float, default=DEFAULT_FY)
    parser.add_argument("--cx", type=float, default=DEFAULT_CX)
    parser.add_argument("--cy", type=float, default=DEFAULT_CY)
    parser.add_argument("--output", type=str, default="droid_testing/output")
    parser.add_argument("--frames", type=int, nargs="+", default=None)
    args = parser.parse_args()

    print(f"Loading episode {args.episode}...")
    episode_data = load_episode(args.episode)
    print(f"  Loaded {episode_data['num_frames']} frames")
    print(f"  Instruction: {episode_data['language_instruction']}")

    visualize_episode(
        episode_data, camera=args.camera, output_dir=args.output,
        fx=args.fx, fy=args.fy, cx=args.cx, cy=args.cy,
        frame_indices=args.frames,
    )


if __name__ == "__main__":
    main()
