"""Render panel (b) for fig2:

Single observer view of the LIBERO scene showing TWO camera frustums at
different poses, each with a ray unprojected from the camera through the
same GT 3D target. Annotates depth (different per camera) and a height
projection from the target down to the table/robot base (invariant across
the two views).

Output:
  /data/cameron/penpot/figures/extracted/fig2v3/two_frustums.png
  /data/cameron/penpot/figures/extracted/fig2v3/two_frustums_meta.json
    (contains projected 2D coords of key points so the SVG can overlay labels)

Usage:
  MUJOCO_GL=osmesa CUDA_VISIBLE_DEVICES=5 python render_fig2b_two_frustums.py
"""

import json
import os
import sys

import cv2
import h5py
import numpy as np
from scipy.spatial.transform import Rotation as R

from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm_lib, get_libero_path
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)

IMAGE_SIZE = 896  # render at 2× resolution for crisper output
IMAGE_PLANE_DEPTH_A = 0.18
IMAGE_PLANE_DEPTH_B = 0.22
TABLE_Z = 0.85

OUT_DIR = "/data/cameron/penpot/figures/extracted/fig2v3"
OUT_PNG = f"{OUT_DIR}/two_frustums.png"
OUT_META = f"{OUT_DIR}/two_frustums_meta.json"


# ──────────────────────────────────────────────────────────────────────
# Helpers (minimal copies from generate_method_visualization.py)
# ──────────────────────────────────────────────────────────────────────

def compute_frustum_corners(camera_pose, cam_K, image_size, depth):
    """Compute 4 corners of the image plane at given depth in world coords."""
    K_inv = np.linalg.inv(cam_K)
    corners_px = np.array([
        [0, 0, 1],
        [image_size, 0, 1],
        [image_size, image_size, 1],
        [0, image_size, 1],
    ], dtype=np.float64)
    cam_pos = camera_pose[:3, 3]
    R_cam = camera_pose[:3, :3]
    out = []
    for c in corners_px:
        ray_cam = K_inv @ c
        ray_cam = ray_cam / ray_cam[2] * depth
        pt_world = R_cam @ ray_cam + cam_pos
        out.append(pt_world)
    return np.array(out)


def project_3d_to_2d(points_3d, world_to_camera, image_size, allow_out=True):
    out = []
    for pt in points_3d:
        pix_rc = project_points_from_world_to_camera(
            points=pt.reshape(1, 3).astype(np.float64),
            world_to_camera_transform=world_to_camera,
            camera_height=image_size,
            camera_width=image_size,
        )[0]
        u = int(round(float(pix_rc[1])))
        v = int(round(float(pix_rc[0])))
        if allow_out or (0 <= u < image_size and 0 <= v < image_size):
            out.append((u, v))
        else:
            out.append(None)
    return out


def draw_line_3d(frame, p1, p2, w2c, image_size, color, thickness=2):
    pts = project_3d_to_2d(np.array([p1, p2]), w2c, image_size, allow_out=True)
    if pts[0] is None or pts[1] is None:
        return
    cv2.line(frame, pts[0], pts[1], color, thickness, cv2.LINE_AA)


def draw_circle_3d(frame, pt, w2c, image_size, color, radius=8, thickness=-1):
    pts = project_3d_to_2d(np.array([pt]), w2c, image_size, allow_out=True)
    if pts[0] is None:
        return
    cv2.circle(frame, pts[0], radius, color, thickness, cv2.LINE_AA)


def draw_frustum(frame, cam_origin, corners, w2c, image_size, color, thickness=3):
    for c in corners:
        draw_line_3d(frame, cam_origin, c, w2c, image_size, color, thickness)
    for i in range(4):
        draw_line_3d(frame, corners[i], corners[(i + 1) % 4], w2c, image_size, color, thickness)


# ──────────────────────────────────────────────────────────────────────
# Camera helpers
# ──────────────────────────────────────────────────────────────────────

def spherical_camera(target, azimuth_deg, elevation_deg, distance):
    """Place a camera on a sphere around `target`, looking inward.
    Returns (mujoco_pos, mujoco_quat_wxyz)."""
    phi = np.radians(azimuth_deg)
    theta = np.radians(elevation_deg)
    x = distance * np.cos(theta) * np.cos(phi)
    y = distance * np.cos(theta) * np.sin(phi)
    z = distance * np.sin(theta)
    pos = target + np.array([x, y, z])

    # Look-at: build rotation matrix so -Z (OpenGL/MuJoCo forward) points at target
    forward = (target - pos) / np.linalg.norm(target - pos)
    world_up = np.array([0.0, 0.0, 1.0])
    right = np.cross(forward, world_up)
    if np.linalg.norm(right) < 1e-6:
        right = np.array([1.0, 0.0, 0.0])
    right = right / np.linalg.norm(right)
    up = np.cross(right, forward)
    up = up / np.linalg.norm(up)
    # MuJoCo convention: camera frame axes (right, up, -forward) as columns
    mat = np.stack([right, up, -forward], axis=1)
    rot = R.from_matrix(mat)
    q_xyzw = rot.as_quat()
    return pos, np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]])


