"""eval_uva_para.py — Evaluate UVA+PARA in LIBERO simulation.

Like eval.py but uses the UVA MAR backbone instead of DINO. Additionally renders
predicted video frames from the UVA model alongside heatmap overlays.

Usage:
    python libero/eval_uva_para.py \
        --checkpoint libero/checkpoints/uva_para_libero/best.pth \
        --benchmark libero_spatial --task_id 0 --n_episodes 20 --save_video
"""

import argparse
import json
import os
import sys
import tempfile
from pathlib import Path
from types import SimpleNamespace

import cv2
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, str(Path(__file__).parent))
UVA_ROOT = Path(__file__).resolve().parent.parent / "video_training" / "unified_video_action"
sys.path.insert(0, str(UVA_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from train_uva_para import (
    ParaHeads, build_vae, extract_pred_2d_and_height,
    N_FRAMES, UVA_IMG_SIZE, PARA_OUT_SIZE, LATENT_SCALE, DECODER_DIM,
    N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS,
)
from utils import recover_3d_from_direct_keypoint_and_height

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)

IMAGE_SIZE = UVA_IMG_SIZE  # eval at 256x256 (matching UVA)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def preprocess_obs(rgb_obs, image_size=IMAGE_SIZE):
    """HxWx3 uint8 -> (1, 3, H, W) tensor in [-1, 1] (UVA format, not ImageNet)."""
    img = rgb_obs.astype(np.float32) / 255.0
    img = np.flipud(img).copy()
    img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    # UVA uses [-1, 1] normalization
    img = img * 2.0 - 1.0
    tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0)
    return tensor


def get_camera_params(sim, camera_name, image_size=IMAGE_SIZE):
    world_to_camera = get_camera_transform_matrix(sim, camera_name, image_size, image_size)
    camera_pose = get_camera_extrinsic_matrix(sim, camera_name)
    cam_K_norm = get_camera_intrinsic_matrix(sim, camera_name, image_size, image_size)
    cam_K_norm[0] /= image_size
    cam_K_norm[1] /= image_size
    cam_K = cam_K_norm.copy()
    cam_K[0] *= image_size
    cam_K[1] *= image_size
    return world_to_camera, camera_pose, cam_K


def eef_to_start_kp(eef_pos, world_to_camera, image_size=IMAGE_SIZE):
    pix_rc = project_points_from_world_to_camera(
        points=eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera_transform=world_to_camera,
        camera_height=image_size, camera_width=image_size,
    )[0]
    u, v = float(pix_rc[1]), float(pix_rc[0])
    return torch.tensor([u, v], dtype=torch.float32)


