"""Evaluate SVD+PARA joint model in LIBERO simulator.
Teleport mode with zero rotation. Generates videos with heatmap overlays."""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, os.path.dirname(__file__))

import argparse
import json
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import imageio
from pathlib import Path
from PIL import Image
from tqdm import tqdm

from svd.models import UNetSpatioTemporalConditionModel
from svd.pipelines import StableVideoDiffusionPipeline
from diffusers import AutoencoderKLTemporalDecoder
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from train_svd_para_joint import (ParaHeadsOnUNet, PARA_OUT_SIZE, N_HEIGHT_BINS,
                                   N_GRIPPER_BINS, N_ROT_BINS, PRERENDER_SIZE,
                                   rand_log_normal, tensor_to_vae_latent, _resize_with_antialiasing)
import torchvision

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
)
from utils import recover_3d_from_direct_keypoint_and_height


def get_camera_params(sim, camera_name, image_size):
    """Return camera matrices for projection and 3D recovery."""
    world_to_camera = get_camera_transform_matrix(sim, camera_name, image_size, image_size)
    camera_pose = get_camera_extrinsic_matrix(sim, camera_name)  # camera→world
    cam_K_norm = get_camera_intrinsic_matrix(sim, camera_name, image_size, image_size)
    cam_K_norm[0] /= image_size
    cam_K_norm[1] /= image_size
    cam_K = cam_K_norm.copy()
    cam_K[0] *= image_size
    cam_K[1] *= image_size
    return world_to_camera, camera_pose, cam_K

IMAGE_SIZE = 448
SVD_H, SVD_W = 320, 576
DEVICE = "cuda"
N_WINDOW = 7


def preprocess_obs(rgb_obs, image_size):
    """Process raw env observation to model input."""
    rgb = np.flipud(rgb_obs).copy().astype(np.float32) / 255.0
    if rgb.shape[0] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    # ImageNet normalize
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    rgb_norm = (rgb - mean) / std
    return torch.from_numpy(rgb_norm).permute(2, 0, 1).float().unsqueeze(0), rgb


def extract_actions(vol_logits, para_heads, feat1, feat2, min_h, max_h, min_g, max_g,
                    cam_pose, cam_K, coord_scale):
    """Extract 3D target positions and gripper from PARA predictions."""
    BT, Nh, H, W = vol_logits.shape

    # Argmax over volume to get 2D + height
    pred_2d = torch.zeros(BT, 2, device=vol_logits.device)
    pred_h_bins = torch.zeros(BT, device=vol_logits.device, dtype=torch.long)
    for i in range(BT):
        vol_i = vol_logits[i]  # (Nh, H, W)
        max_over_h, _ = vol_i.max(dim=0)  # (H, W)
        flat_idx = max_over_h.view(-1).argmax()
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[i, 0] = px.float()
        pred_2d[i, 1] = py.float()
        pred_h_bins[i] = vol_i[:, py, px].argmax()

    bin_centers = torch.linspace(0, 1, Nh, device=vol_logits.device)
    pred_height = bin_centers[pred_h_bins] * (max_h - min_h) + min_h

    # Get gripper predictions at predicted pixels
    grip_logits, _ = None, None
    query = pred_2d.clone()
    _, grip_logits, _ = para_heads(feat1, feat2, query_pixels=query)

    # Decode gripper
    if grip_logits is not None:
        grip_bins = grip_logits.argmax(dim=1)
        grip_centers = torch.linspace(0, 1, N_GRIPPER_BINS, device=vol_logits.device)
        gripper_values = grip_centers[grip_bins] * (max_g - min_g) + min_g
    else:
        gripper_values = torch.zeros(BT, device=vol_logits.device)

    # Convert to 3D positions
    pred_2d_img = pred_2d / coord_scale  # back to IMAGE_SIZE space
    targets_3d = []
    for i in range(BT):
        kp = pred_2d_img[i].cpu().numpy()
        h = pred_height[i].item()
        pos_3d = recover_3d_from_direct_keypoint_and_height(
            kp, h, cam_pose, cam_K
        )
        targets_3d.append(pos_3d)

    return np.array(targets_3d), gripper_values.cpu().numpy(), pred_2d_img.cpu().numpy()


