"""Evaluate SVD + Global Action Regression baseline in LIBERO simulator.
Teleport mode with zero rotation. Same eval protocol as eval_joint.py."""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, os.path.dirname(__file__))

import argparse
import json
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import imageio
import subprocess
from pathlib import Path
from PIL import Image

from svd.models import UNetSpatioTemporalConditionModel
from svd.pipelines import StableVideoDiffusionPipeline
from train_svd_global_action_regressor import GlobalActionHead, N_GRIPPER_BINS

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv

import h5py

IMAGE_SIZE = 448
SVD_H, SVD_W = 320, 576
DEVICE = "cuda"
N_WINDOW = 7


def preprocess_obs(rgb_obs, image_size):
    rgb = np.flipud(rgb_obs).copy().astype(np.float32) / 255.0
    if rgb.shape[0] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    rgb_norm = (rgb - mean) / std
    return torch.from_numpy(rgb_norm).permute(2, 0, 1).float().unsqueeze(0), rgb


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--svd_base", type=str,
                   default="checkpoints/stable-video-diffusion-img2vid-xt-1-1")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--n_episodes", type=int, default=20)
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--out_dir", type=str, default="eval_global_action_output")
    p.add_argument("--camera", type=str, default="agentview")
    args = p.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    device = torch.device(DEVICE)

    # --- Load model ---
    print("Loading model...")
    ckpt_dir = Path(args.checkpoint)
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        str(ckpt_dir / "unet"), torch_dtype=torch.float16).to(device)
    unet.eval()

    ckpt = torch.load(ckpt_dir / "action_checkpoint.pt", map_location=device)
    action_head = GlobalActionHead().to(device)
    action_head.load_state_dict(ckpt["action_head"])
    action_head.eval()

    stats = ckpt["stats"]
    min_pos = torch.tensor(stats["min_pos"], device=device, dtype=torch.float32)
    max_pos = torch.tensor(stats["max_pos"], device=device, dtype=torch.float32)
    min_g, max_g = stats["min_gripper"], stats["max_gripper"]

    # Feature hooks
    captured = {}
    def hook(name):
        def fn(mod, inp, out):
            captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
        return fn
    unet.up_blocks[1].register_forward_hook(hook("ub1"))
    unet.up_blocks[2].register_forward_hook(hook("ub2"))

    # Pipeline
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        args.svd_base, unet=unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(device)

    # --- LIBERO env ---
    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: {task_name}")

    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(0)

    def apply_clean_scene(env):
        sim = env.env.sim
        for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
            try:
                bid = sim.model.body_name2id(fname)
                sim.model.body_pos[bid] = np.array([0, 0, -5.0])
            except Exception: pass
        sim.forward()
        dist_bodies = set()
        for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
            try: dist_bodies.add(sim.model.body_name2id(dn))
            except Exception: pass
        for gid in range(sim.model.ngeom):
            if sim.model.geom_bodyid[gid] in dist_bodies:
                sim.model.geom_rgba[gid][3] = 0.0

    n_eps = min(args.n_episodes, len(init_states))
    results = []

    for ep in range(n_eps):
        print(f"\n--- Episode {ep+1}/{n_eps} ---")
        env.reset()
        apply_clean_scene(env)

        obs = env.set_init_state(init_states[ep].copy())
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        success = False
        step_idx = 0
        current_gripper_cmd = 0.0
        frames = []
        replan_idx = 0

        while step_idx < args.max_steps and not success:
            rgb_obs = obs[f"{args.camera}_image"]
            _, rgb_float = preprocess_obs(rgb_obs, IMAGE_SIZE)

            # Generate video + extract features
            rgb_pil = Image.fromarray((rgb_float * 255).astype(np.uint8)).resize((SVD_W, SVD_H))
            captured.clear()
            with torch.inference_mode():
                gen_frames_pil = pipe(rgb_pil, height=SVD_H, width=SVD_W,
                                    num_frames=N_WINDOW, decode_chunk_size=4,
                                    num_inference_steps=25).frames[0]

            if "ub1" in captured and "ub2" in captured:
                feat1 = captured["ub1"].clone()
                feat2 = captured["ub2"].clone()
                n_t = min(feat1.shape[0], N_WINDOW, len(gen_frames_pil))

                with torch.no_grad():
                    pos_pred, grip_logits = action_head(feat1[:n_t], feat2[:n_t])

                # Denormalize predictions
                positions = pos_pred * (max_pos - min_pos) + min_pos  # (n_t, 3)
                positions = positions.cpu().numpy()

                grip_bins = grip_logits.argmax(dim=1)
                grip_centers = torch.linspace(0, 1, N_GRIPPER_BINS, device=device)
                gripper_vals = (grip_centers[grip_bins] * (max_g - min_g) + min_g).cpu().numpy()

                # Interleaved: execute action t, then visualize
                for t in range(n_t):
                    target_pos = positions[t]
                    new_gripper = 1.0 if gripper_vals[t] > 0 else -1.0

                    # Execute
                    try:
                        for _ in range(25):
                            cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                            delta = target_pos - cur_pos
                            if np.linalg.norm(delta) < 0.005:
                                break
                            delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
                            servo_action = np.zeros(7, dtype=np.float32)
                            servo_action[:3] = delta_clipped
                            servo_action[6] = current_gripper_cmd
                            obs, _, done, _ = env.step(servo_action)
                            step_idx += 1
                            if done:
                                success = True
                                break
                            if step_idx >= args.max_steps:
                                break
                    except ValueError:
                        success = True

                    if not success and step_idx < args.max_steps:
                        try:
                            grip_action = np.zeros(7, dtype=np.float32)
                            grip_action[6] = new_gripper
                            obs, _, done, _ = env.step(grip_action)
                            step_idx += 1
                            current_gripper_cmd = new_gripper
                            if done:
                                success = True
                        except ValueError:
                            success = True

                    # Visualize: GT (left) vs Gen (right) — no heatmap, just target marker
                    gt_rgb = obs[f"{args.camera}_image"]
                    gt_frame = np.flipud(gt_rgb).copy()
                    if gt_frame.shape[0] != IMAGE_SIZE:
                        gt_frame = cv2.resize(gt_frame, (IMAGE_SIZE, IMAGE_SIZE))
                    gen_frame = cv2.resize(np.array(gen_frames_pil[t]), (IMAGE_SIZE, IMAGE_SIZE))

                    # Label
                    left = gt_frame.copy()
                    cv2.putText(left, f"GT replan={replan_idx} t={t}", (5, 25),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
                    cv2.putText(left, f"target=({target_pos[0]:.2f},{target_pos[1]:.2f},{target_pos[2]:.2f})",
                               (5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)

                    right = gen_frame.copy()
                    cv2.putText(right, f"Gen replan={replan_idx} t={t}", (5, 25),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

                    combined = np.concatenate([left, right], axis=1)
                    frames.append(combined)

                    if success or step_idx >= args.max_steps:
                        break
            else:
                break

            replan_idx += 1
            torch.cuda.empty_cache()

        results.append({"episode": ep, "success": success, "steps": step_idx})
        print(f"  {'SUCCESS' if success else 'FAILED'} in {step_idx} steps")

        # Save video (H.264)
        if frames:
            tmp_path = f"{args.out_dir}/episode_{ep}_raw.mp4"
            vid_path = f"{args.out_dir}/episode_{ep}.mp4"
            imageio.mimwrite(tmp_path, frames, fps=4, quality=8)
            subprocess.run(["ffmpeg", "-y", "-i", tmp_path, "-c:v", "libx264",
                           "-preset", "ultrafast", "-crf", "23",
                           "-movflags", "+faststart", vid_path],
                          capture_output=True)
            os.remove(tmp_path)
            print(f"  Saved: {vid_path}")

    # Summary
    n_success = sum(r["success"] for r in results)
    print(f"\n{'='*50}")
    print(f"Success rate: {n_success}/{n_eps} ({100*n_success/n_eps:.1f}%)")
    print(f"Results: {results}")

    with open(f"{args.out_dir}/results.json", "w") as f:
        json.dump(results, f, indent=2)

    env.close()


if __name__ == "__main__":
    main()