def decode_window_actions(volume_logits, para_heads, feats, camera_pose, cam_K,
                          current_eef_pos, current_eef_quat,
                          min_h, max_h, min_g, max_g, min_r, max_r,
                          image_size=IMAGE_SIZE, max_delta=0.05):
    """Decode N_FRAMES predicted timesteps into OSC_POSE delta actions."""
    OSC_POS_SCALE = 0.05
    OSC_ROT_SCALE = 0.5

    n_window = volume_logits.shape[1]
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    pred_px_list = []
    for t in range(n_window):
        vol_t = volume_logits[0, t]
        max_over_h = vol_t.max(dim=0)[0]
        flat_idx = max_over_h.reshape(-1).argmax().item()
        py = flat_idx // pred_size
        px = flat_idx % pred_size
        pred_px_list.append((px, py))

    pred_pixels = torch.tensor(
        [[px, py] for px, py in pred_px_list], dtype=torch.float32, device=feats.device
    ).unsqueeze(0)
    with torch.no_grad():
        gripper_logits, rotation_logits = para_heads.predict_at_pixels(feats, pred_pixels)

    actions = []
    pred_3d_targets = []
    ref_pos = current_eef_pos.copy()

    for t, (px, py) in enumerate(pred_px_list):
        vol_t = volume_logits[0, t]
        px_full = (px + 0.5) * scale
        py_full = (py + 0.5) * scale

        h_bin = vol_t[:, py, px].argmax().item()
        height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        pred_3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), height, camera_pose, cam_K,
        )
        if pred_3d is None:
            pred_3d = pred_3d_targets[-1] if pred_3d_targets else ref_pos.copy()
        pred_3d_targets.append(pred_3d)

        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        euler_pred = np.array([
            (rotation_logits[0, t, axis, :].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (max_r[axis] - min_r[axis]) + min_r[axis]
            for axis in range(3)
        ])
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_current = ScipyR.from_quat(current_eef_quat)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        g_bin = gripper_logits[0, t, :].argmax().item()
        gripper_val = (g_bin / max(N_GRIPPER_BINS - 1, 1)) * (max_g - min_g) + min_g
        gripper_cmd = float(np.clip((gripper_val - min_g) / (max_g - min_g) * 2 - 1, -1, 1))

        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_norm
        action[3:6] = delta_rot_norm
        action[6] = gripper_cmd
        actions.append(action)

    return actions, pred_3d_targets


def render_eval_frame(rgb_obs, volume_logits, current_eef_pos,
                      world_to_camera, image_size, step_idx, success,
                      predicted_frames=None):
    """Render eval frame with heatmap overlay and optional predicted video inset.

    Returns (H, W*2 or W, 3) uint8 RGB.
    If predicted_frames is provided, creates a 2-column layout:
      Left: sim frame with heatmap overlay
      Right: 2x2 grid of predicted video frames
    """
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    frame = rgb_obs.astype(np.float32) / 255.0
    frame = np.flipud(frame).copy()
    frame = cv2.resize(frame, (image_size, image_size), interpolation=cv2.INTER_LINEAR)

    vol_t = volume_logits[0, 0]
    vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
    heat_small = vol_probs.max(dim=0)[0].cpu().numpy()
    heat = cv2.resize(heat_small, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    heat = (heat - heat.min()) / (heat.max() + 1e-8)
    heat_rgb = np.zeros_like(frame)
    heat_rgb[..., 0] = heat
    overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
    vis = (overlay * 255.0).astype(np.uint8)

    flat_idx = heat_small.argmax()
    py, px = flat_idx // pred_size, flat_idx % pred_size
    px_full, py_full = int((px + 0.5) * scale), int((py + 0.5) * scale)
    cv2.drawMarker(vis, (px_full, py_full), (0, 255, 0), cv2.MARKER_CROSS, 18, 2, cv2.LINE_AA)

    pix_rc = project_points_from_world_to_camera(
        current_eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera, image_size, image_size,
    )[0]
    u, v = int(round(float(pix_rc[1]))), int(round(float(pix_rc[0])))
    if 0 <= u < image_size and 0 <= v < image_size:
        cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)

    label = f"step {step_idx}"
    if success is not None:
        label += "  SUCCESS" if success else "  running"
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)

    if predicted_frames is not None:
        # Create 2x2 grid of predicted frames, each image_size//2 x image_size//2
        half = image_size // 2
        grid = np.zeros((image_size, image_size, 3), dtype=np.uint8)
        for i in range(min(N_FRAMES, 4)):
            pf = predicted_frames[i]  # (3, H, W) tensor in [-1,1]
            pf_np = ((pf.cpu().clamp(-1, 1) + 1.0) / 2.0 * 255).permute(1, 2, 0).numpy().astype(np.uint8)
            pf_resized = cv2.resize(pf_np, (half, half))
            r, c = i // 2, i % 2
            grid[r * half:(r + 1) * half, c * half:(c + 1) * half] = pf_resized
            cv2.putText(grid, f"pred t={i}", (c * half + 4, r * half + 14),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1, cv2.LINE_AA)
        return np.concatenate([vis, grid], axis=1)

    return vis


# ---------------------------------------------------------------------------
# Main eval
# ---------------------------------------------------------------------------

