"""
Debug alignment between MuJoCo render and real DROID images.
Creates clear visualizations with colored robot silhouettes overlaid on real images.

Usage:
    MUJOCO_GL=egl python droid_testing/debug_alignment.py --episode 0
"""

import argparse
import numpy as np
import mujoco
import cv2
from pathlib import Path
from scipy.spatial.transform import Rotation

from load_droid_episode import load_episode

FRANKA_SCENE_XML = str(Path(__file__).parent / "franka_panda" / "scene.xml")
IMG_W, IMG_H = 320, 180
SCALE = 3  # upscale for visibility


def build_model_with_camera(cam_pos, cam_quat_wxyz, fovy_deg):
    """Build MuJoCo model with embedded fixed camera, no floor/sky."""
    import re
    scene_dir = Path(FRANKA_SCENE_XML).parent

    with open(FRANKA_SCENE_XML) as f:
        xml = f.read()

    pos_str = " ".join(f"{v}" for v in cam_pos)
    quat_str = " ".join(f"{v}" for v in cam_quat_wxyz)

    camera_xml = f"""
    <body name="droid_cam_body" pos="{pos_str}" quat="{quat_str}">
      <camera name="droid_cam" fovy="{fovy_deg:.4f}" mode="fixed"/>
    </body>
"""
    xml = xml.replace("</worldbody>", camera_xml + "  </worldbody>")
    xml = re.sub(r'<geom name="floor"[^/]*/>', '', xml)
    xml = re.sub(r'<texture type="skybox"[^/]*/>', '', xml)
    xml = re.sub(r'<rgba haze="[^"]*"/>', '', xml)

    tmp_path = scene_dir / "_tmp_debug_scene.xml"
    with open(tmp_path, "w") as f:
        f.write(xml)
    try:
        model = mujoco.MjModel.from_xml_path(str(tmp_path))
    finally:
        tmp_path.unlink(missing_ok=True)

    return model, mujoco.MjData(model)


def set_joints(model, data, joint_pos_7, gripper_pos=0.0):
    for i in range(7):
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, f"joint{i+1}")
        data.qpos[model.jnt_qposadr[jid]] = joint_pos_7[i]
    fw = np.clip(float(gripper_pos) * 0.04, 0.0, 0.04)
    for fn in ["finger_joint1", "finger_joint2"]:
        jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, fn)
        data.qpos[model.jnt_qposadr[jid]] = fw
    mujoco.mj_forward(model, data)


def extrinsics_to_mj_camera(ext6d, fy, height):
    """DROID [x,y,z,rx,ry,rz] -> MuJoCo camera params."""
    pos = ext6d[:3]
    R_base_cam = Rotation.from_euler("xyz", ext6d[3:6]).as_matrix()
    # OpenCV cam Z=forward → MuJoCo cam -Z=forward: flip Y and Z
    R_base_mj = R_base_cam @ np.diag([1.0, -1.0, -1.0])
    q = Rotation.from_matrix(R_base_mj).as_quat()  # (x,y,z,w)
    quat_wxyz = np.array([q[3], q[0], q[1], q[2]])
    fovy = 2.0 * np.arctan(height / (2.0 * fy)) * 180.0 / np.pi
    return pos, quat_wxyz, fovy


def render(model, data, width, height):
    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "droid_cam")
    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()
    renderer.close()
    return rgb


def make_overlay(real, rendered, scale=SCALE):
    """Create clear overlay: real image with colored robot silhouette."""
    # Upscale
    real_up = cv2.resize(real, (real.shape[1]*scale, real.shape[0]*scale),
                          interpolation=cv2.INTER_NEAREST)
    rend_up = cv2.resize(rendered, (rendered.shape[1]*scale, rendered.shape[0]*scale),
                          interpolation=cv2.INTER_NEAREST)

    # Robot mask from render
    gray = cv2.cvtColor(rend_up, cv2.COLOR_RGB2GRAY)
    mask = gray > 5

    # Create colored overlay (semi-transparent red)
    overlay = real_up.copy()
    red_tint = np.zeros_like(overlay)
    red_tint[:, :, 0] = 255  # R channel
    red_tint[:, :, 1] = 50
    red_tint[:, :, 2] = 50

    alpha = 0.4
    overlay[mask] = (
        (1 - alpha) * overlay[mask].astype(float) +
        alpha * red_tint[mask].astype(float)
    ).astype(np.uint8)

    # Draw thick green contour
    contours, _ = cv2.findContours(
        mask.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)

    # Also create render-only upscaled
    render_vis = rend_up.copy()

    return overlay, render_vis


