"""Closed-loop libero 2view rollout viz → video.

Per replan step records: BEV RGB, wrist RGB, F_bev, F_wrist, predicted N_WINDOW 3D positions.
After rollout, computes joint 3-PCA basis over all F_bev/F_wrist from all frames, then
renders each frame as a 1×4 panel and stitches to mp4.

Layout per frame (left→right):
  [BEV + pred trajectory in rainbow] [wrist + pred trajectory] [F_bev PCA] [F_wrist PCA]

Usage:
    python viz_rollout_2view_video.py \\
        --checkpoint /data/cameron/para/libero/checkpoints/libero_2view_v0/latest.pth \\
        --cam_theta 14 --cam_phi 45 --max_steps 300 \\
        --out_path /tmp/rollout_2v_th14_ph45.mp4
"""
import os, sys, argparse
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/LIBERO")

import numpy as np
import torch
import cv2
import h5py
from pathlib import Path
from scipy.spatial.transform import Rotation as ScipyR

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix, get_camera_intrinsic_matrix, get_camera_transform_matrix,
)
from model_dino_volume_query_2view import (DinoVolumeQuery2View, PRED_SIZE,
                                            build_bev_world_xyz_table)
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, eef_to_start_kp
from eval_libero_2view_ood import (
    shift_init_state, hide_distractors_in_state, apply_clean_scene,
    servo_to_pos, reposition_camera, decode_actions,
)

LIBERO_IMG = 448


def project_world_to_pixel(xyz, K_pixel, extrinsic):
    """xyz: (3,) world. extrinsic: (4,4) cam→world. Returns (u, v) pixel or None."""
    world_to_cam = np.linalg.inv(extrinsic)
    pt_cam = world_to_cam @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
    if pt_cam[2] <= 1e-3:
        return None
    pix_h = K_pixel @ (pt_cam[:3] / pt_cam[2])
    return (float(pix_h[0]), float(pix_h[1]))


def feature_pca_image(F, pca_components, mean, hw_out):
    """F: (d, H, W) tensor. Project to 3 PCs (mean-subtracted), normalize, resize to hw_out."""
    d, H, W = F.shape
    flat = F.reshape(d, -1).T.numpy() - mean[None, :]
    proj = flat @ pca_components.T  # (H*W, 3)
    # Per-axis min-max normalization across all frames (use the same global min/max — passed in mean/pca actually)
    img = proj.reshape(H, W, 3)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    img = (img * 255).astype(np.uint8)
    img = cv2.resize(img, (hw_out, hw_out), interpolation=cv2.INTER_NEAREST)
    return img