def run_eval(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Device: {device}")

    # --- Load checkpoint ---
    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu")

    min_h = float(ckpt.get("min_height", 0.0))
    max_h = float(ckpt.get("max_height", 1.0))
    min_g = float(ckpt.get("min_gripper", -1.0))
    max_g = float(ckpt.get("max_gripper", 1.0))
    min_r = np.array(ckpt.get("min_rot", [-3.14159] * 3), dtype=np.float64)
    max_r = np.array(ckpt.get("max_rot", [3.14159] * 3), dtype=np.float64)
    print(f"Height:  [{min_h:.4f}, {max_h:.4f}]")
    print(f"Gripper: [{min_g:.4f}, {max_g:.4f}]")
    print(f"Rot:     {min_r.tolist()} .. {max_r.tolist()}")

    # --- Build models ---
    vae = build_vae(args.vae_ckpt, device)
    mar = mar_base_video_only(
        img_size=UVA_IMG_SIZE, vae_stride=16, patch_size=1, vae_embed_dim=16,
        num_sampling_steps="100", diffloss_d=6, diffloss_w=1024,
    ).to(device)
    mar.load_state_dict(ckpt["mar_state_dict"], strict=False)
    mar.eval()

    para_heads = ParaHeads(decoder_dim=DECODER_DIM, para_out_size=PARA_OUT_SIZE).to(device)
    para_heads.load_state_dict(ckpt["para_heads_state_dict"])
    para_heads.eval()
    print(f"Loaded UVA+PARA from {ckpt_path}")

    # --- LIBERO env ---
    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: [{args.benchmark}] {task_name}")

    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    n_episodes = min(args.n_episodes, len(init_states))

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(args.seed)
    env.reset()
    print("Environment ready.")

    successes = []
    step_counts = []

    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        env.reset()
        obs = env.set_init_state(init_states[ep_idx])
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)

        done = False
        success = False
        frames = [] if args.save_video else None
        step_idx = 0
        predicted_video_cache = None

        while step_idx < args.max_steps and not done:
            current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            rgb_obs = obs[f"{args.camera}_image"]

            img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device)

            with torch.no_grad():
                # Encode current frame
                posterior = vae.encode(img_tensor.float())
                z0 = posterior.sample() * LATENT_SCALE

                # Step 1: Sample future video tokens from MAR (the model's prediction)
                cond_z = z0.unsqueeze(1).expand(1, N_FRAMES, -1, -1, -1)
                sampled_video_latents, _ = mar.sample_tokens(
                    bsz=1, cond=cond_z, num_iter=args.num_iter, cfg=1.0, temperature=0.95,
                )  # (T, C_vae, H_lat, W_lat)

                # Step 2: Use sampled tokens as x_tokens for PARA forward pass
                # This matches training where forward_decode_tokens sees real future frames
                sampled_tokens = mar.patchify(sampled_video_latents)  # (T, S, C_token)
                sampled_tokens = sampled_tokens.unsqueeze(0)  # (1, T, S, C)
                cond_tokens = mar.patchify(z0).unsqueeze(1).expand(-1, N_FRAMES, -1, -1)  # (1, T, S, C)

                dec_tokens = mar.forward_decode_tokens(sampled_tokens, cond_tokens, mask=None)
                volume_logits, feats, _, _ = para_heads(dec_tokens)

                # Step 3: Decode sampled video for visualization
                if args.save_video:
                    predicted_video_cache = vae.decode(sampled_video_latents / LATENT_SCALE)  # (T, 3, H, W)

            # Decode actions
            window_actions, _ = decode_window_actions(
                volume_logits, para_heads, feats, camera_pose, cam_K,
                current_eef_pos, current_eef_quat,
                min_h, max_h, min_g, max_g, min_r, max_r,
                image_size=IMAGE_SIZE,
            )

            for t, action in enumerate(window_actions):
                if frames is not None:
                    pred_frames = predicted_video_cache if predicted_video_cache is not None else None
                    frames.append(render_eval_frame(
                        obs[f"{args.camera}_image"], volume_logits, current_eef_pos,
                        world_to_camera, IMAGE_SIZE, step_idx, success=None,
                        predicted_frames=pred_frames,
                    ))

                obs, _, done, _ = env.step(action)
                step_idx += 1
                if done:
                    success = True
                    break
                if step_idx >= args.max_steps:
                    break

        if frames:
            final_frame = render_eval_frame(
                obs[f"{args.camera}_image"], volume_logits,
                np.array(obs["robot0_eef_pos"], dtype=np.float64),
                world_to_camera, IMAGE_SIZE, step_idx, success=success,
                predicted_frames=predicted_video_cache,
            )
            frames[-1] = final_frame

            video_dir = Path(args.out_dir) / "videos"
            video_dir.mkdir(parents=True, exist_ok=True)
            video_path = video_dir / f"ep{ep_idx:03d}_{'success' if success else 'fail'}.mp4"
            h, w = frames[0].shape[:2]
            writer = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), args.video_fps, (w, h))
            for f in frames:
                writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
            writer.release()

        successes.append(float(success))
        step_counts.append(step_idx + 1)
        tqdm.write(f"  Ep {ep_idx+1:3d}: {'SUCCESS' if success else 'FAILURE'}  steps={step_idx+1}")

    env.close()

    success_rate = float(np.mean(successes))
    avg_steps = float(np.mean(step_counts))
    print(f"\n{'=' * 52}")
    print(f"  Benchmark:    {args.benchmark}")
    print(f"  Task {args.task_id}:      {task_name}")
    print(f"  Episodes:     {n_episodes}")
    print(f"  Successes:    {int(sum(successes))} / {n_episodes}")
    print(f"  Success Rate: {success_rate * 100:.1f}%")
    print(f"  Avg steps:    {avg_steps:.1f} / {args.max_steps}")
    print(f"{'=' * 52}")

    results = {
        "benchmark": args.benchmark, "task_id": args.task_id, "task_name": task_name,
        "checkpoint": str(ckpt_path), "n_episodes": n_episodes,
        "success_rate": success_rate, "successes": successes,
        "step_counts": step_counts, "avg_steps": avg_steps, "max_steps": args.max_steps,
    }
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"eval_uva_para_{args.benchmark}_task{args.task_id}.json"
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved -> {out_path}")
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate UVA+PARA in LIBERO simulation")
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--vae_ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--n_episodes", type=int, default=20)
    parser.add_argument("--max_steps", type=int, default=300)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--out_dir", type=str, default="libero/out/eval_uva_para")
    parser.add_argument("--save_video", action="store_true")
    parser.add_argument("--video_fps", type=int, default=10)
    parser.add_argument("--num_iter", type=int, default=1, help="MAR unmasking iterations (1 = single-shot prediction, fast)")
    parser.add_argument("--video_pred_every", type=int, default=20,
                        help="Generate predicted video every N env steps (slow)")
    args = parser.parse_args()
    run_eval(args)