def project_point_to_image(point_3d, T_base_cam, fx, fy, cx, cy):
    """Project a 3D point (in base frame) to image coordinates."""
    # T_cam_base = inv(T_base_cam)
    R = T_base_cam[:3, :3]
    t = T_base_cam[:3, 3]
    R_cam = R.T
    t_cam = -R.T @ t

    p_cam = R_cam @ point_3d + t_cam
    if p_cam[2] <= 0:
        return None
    u = fx * p_cam[0] / p_cam[2] + cx
    v = fy * p_cam[1] / p_cam[2] + cy
    return (int(round(u)), int(round(v)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode", type=int, default=0)
    parser.add_argument("--fy", type=float, default=150.0)
    parser.add_argument("--output", type=str, default="droid_testing/output_debug")
    args = parser.parse_args()

    fx = fy = args.fy
    cx, cy = IMG_W / 2, IMG_H / 2

    print(f"Loading episode {args.episode}...")
    ep = load_episode(args.episode)
    T = ep["num_frames"]
    print(f"  Frames: {T}, Instruction: {ep['language_instruction']}")

    frames_to_show = [0, T // 2, T - 1]
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    for cam_name in ["ext1", "ext2"]:
        ext = ep[f"{cam_name}_extrinsics"]
        images = ep[f"{cam_name}_images"]
        print(f"\n=== {cam_name} ===")
        print(f"  Extrinsics: {ext}")

        cam_pos, cam_quat, fovy = extrinsics_to_mj_camera(ext, fy, IMG_H)
        print(f"  Camera pos: {cam_pos}, fovy: {fovy:.1f}°")

        T_base_cam = np.eye(4)
        T_base_cam[:3, :3] = Rotation.from_euler("xyz", ext[3:6]).as_matrix()
        T_base_cam[:3, 3] = ext[:3]

        # Project robot base to image as sanity check
        base_px = project_point_to_image(np.array([0, 0, 0]), T_base_cam, fx, fy, cx, cy)
        print(f"  Robot base projects to pixel: {base_px}")

        model, data = build_model_with_camera(cam_pos, cam_quat, fovy)

        all_rows = []
        for idx in frames_to_show:
            if idx >= T:
                continue
            set_joints(model, data, ep["joint_positions"][idx], ep["gripper_positions"][idx])
            rendered = render(model, data, IMG_W, IMG_H)
            real = images[idx]

            overlay, render_vis = make_overlay(real, rendered)

            # Add frame labels
            h, w = overlay.shape[:2]
            cv2.putText(overlay, f"{cam_name} f={idx}", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 0), 2)

            # Show projected robot base on real image
            real_up = cv2.resize(real, (w, h), interpolation=cv2.INTER_NEAREST)
            if base_px is not None:
                bx, by = base_px[0] * SCALE, base_px[1] * SCALE
                cv2.circle(real_up, (bx, by), 8, (0, 0, 255), 2)
                cv2.putText(real_up, "base", (bx+10, by), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

            row = np.concatenate([real_up, render_vis, overlay], axis=1)
            all_rows.append(row)

        if all_rows:
            grid = np.concatenate(all_rows, axis=0)
            # Convert RGB to BGR for cv2
            cv2.imwrite(str(output_dir / f"debug_{cam_name}.png"),
                         cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  Saved: debug_{cam_name}.png")

    # Also create a FOV sweep for ext1 with frame 0
    print("\n=== FOV sweep (ext1, frame 0) ===")
    ext = ep["ext1_extrinsics"]
    images = ep["ext1_images"]
    T_base_cam = np.eye(4)
    T_base_cam[:3, :3] = Rotation.from_euler("xyz", ext[3:6]).as_matrix()
    T_base_cam[:3, 3] = ext[:3]

    fov_strips = []
    for test_fy in [100, 125, 150, 175, 200]:
        cam_pos, cam_quat, fovy = extrinsics_to_mj_camera(ext, test_fy, IMG_H)
        model, data = build_model_with_camera(cam_pos, cam_quat, fovy)
        set_joints(model, data, ep["joint_positions"][0], ep["gripper_positions"][0])
        rendered = render(model, data, IMG_W, IMG_H)
        overlay, _ = make_overlay(images[0], rendered)
        cv2.putText(overlay, f"fy={test_fy}", (10, 30),
                     cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 0), 2)
        fov_strips.append(overlay)

    if fov_strips:
        fov_grid = np.concatenate(fov_strips, axis=0)
        cv2.imwrite(str(output_dir / "fov_sweep.png"),
                     cv2.cvtColor(fov_grid, cv2.COLOR_RGB2BGR))
        print("  Saved: fov_sweep.png")


if __name__ == "__main__":
    main()