def draw_trajectory(img, pix_array, color_start=(0, 0, 255), color_end=(0, 255, 0)):
    """Draw a rainbow polyline of predicted pixel positions on img."""
    T = len(pix_array)
    for i in range(T):
        u, v = pix_array[i]
        if not (0 <= u < img.shape[1] and 0 <= v < img.shape[0]):
            continue
        alpha = i / max(T - 1, 1)
        col = (int(color_start[0] * (1 - alpha) + color_end[0] * alpha),
               int(color_start[1] * (1 - alpha) + color_end[1] * alpha),
               int(color_start[2] * (1 - alpha) + color_end[2] * alpha))
        cv2.circle(img, (int(u), int(v)), 6, col, -1)
        cv2.circle(img, (int(u), int(v)), 6, (0, 0, 0), 1)
        if i > 0:
            u0, v0 = pix_array[i - 1]
            if 0 <= u0 < img.shape[1] and 0 <= v0 < img.shape[0]:
                cv2.line(img, (int(u0), int(v0)), (int(u), int(v)), col, 1)
    return img


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark", default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--ep_idx", type=int, default=0)
    p.add_argument("--max_steps", type=int, default=300)
    p.add_argument("--shift_dx", type=float, default=0.0)
    p.add_argument("--shift_dy", type=float, default=0.0)
    p.add_argument("--clean_scene", action="store_true", default=True)
    p.add_argument("--zero_rotation", action="store_true", default=True)
    p.add_argument("--teleport", action="store_true", default=True)
    p.add_argument("--cam_theta", type=float, default=0.0)
    p.add_argument("--cam_phi", type=float, default=0.0)
    p.add_argument("--out_path", type=str, default="/tmp/rollout_2v.mp4")
    p.add_argument("--fps", type=int, default=4)
    p.add_argument("--actions_per_replan", type=int, default=8,
                    help="How many actions to execute between model forwards (default = full N_WINDOW)")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    min_g, max_g = float(ckpt["min_grip"]),   float(ckpt["max_grip"])
    rot_pca_mean = np.asarray(ckpt["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(ckpt["rot_pca_axis"], dtype=np.float64)
    rot_pca_min  = float(ckpt["rot_pca_min"]); rot_pca_max = float(ckpt["rot_pca_max"])
    n_rot_bins   = int(ckpt["n_rot_bins"])
    n_window     = int(ckpt.get("n_window", 8))
    image_size   = int(ckpt.get("image_size", LIBERO_IMG))

    model = DinoVolumeQuery2View(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE, rotation_mode='1d_pca',
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)

    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=image_size, camera_widths=image_size,
        camera_names=["agentview", "robot0_eye_in_hand"],
    )
    env.seed(0); env.reset()
    if args.clean_scene:
        apply_clean_scene(env.env.sim)

    init_state = init_states[args.ep_idx].copy()
    if args.shift_dx != 0 or args.shift_dy != 0:
        init_state = shift_init_state(init_state, args.shift_dx, args.shift_dy)
        init_state = hide_distractors_in_state(init_state)
    obs = env.set_init_state(init_state)
    for _ in range(5):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))
    if args.cam_theta != 0.0 or args.cam_phi != 0.0:
        reposition_camera(env.env.sim, "agentview", args.cam_theta, args.cam_phi)
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    cur_bev_K = get_camera_intrinsic_matrix(env.env.sim, "agentview", image_size, image_size).astype(np.float32)
    cur_bev_K_norm = cur_bev_K.copy(); cur_bev_K_norm[0] /= image_size; cur_bev_K_norm[1] /= image_size
    cur_bev_ext = get_camera_extrinsic_matrix(env.env.sim, "agentview").astype(np.float32)
    bev_xyz_table = build_bev_world_xyz_table(
        torch.tensor(cur_bev_K_norm, dtype=torch.float32, device=device),
        torch.tensor(cur_bev_ext,    dtype=torch.float32, device=device),
        32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size, device,
    )
    K_bev_pixel = cur_bev_K_norm.copy().astype(np.float64)
    K_bev_pixel[0] *= image_size; K_bev_pixel[1] *= image_size
    camera_pose_bev = cur_bev_ext.astype(np.float64)

    frames_record = []
    step_idx = 0; done = False; success = False
    while step_idx < args.max_steps and not done:
        current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
        rgb_bev_obs = obs["agentview_image"]
        rgb_wrist_obs = obs["robot0_eye_in_hand_image"]
        # uint8 [0,255]
        bev_rgb = np.flipud(rgb_bev_obs).copy() if rgb_bev_obs.dtype == np.uint8 else np.flipud((rgb_bev_obs*255).astype(np.uint8)).copy()
        wrist_rgb = np.flipud(rgb_wrist_obs).copy() if rgb_wrist_obs.dtype == np.uint8 else np.flipud((rgb_wrist_obs*255).astype(np.uint8)).copy()

        img_bev = preprocess_obs(rgb_bev_obs, image_size).to(device)
        img_wrist = preprocess_obs(rgb_wrist_obs, image_size).to(device)
        world_to_cam_bev = get_camera_transform_matrix(env.sim, "agentview", image_size, image_size)
        start_pix = eef_to_start_kp(current_eef_pos, world_to_cam_bev, image_size).to(device)
        if start_pix.dim() == 1: start_pix = start_pix.unsqueeze(0)
        wrist_ext_np = get_camera_extrinsic_matrix(env.sim, "robot0_eye_in_hand").astype(np.float32)
        wrist_K_np = get_camera_intrinsic_matrix(env.sim, "robot0_eye_in_hand", image_size, image_size).astype(np.float32)
        wrist_K_norm = wrist_K_np.copy(); wrist_K_norm[0] /= image_size; wrist_K_norm[1] /= image_size
        wrist_K_t = torch.tensor(wrist_K_norm, dtype=torch.float32, device=device).unsqueeze(0)
        wrist_ext_t = torch.tensor(wrist_ext_np, dtype=torch.float32, device=device).unsqueeze(0)

        with torch.no_grad():
            out = model(img_bev, img_wrist, start_pix, bev_xyz_table, wrist_K_t, wrist_ext_t)
            F_bev = out["pixel_feats"][0].cpu()           # (d, H, W)
            F_wrist = out["pixel_feats_wrist"][0].cpu()
            # Volume logits (T, Z, H, W) — softmax over (Z, H, W) per timestep for heatmap
            vol_logits = out["volume_logits"][0].cpu()    # (T, Z, H, W)

        window_actions, pred_3d_list = decode_actions(
            out, camera_pose_bev, K_bev_pixel,
            current_eef_pos, current_eef_quat,
            min_h, max_h, min_g, max_g,
            rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
            image_size=image_size,
        )

        # Project predicted 3D into both views
        wrist_K_pix = wrist_K_norm.astype(np.float64).copy()
        wrist_K_pix[0] *= image_size; wrist_K_pix[1] *= image_size
        wrist_ext_np_64 = wrist_ext_np.astype(np.float64)
        pred_pix_bev = []; pred_pix_wrist = []
        for p3d in pred_3d_list:
            pb = project_world_to_pixel(p3d, K_bev_pixel, camera_pose_bev)
            pw = project_world_to_pixel(p3d, wrist_K_pix, wrist_ext_np_64)
            pred_pix_bev.append(pb if pb is not None else (-100, -100))
            pred_pix_wrist.append(pw if pw is not None else (-100, -100))

        frames_record.append({
            "bev_rgb": bev_rgb, "wrist_rgb": wrist_rgb,
            "F_bev": F_bev, "F_wrist": F_wrist,
            "vol_logits": vol_logits,
            "pred_pix_bev": np.array(pred_pix_bev), "pred_pix_wrist": np.array(pred_pix_wrist),
            "step": step_idx, "eef_pos": current_eef_pos,
        })

        n_to_exec = min(args.actions_per_replan, len(window_actions))
        for t, action in enumerate(window_actions[:n_to_exec]):
            if args.zero_rotation:
                action[3:6] = 0.0
            if args.teleport:
                target = pred_3d_list[t].astype(np.float64)
                gripper_cmd = float(action[6])
                obs, n_servo, ep_done = servo_to_pos(env, target, gripper_cmd, max_servo=25)
                step_idx += n_servo
                if ep_done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                    success = True; done = True; break
            else:
                obs, _, done, _ = env.step(action)
                step_idx += 1
                if done: success = True; break
            if step_idx >= args.max_steps: break
        if step_idx >= args.max_steps: break

    print(f"Episode {args.ep_idx} done. success={success} step_idx={step_idx} n_replans={len(frames_record)}")

    # Compute joint PCA over all (F_bev + F_wrist) frames
    d = frames_record[0]["F_bev"].shape[0]
    H_g = frames_record[0]["F_bev"].shape[1]
    all_F = []
    for fr in frames_record:
        all_F.append(fr["F_bev"].reshape(d, -1).T.numpy())
        all_F.append(fr["F_wrist"].reshape(d, -1).T.numpy())
    all_F = np.concatenate(all_F, axis=0)
    mean_F = all_F.mean(0)
    centered = all_F - mean_F
    # SVD top 3 PCs
    _, _, Vt = np.linalg.svd(centered, full_matrices=False)
    pca_components = Vt[:3]  # (3, d)
    proj_all = centered @ pca_components.T
    g_min = proj_all.min(0); g_max = proj_all.max(0)

    def proj_to_rgb(F):
        flat = F.reshape(d, -1).T.numpy() - mean_F[None, :]
        p = flat @ pca_components.T  # (H*W, 3)
        p = (p - g_min[None]) / (g_max - g_min + 1e-8)[None]
        return (p.reshape(H_g, H_g, 3) * 255).clip(0, 255).astype(np.uint8)

    def vol_heatmap_overlay(bev_rgb_img, vol_logits, t_to_show=0, alpha=0.55):
        """vol_logits: (T, Z, H, W). Pick timestep t_to_show, softmax over (Z,H,W),
        max-project over Z → (H, W) → upsample to image_size → blend with BEV as red heatmap."""
        v = vol_logits[t_to_show]                                  # (Z, H, W)
        Z, H, W = v.shape
        probs = torch.softmax(v.reshape(-1), dim=-1).reshape(Z, H, W)
        heat2d = probs.max(dim=0)[0].numpy()                       # (H, W) max prob over Z
        # Normalize so max is 1 (preserves relative ordering)
        if heat2d.max() > 0:
            heat2d = heat2d / heat2d.max()
        heat2d_resized = cv2.resize(heat2d, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
        # Apply jet colormap
        heat_u8 = (heat2d_resized * 255).clip(0, 255).astype(np.uint8)
        heat_color = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
        heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
        blended = (bev_rgb_img.astype(np.float32) * (1 - alpha) + heat_color.astype(np.float32) * alpha).clip(0, 255).astype(np.uint8)
        # Add argmax marker (peak voxel projected back to BEV pixel)
        flat = v.reshape(-1).argmax().item()
        z_b, yx = divmod(flat, H * W)
        y_g, x_g = divmod(yx, W)
        scale = image_size / H
        u_peak = int((x_g + 0.5) * scale); v_peak = int((y_g + 0.5) * scale)
        cv2.drawMarker(blended, (u_peak, v_peak), (255, 255, 255),
                        markerType=cv2.MARKER_CROSS, markerSize=24, thickness=3)
        cv2.drawMarker(blended, (u_peak, v_peak), (0, 0, 0),
                        markerType=cv2.MARKER_CROSS, markerSize=24, thickness=1)
        return blended

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    panel_h, panel_w = image_size, image_size
    # 5 panels: BEV+traj, Wrist+traj, F_bev PCA, F_wrist PCA, Heatmap (t=0)
    panel_size = (panel_w * 5, panel_h)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(out_path), fourcc, args.fps, panel_size)
    if not writer.isOpened():
        print(f"FAIL: VideoWriter open failed for {out_path}")
        return

    for i, fr in enumerate(frames_record):
        bev = fr["bev_rgb"].copy(); wrist = fr["wrist_rgb"].copy()
        bev = draw_trajectory(bev, fr["pred_pix_bev"])
        wrist = draw_trajectory(wrist, fr["pred_pix_wrist"])
        F_bev_pca = proj_to_rgb(fr["F_bev"])
        F_wrist_pca = proj_to_rgb(fr["F_wrist"])
        F_bev_pca = cv2.resize(F_bev_pca, (panel_w, panel_h), interpolation=cv2.INTER_NEAREST)
        F_wrist_pca = cv2.resize(F_wrist_pca, (panel_w, panel_h), interpolation=cv2.INTER_NEAREST)
        heatmap_panel = vol_heatmap_overlay(fr["bev_rgb"], fr["vol_logits"], t_to_show=0)
        # Labels
        for img, lbl in [(bev, "BEV + pred traj"), (wrist, "Wrist + pred traj"),
                          (F_bev_pca, "F_bev PCA"), (F_wrist_pca, "F_wrist PCA"),
                          (heatmap_panel, "Vol heatmap (t=0, max-Z)")]:
            cv2.putText(img, lbl, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2, cv2.LINE_AA)
            cv2.putText(img, lbl, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (20, 20, 20), 1, cv2.LINE_AA)
        # Step counter on BEV
        step_lbl = f"replan {i+1}/{len(frames_record)}  step {fr['step']}"
        cv2.putText(bev, step_lbl, (10, panel_h - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(bev, step_lbl, (10, panel_h - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)
        panel = np.concatenate([bev, wrist, F_bev_pca, F_wrist_pca, heatmap_panel], axis=1)
        writer.write(cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"Video saved → {out_path}  (frames={len(frames_record)}, fps={args.fps}, success={success})")


if __name__ == "__main__":
    main()