def get_frustum(sim, cam_id, cam_name, pos, quat, cam_K, image_size, depth):
    """Set the sim camera pose and compute its world-frame frustum corners."""
    sim.model.cam_pos[cam_id] = pos
    sim.model.cam_quat[cam_id] = quat
    sim.forward()
    pose = get_camera_extrinsic_matrix(sim, cam_name)
    corners = compute_frustum_corners(pose, cam_K, image_size, depth)
    return pose, corners


# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────

def main():
    print("[1/5] init LIBERO…")
    benchmark = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = benchmark.get_task(0)
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=["agentview"],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    print("[2/5] load demo + clean scene…")
    demo_path = os.path.join(get_libero_path("datasets"), benchmark.get_task_demonstration(0))
    with h5py.File(demo_path, "r") as f:
        states = np.array(f["data/demo_0/states"])
        actions = np.array(f["data/demo_0/actions"])
    viz_state = states[20].copy()

    # Find grasp frame (first gripper close transition)
    gripper_actions = actions[:, 6]
    grasp_frame = 20
    for t in range(1, len(gripper_actions)):
        if gripper_actions[t - 1] < 0 and gripper_actions[t] > 0:
            grasp_frame = t
            break
    print(f"   grasp frame = {grasp_frame}")

    # Apply state + clean scene
    env.set_init_state(viz_state)
    sim.forward()
    for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(name)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass
    for name in [
        "akita_black_bowl_2_main", "cookies_1_main",
        "glazed_rim_porcelain_ramekin_1_main",
    ]:
        try:
            bid = sim.model.body_name2id(name)
            for gid in range(sim.model.ngeom):
                if sim.model.geom_bodyid[gid] == bid:
                    sim.model.geom_rgba[gid][3] = 0.0
        except Exception:
            pass
    sim.forward()

    # Re-apply viz state so distractor cleanup sticks
    env.set_init_state(viz_state)
    sim.forward()
    env.env._get_observations()

    # Extract GT target (grasp position)
    env.set_init_state(states[grasp_frame])
    sim.forward()
    obs_grasp = env.env._get_observations()
    grasp_eef = np.array(obs_grasp["robot0_eef_pos"], dtype=np.float64)
    print(f"   grasp_eef = {grasp_eef}")

    env.set_init_state(viz_state)
    sim.forward()

    # Camera intrinsics (same for both policy cameras and observer)
    cam_name = "agentview"
    cam_id = sim.model.camera_name2id(cam_name)
    cam_K = get_camera_intrinsic_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)

    # Save original agentview pose for restoring later
    orig_cam_pos = sim.model.cam_pos[cam_id].copy()
    orig_cam_quat = sim.model.cam_quat[cam_id].copy()

    print("[3/5] build two policy cameras + observer…")
    # Target for camera placement: slightly above the grasp point for a
    # look-at that frames the scene nicely
    look_at = grasp_eef.copy()

    # Camera A: close, front-low angle  (noticeably shorter depth)
    cam_A_pos, cam_A_quat = spherical_camera(
        target=look_at,
        azimuth_deg=-135,
        elevation_deg=18,
        distance=0.38,
    )
    # Camera B: same orientation (azimuth/elevation) as before but moved
    # much closer to the scene — roughly the same absolute distance from
    # camera A as A is wide from the target.
    cam_B_pos, cam_B_quat = spherical_camera(
        target=look_at,
        azimuth_deg=55,
        elevation_deg=55,
        distance=0.50,
    )

    depth_A = float(np.linalg.norm(grasp_eef - cam_A_pos))
    depth_B = float(np.linalg.norm(grasp_eef - cam_B_pos))
    height_z = float(grasp_eef[2] - TABLE_Z)
    print(f"   depth_A = {depth_A:.3f} m")
    print(f"   depth_B = {depth_B:.3f} m")
    print(f"   height  = {height_z:.3f} m")

    # Compute frustums for each (temporarily set sim camera to query pose)
    _, frustum_A = get_frustum(sim, cam_id, cam_name, cam_A_pos, cam_A_quat,
                               cam_K, IMAGE_SIZE, IMAGE_PLANE_DEPTH_A)
    _, frustum_B = get_frustum(sim, cam_id, cam_name, cam_B_pos, cam_B_quat,
                               cam_K, IMAGE_SIZE, IMAGE_PLANE_DEPTH_B)

    # ── Observer camera: closer framing around the two frustums + target ──
    midpoint = (cam_A_pos + cam_B_pos + grasp_eef) / 3.0
    observer_pos, observer_quat = spherical_camera(
        target=midpoint,
        azimuth_deg=-50,
        elevation_deg=30,
        distance=1.5,
    )

    print("[4/5] render observer view…")
    sim.model.cam_pos[cam_id] = observer_pos
    sim.model.cam_quat[cam_id] = observer_quat
    env.set_init_state(viz_state)
    sim.forward()
    obs_render = env.env._get_observations()
    rgb = np.flipud(np.asarray(obs_render["agentview_image"]).copy())
    frame = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    w2c = get_camera_transform_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)

    print("[5/5] draw frustums + rays + target…")
    # BGR colors
    COL_A = (235, 130, 40)     # blue-ish   (RGB #2882eb)
    COL_B = (28, 100, 235)     # orange-ish (RGB #eb641c)
    COL_TARGET = (60, 200, 60) # 3D target — green
    COL_HEIGHT = (40, 200, 80) # height projection line — bright green
    COL_BASE = (50, 45, 215)   # base-on-table keypoint — red (RGB #d72d32)

    # Frustums
    draw_frustum(frame, cam_A_pos, frustum_A, w2c, IMAGE_SIZE, COL_A, thickness=4)
    draw_frustum(frame, cam_B_pos, frustum_B, w2c, IMAGE_SIZE, COL_B, thickness=4)

    # Rays from each camera to the target
    draw_line_3d(frame, cam_A_pos, grasp_eef, w2c, IMAGE_SIZE, COL_A, thickness=5)
    draw_line_3d(frame, cam_B_pos, grasp_eef, w2c, IMAGE_SIZE, COL_B, thickness=5)

    # Height projection line from target straight down to table (along -Z)
    target_on_table = grasp_eef.copy()
    target_on_table[2] = TABLE_Z
    draw_line_3d(frame, grasp_eef, target_on_table, w2c, IMAGE_SIZE, COL_HEIGHT, thickness=5)

    # Camera position markers (small filled circles)
    for pt, col in [(cam_A_pos, COL_A), (cam_B_pos, COL_B)]:
        draw_circle_3d(frame, pt, w2c, IMAGE_SIZE, col, radius=14, thickness=-1)
        draw_circle_3d(frame, pt, w2c, IMAGE_SIZE, (255, 255, 255), radius=15, thickness=3)

    # GT 3D target — smaller, outline-only so the bowl is visible through it
    draw_circle_3d(frame, grasp_eef, w2c, IMAGE_SIZE, (255, 255, 255), radius=13, thickness=2)
    draw_circle_3d(frame, grasp_eef, w2c, IMAGE_SIZE, COL_TARGET, radius=11, thickness=3)

    # Base-on-table tick — RED outline to clearly differentiate from the green target
    draw_circle_3d(frame, target_on_table, w2c, IMAGE_SIZE, (255, 255, 255), radius=10, thickness=2)
    draw_circle_3d(frame, target_on_table, w2c, IMAGE_SIZE, COL_BASE, radius=8, thickness=3)

    # ── Save PNG + metadata ──
    os.makedirs(OUT_DIR, exist_ok=True)
    cv2.imwrite(OUT_PNG, frame)
    print(f"   wrote {OUT_PNG}")

    # Project key points into the observer view so the SVG layer can drop
    # depth/height labels in the right place
    key_3d = [cam_A_pos, cam_B_pos, grasp_eef, target_on_table]
    key_2d = project_3d_to_2d(np.array(key_3d), w2c, IMAGE_SIZE, allow_out=True)

    def mid(a, b):
        return (int(round((a[0] + b[0]) / 2)), int(round((a[1] + b[1]) / 2)))

    meta = {
        "image_size": IMAGE_SIZE,
        "depth_A": depth_A,
        "depth_B": depth_B,
        "height_z": height_z,
        "cam_A_2d": list(key_2d[0]) if key_2d[0] else None,
        "cam_B_2d": list(key_2d[1]) if key_2d[1] else None,
        "target_2d": list(key_2d[2]) if key_2d[2] else None,
        "target_on_table_2d": list(key_2d[3]) if key_2d[3] else None,
        "ray_A_mid_2d": mid(key_2d[0], key_2d[2]) if (key_2d[0] and key_2d[2]) else None,
        "ray_B_mid_2d": mid(key_2d[1], key_2d[2]) if (key_2d[1] and key_2d[2]) else None,
        "height_mid_2d": mid(key_2d[2], key_2d[3]) if (key_2d[2] and key_2d[3]) else None,
    }
    with open(OUT_META, "w") as f:
        json.dump(meta, f, indent=2)
    print(f"   wrote {OUT_META}")

    # Restore camera for a clean shutdown
    sim.model.cam_pos[cam_id] = orig_cam_pos
    sim.model.cam_quat[cam_id] = orig_cam_quat
    env.close()
    print("done.")


if __name__ == "__main__":
    main()
