"""YAM bimanual silhouette overlay renderer.

Renders both arms' silhouettes using MuJoCo from a single MJCF (one arm),
re-using the same model by injecting a camera and updating qpos. For the right
arm, the camera is brought into right_arm_base coords via inv(T_left_from_right).

Output: a side-by-side panel (raw | overlay) per requested frame.
"""
import os
os.environ.setdefault("MUJOCO_GL", "egl")

import argparse
import json
import pickle
from pathlib import Path

import cv2
import mujoco
import numpy as np

# Use the combined arm+gripper XML that raiden itself uses — this is the
# canonical kinematic model for this YAM + linear-4310 gripper rig.
from raiden._xml_paths import get_yam_4310_linear_xml_path

YAM_XML = get_yam_4310_linear_xml_path()
DATA_ROOT = Path("/home/robot-lab/data/processed")


def load_assets(xml_path):
    """Load mesh assets from all directories the combined XML references."""
    xml = Path(xml_path).read_text()
    xml_dir = Path(xml_path).parent
    assets = {}
    # Recursively scan up to 4 levels for asset dirs (combined xml may live in /tmp/
    # while meshes live in i2rt/robot_models/...). The combined XML uses absolute
    # paths for mesh files, so MuJoCo loads them directly from disk; nothing to do.
    return assets


YAM_ASSETS = load_assets(YAM_XML)


def fovy_from_K(K, H):
    fy = float(K[1, 1])
    return float(2 * np.degrees(np.arctan(H / (2 * fy))))


def build_xml_with_camera(xml_path, cam_pos, cam_rot_cv, fovy, W, H):
    """Inject a fixed pinhole camera (OpenCV convention) + offscreen buffer size into the YAM MJCF."""
    opencv_to_mj = np.diag([1, -1, -1])  # flip Y + Z
    cam_rot_mj = cam_rot_cv @ opencv_to_mj
    cam_x, cam_y = cam_rot_mj[:, 0], cam_rot_mj[:, 1]
    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]:.6f} {cam_pos[1]:.6f} {cam_pos[2]:.6f}" '
        f'xyaxes="{cam_x[0]:.6f} {cam_x[1]:.6f} {cam_x[2]:.6f} '
        f'{cam_y[0]:.6f} {cam_y[1]:.6f} {cam_y[2]:.6f}" '
        f'fovy="{fovy:.4f}"/>'
    )
    visual_str = f'<visual><global offwidth="{W}" offheight="{H}"/></visual>'
    xml = Path(xml_path).read_text()
    xml = xml.replace("<asset>", f"{visual_str}\n  <asset>", 1)
    xml = xml.replace("</worldbody>", f"  {cam_str}\n  </worldbody>")
    return xml


def render_arm_mask(joints7, T_cam2base, K, W, H):
    """Render a single arm + gripper at world origin from a given camera pose. Returns (mask, depth).

    joints7: (7,) array. First 6 are arm joints; 7th is gripper opening.
    """
    cam_pos = T_cam2base[:3, 3]
    cam_rot = T_cam2base[:3, :3]
    fovy = fovy_from_K(K, H)
    xml = build_xml_with_camera(YAM_XML, cam_pos, cam_rot, fovy, W, H)
    # The combined XML references meshes by absolute path on disk, so no assets dict needed.
    model = mujoco.MjModel.from_xml_string(xml)
    data = mujoco.MjData(model)
    # Arm joints: qpos[0..5] from joints7[0..5].
    arm_n = min(6, len(joints7))
    data.qpos[:arm_n] = joints7[:arm_n]
    # Gripper (joint7 = qpos[6], joint8 = qpos[7]). raiden's gripper command is
    # normalized ~[0, 1] (FOLLOWER_HOME_POS has gripper=1.0 = fully open). The MJCF
    # finger sliders have range [0, 0.0475] m and are coupled by an <equality>,
    # but mj_forward doesn't enforce equality constraints — so we drive both
    # sliders manually and scale to the slider's physical range.
    if len(joints7) >= 7 and model.nq >= 8:
        slider_m = float(np.clip(joints7[6], 0.0, 1.0)) * 0.0475
        data.qpos[6] = slider_m
        data.qpos[7] = slider_m
    mujoco.mj_forward(model, data)
    renderer = mujoco.Renderer(model, height=H, width=W)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")
    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.close()
    # Pixels at "infinity" come back as the far-plane depth; everything else is robot.
    far = float(depth.max())
    mask = (depth < far * 0.999).astype(np.uint8)
    return mask, depth