def make_heatmap_overlay(rgb_frame, vol_logits_t, pred_px, pred_py, gt_text="", size=448):
    """Create heatmap overlay on a frame."""
    vol_probs = F.softmax(vol_logits_t.reshape(-1), dim=0).view(N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
    heat = vol_probs.max(dim=0)[0].cpu().numpy()
    heat_norm = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
    heat_up = cv2.resize(heat_norm, (size, size))
    heat_color = cv2.applyColorMap((heat_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)

    frame_resized = cv2.resize(rgb_frame, (size, size)) if rgb_frame.shape[0] != size else rgb_frame.copy()
    overlay = (frame_resized * 0.5 + heat_color * 0.5).astype(np.uint8)

    cv2.circle(overlay, (int(pred_px), int(pred_py)), 8, (255, 0, 0), 3)
    if gt_text:
        cv2.putText(overlay, gt_text, (5, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

    return overlay


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str,
                   default="output_svd_para_joint/checkpoint-2000")
    p.add_argument("--svd_base", type=str,
                   default="checkpoints/stable-video-diffusion-img2vid-xt-1-1")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--n_episodes", type=int, default=5)
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--out_dir", type=str, default="eval_joint_output")
    p.add_argument("--camera", type=str, default="agentview")
    p.add_argument("--clean_scene", action="store_true", default=True,
                   help="Remove distractors and furniture (match OOD training data)")
    p.add_argument("--shift_dx", type=float, default=0.0)
    p.add_argument("--shift_dy", type=float, default=0.0)
    args = p.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    device = torch.device(DEVICE)
    coord_scale = PARA_OUT_SIZE / IMAGE_SIZE

    # --- Load model ---
    print("Loading model...")
    ckpt_dir = Path(args.checkpoint)
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        str(ckpt_dir / "unet"), torch_dtype=torch.float16).to(device)
    unet.eval()

    para_ckpt = torch.load(ckpt_dir / "para_checkpoint.pt", map_location=device)
    para_heads = ParaHeadsOnUNet().to(device)
    para_heads.load_state_dict(para_ckpt["para_heads"])
    para_heads.eval()

    stats = para_ckpt["stats"]
    min_h, max_h = stats["min_height"], stats["max_height"]
    min_g, max_g = stats["min_gripper"], stats["max_gripper"]

    # Feature capture hooks
    captured = {}
    def hook(name):
        def fn(mod, inp, out):
            captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
        return fn
    unet.up_blocks[1].register_forward_hook(hook("ub1"))
    unet.up_blocks[2].register_forward_hook(hook("ub2"))

    # Load other SVD components for video generation
    vae = AutoencoderKLTemporalDecoder.from_pretrained(
        args.svd_base, subfolder="vae", torch_dtype=torch.float16).to(device)
    vae.eval()
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        args.svd_base, subfolder="image_encoder", torch_dtype=torch.float16).to(device)
    image_encoder.eval()
    feature_extractor_clip = CLIPImageProcessor.from_pretrained(
        args.svd_base, subfolder="feature_extractor")

    # Video generation pipeline
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        args.svd_base, unet=unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(device)

    print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # --- LIBERO env setup ---
    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: {task_name}")

    import h5py
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(0)

    def apply_clean_scene(env):
        """Remove distractors and furniture to match OOD training data."""
        sim = env.env.sim
        for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
            try:
                bid = sim.model.body_name2id(fname)
                sim.model.body_pos[bid] = np.array([0, 0, -5.0])
            except Exception: pass
        sim.forward()
        dist_bodies = set()
        for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
            try: dist_bodies.add(sim.model.body_name2id(dn))
            except Exception: pass
        for gid in range(sim.model.ngeom):
            if sim.model.geom_bodyid[gid] in dist_bodies:
                sim.model.geom_rgba[gid][3] = 0.0

    n_eps = min(args.n_episodes, len(init_states))
    results = []

    for ep in range(n_eps):
        print(f"\n--- Episode {ep+1}/{n_eps} ---")
        env.reset()

        # Apply clean scene after reset
        if args.clean_scene:
            apply_clean_scene(env)

        init_state = init_states[ep].copy()

        # Shift objects if requested
        if args.shift_dx != 0 or args.shift_dy != 0:
            for qps in [9, 37]:
                si = qps + 1
                init_state[si] += args.shift_dx
                init_state[si + 1] += args.shift_dy
            for qps in [16, 23, 30]:
                si = qps + 1
                init_state[si:si+3] = [10.0, 10.0, 0.9]

        obs = env.set_init_state(init_state)

        # Let sim settle
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))
        success = False
        step_idx = 0
        current_gripper_cmd = 0.0
        frames = []
        replan_idx = 0

        while step_idx < args.max_steps and not success:
            # Get observation
            rgb_obs = obs[f"{args.camera}_image"]
            img_tensor, rgb_float = preprocess_obs(rgb_obs, IMAGE_SIZE)
            img_tensor = img_tensor.to(device)

            # Camera matrices (proper intrinsics + extrinsics)
            world_to_cam, cam_pose, cam_K = get_camera_params(
                env.sim, args.camera, IMAGE_SIZE)

            # Generate video + extract features
            rgb_pil = Image.fromarray((rgb_float * 255).astype(np.uint8)).resize((SVD_W, SVD_H))

            captured.clear()
            with torch.inference_mode():
                gen_frames_pil = pipe(rgb_pil, height=SVD_H, width=SVD_W,
                                    num_frames=N_WINDOW, decode_chunk_size=4,
                                    num_inference_steps=25).frames[0]

            # Run PARA on captured features (clone to escape inference_mode)
            if "ub1" in captured and "ub2" in captured:
                feat1 = captured["ub1"].clone()
                feat2 = captured["ub2"].clone()
                with torch.no_grad():
                    vol_logits, _, _ = para_heads(feat1, feat2)

                # Extract ALL timesteps' actions
                n_t = min(vol_logits.shape[0], N_WINDOW, len(gen_frames_pil))
                targets_3d, gripper_vals, pred_2d_img = extract_actions(
                    vol_logits[:n_t], para_heads, feat1[:n_t].clone(), feat2[:n_t].clone(),
                    min_h, max_h, min_g, max_g, cam_pose, cam_K, coord_scale
                )

                # Interleaved: execute action t, then visualize GT vs Gen for each t
                for t in range(n_t):
                    target_pos = targets_3d[t]
                    new_gripper = 1.0 if gripper_vals[t] > 0 else -1.0

                    # Execute action first
                    try:
                        for _ in range(25):
                            cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                            delta = target_pos - cur_pos
                            if np.linalg.norm(delta) < 0.005:
                                break
                            delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
                            servo_action = np.zeros(7, dtype=np.float32)
                            servo_action[:3] = delta_clipped
                            servo_action[6] = current_gripper_cmd
                            obs, _, done, _ = env.step(servo_action)
                            step_idx += 1
                            if done:
                                success = True
                                break
                            if step_idx >= args.max_steps:
                                break
                    except ValueError:
                        success = True

                    if not success and step_idx < args.max_steps:
                        try:
                            grip_action = np.zeros(7, dtype=np.float32)
                            grip_action[6] = new_gripper
                            obs, _, done, _ = env.step(grip_action)
                            step_idx += 1
                            current_gripper_cmd = new_gripper
                            if done:
                                success = True
                        except ValueError:
                            success = True

                    # Visualize AFTER execution: GT (left) vs Gen (right)
                    gt_rgb = obs[f"{args.camera}_image"]
                    gt_frame = np.flipud(gt_rgb).copy()
                    if gt_frame.shape[0] != IMAGE_SIZE:
                        gt_frame = cv2.resize(gt_frame, (IMAGE_SIZE, IMAGE_SIZE))
                    gen_frame = cv2.resize(np.array(gen_frames_pil[t]), (IMAGE_SIZE, IMAGE_SIZE))
                    pred_px = int(pred_2d_img[t, 0])
                    pred_py = int(pred_2d_img[t, 1])
                    left = make_heatmap_overlay(gt_frame, vol_logits[t], pred_px, pred_py,
                                               f"GT replan={replan_idx} t={t}")
                    right = make_heatmap_overlay(gen_frame, vol_logits[t], pred_px, pred_py,
                                                f"Gen replan={replan_idx} t={t}")
                    combined = np.concatenate([left, right], axis=1)
                    frames.append(combined)

                    if success or step_idx >= args.max_steps:
                        break
            else:
                break

            replan_idx += 1
            torch.cuda.empty_cache()

        results.append({"episode": ep, "success": success, "steps": step_idx})
        print(f"  {'SUCCESS' if success else 'FAILED'} in {step_idx} steps")

        # Save episode video (H.264 encoded)
        if frames:
            import subprocess
            tmp_path = f"{args.out_dir}/episode_{ep}_raw.mp4"
            vid_path = f"{args.out_dir}/episode_{ep}.mp4"
            imageio.mimwrite(tmp_path, frames, fps=4, quality=8)
            subprocess.run(["ffmpeg", "-y", "-i", tmp_path, "-c:v", "libx264",
                           "-preset", "ultrafast", "-crf", "23",
                           "-movflags", "+faststart", vid_path],
                          capture_output=True)
            os.remove(tmp_path)
            print(f"  Saved: {vid_path}")

    # Summary
    n_success = sum(r["success"] for r in results)
    print(f"\n{'='*50}")
    print(f"Success rate: {n_success}/{n_eps} ({100*n_success/n_eps:.1f}%)")
    print(f"Results: {results}")

    with open(f"{args.out_dir}/results.json", "w") as f:
        json.dump(results, f, indent=2)

    env.close()


if __name__ == "__main__":
    main()