def render_bimanual(joints14, T_cam2world, T_left_from_right, K, W, H):
    # joints[0:7] = left arm (6 joints + gripper); joints[7:14] = right arm same layout.
    mask_l, _ = render_arm_mask(joints14[:7], T_cam2world, K, W, H)
    T_inv = np.linalg.inv(T_left_from_right)
    T_cam2rbase = T_inv @ T_cam2world
    mask_r, _ = render_arm_mask(joints14[7:14], T_cam2rbase, K, W, H)
    return mask_l, mask_r


def overlay_contours(img, mask_l, mask_r):
    out = img.copy()
    # Left arm = GREEN, right arm = RED. Thick contours + light fill for shape sense.
    for mask, color in [(mask_l, (0, 255, 0)), (mask_r, (0, 0, 255))]:
        if mask is None or mask.sum() == 0:
            continue
        cs, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(out, cs, -1, color, 2)
        fill = img.copy()
        cv2.fillPoly(fill, cs, color)
        out = cv2.addWeighted(out, 0.75, fill, 0.25, 0)
    return out


def find_scene_cam(cams):
    """cams: iterable of camera names. Returns the scene cam name or None."""
    cams = list(cams)
    for c in cams:
        cl = c.lower()
        if "scene" in cl and "wrist" not in cl:
            return c
    # Fallback: any cam without "wrist"
    for c in cams:
        if "wrist" not in c.lower():
            return c
    return cams[0] if cams else None


def list_episodes(task_dir):
    return sorted([d for d in task_dir.iterdir() if d.is_dir() and (d / "lowdim").is_dir()])


def count_frames(ep_dir):
    pkls = sorted((ep_dir / "lowdim").glob("*.pkl"))
    return len(pkls), pkls


def find_rgb_path(ep_dir, cam, frame_int):
    fstr = f"{frame_int:010d}"
    for ext in (".jpg", ".png", ".jpeg"):
        cand = ep_dir / "rgb" / cam / f"{fstr}{ext}"
        if cand.exists():
            return cand
    return None


def render_one(task, ep, cam=None, frame=None, out_path=None, label=None):
    task_dir = DATA_ROOT / task
    ep_dir = task_dir / ep
    n_frames, pkls = count_frames(ep_dir)
    if n_frames == 0:
        return None, f"no frames in {ep_dir}"
    if frame is None:
        frame = n_frames // 2
    fstr = f"{frame:010d}"
    pkl_path = ep_dir / "lowdim" / f"{fstr}.pkl"
    with open(pkl_path, "rb") as f:
        fd = pickle.load(f)
    if cam is None:
        # Read camera list from the frame's intrinsics dict — works whether or not
        # metadata_shared.json exists at the task root.
        cam = find_scene_cam(fd["intrinsics"].keys())
    rgb_path = find_rgb_path(ep_dir, cam, frame)
    if rgb_path is None:
        return None, f"missing rgb for cam {cam}"
    joints = np.array(fd["joints"])
    T_lfr = np.array(fd["T_left_from_right"])
    extr = fd["extrinsics"].get(cam)
    intr = fd["intrinsics"].get(cam)
    if extr is None or intr is None:
        return None, f"no extr/intr for cam {cam}"
    K = np.array(intr)
    T_cam2world = np.array(extr)
    img = cv2.imread(str(rgb_path))
    H, W = img.shape[:2]
    mask_l, mask_r = render_bimanual(joints, T_cam2world, T_lfr, K, W, H)
    overlay = overlay_contours(img, mask_l, mask_r)
    # Side-by-side panel
    panel = np.concatenate([img, overlay], axis=1)
    cv2.line(panel, (W, 0), (W, H), (40, 40, 40), 1)
    lbl = label or f"{task}/{ep} cam={cam} f={frame}/{n_frames}"
    cv2.putText(panel, lbl, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
    if out_path:
        Path(out_path).parent.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(str(out_path), panel)
    return panel, lbl


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--task", required=True)
    p.add_argument("--ep", default="0000")
    p.add_argument("--cam", default=None, help="defaults to scene cam from metadata")
    p.add_argument("--frame", type=int, default=None, help="defaults to middle frame")
    p.add_argument("--out", required=True)
    args = p.parse_args()
    panel, label = render_one(args.task, args.ep, args.cam, args.frame, args.out)
    if panel is None:
        print(f"FAIL: {label}")
        return 1
    print(f"OK: saved {args.out}  ({label})")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
