"""Debug eval: for each replan, show GT frame paused on left while generated video
plays through the full window on right, with heatmaps + keypoints overlaid.
Then GT advances by executing actions. Quick iteration: 1 episode, 20 steps."""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, os.path.dirname(__file__))

import numpy as np
import torch
import torch.nn.functional as F
import cv2
import imageio
import subprocess
from pathlib import Path
from PIL import Image

from svd.models import UNetSpatioTemporalConditionModel
from svd.pipelines import StableVideoDiffusionPipeline
from diffusers import AutoencoderKLTemporalDecoder
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from train_svd_para_joint import ParaHeadsOnUNet, PARA_OUT_SIZE, N_HEIGHT_BINS, N_GRIPPER_BINS

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix, get_camera_extrinsic_matrix, get_camera_intrinsic_matrix)
from utils import recover_3d_from_direct_keypoint_and_height

import h5py

IMAGE_SIZE = 448
SVD_H, SVD_W = 320, 576
DEVICE = "cuda"
N_WINDOW = 7
CKPT_DIR = "output_svd_para_joint/checkpoint-2000"
SVD_BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
BENCHMARK = "libero_spatial"
TASK_ID = 0
MAX_REPLANS = 20  # quick iteration

def get_cam_params(sim, camera, size):
    w2c = get_camera_transform_matrix(sim, camera, size, size)
    cam_pose = get_camera_extrinsic_matrix(sim, camera)
    K = get_camera_intrinsic_matrix(sim, camera, size, size)
    K[0] /= size; K[1] /= size
    cam_K = K.copy(); cam_K[0] *= size; cam_K[1] *= size
    return w2c, cam_pose, cam_K

def preprocess(rgb_obs):
    rgb = np.flipud(rgb_obs).copy().astype(np.float32) / 255.0
    if rgb.shape[0] != IMAGE_SIZE:
        rgb = cv2.resize(rgb, (IMAGE_SIZE, IMAGE_SIZE))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    tensor = torch.from_numpy((rgb - mean) / std).permute(2, 0, 1).float().unsqueeze(0)
    return tensor, rgb

def make_overlay(frame_np, vol_logits_t, pred_px, pred_py, label=""):
    """Heatmap overlay with predicted keypoint on a frame."""
    if frame_np.max() <= 1.0:
        frame_np = (frame_np * 255).astype(np.uint8)
    frame_np = frame_np.copy()
    if frame_np.shape[0] != IMAGE_SIZE:
        frame_np = cv2.resize(frame_np, (IMAGE_SIZE, IMAGE_SIZE))

    if vol_logits_t is not None:
        vp = F.softmax(vol_logits_t.reshape(-1), dim=0).view(N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
        heat = vp.max(dim=0)[0].cpu().numpy()
        hn = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
        hu = cv2.resize(hn, (IMAGE_SIZE, IMAGE_SIZE))
        hc = cv2.applyColorMap((hu * 255).astype(np.uint8), cv2.COLORMAP_JET)
        hc = cv2.cvtColor(hc, cv2.COLOR_BGR2RGB)
        frame_np = (frame_np * 0.5 + hc * 0.5).astype(np.uint8)

    if pred_px is not None:
        cv2.circle(frame_np, (int(pred_px), int(pred_py)), 8, (255, 0, 0), 3)

    if label:
        cv2.putText(frame_np, label, (5, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
    return frame_np


def main():
    device = torch.device(DEVICE)
    coord_scale = PARA_OUT_SIZE / IMAGE_SIZE

    # Load model
    print("Loading model...")
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        f"{CKPT_DIR}/unet", torch_dtype=torch.float16).to(device)
    unet.eval()

    para_ckpt = torch.load(f"{CKPT_DIR}/para_checkpoint.pt", map_location=device)
    para_heads = ParaHeadsOnUNet().to(device)
    para_heads.load_state_dict(para_ckpt["para_heads"])
    para_heads.eval()

    stats = para_ckpt["stats"]
    min_h, max_h = stats["min_height"], stats["max_height"]
    min_g, max_g = stats["min_gripper"], stats["max_gripper"]

    # Hooks
    captured = {}
    def hook(name):
        def fn(mod, inp, out):
            captured[name] = (out[0] if isinstance(out, tuple) else out).detach().float()
        return fn
    unet.up_blocks[1].register_forward_hook(hook("ub1"))
    unet.up_blocks[2].register_forward_hook(hook("ub2"))

    # Pipeline
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        SVD_BASE, unet=unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(device)

    # Env
    bench = bm.get_benchmark_dict()[BENCHMARK]()
    task = bench.get_task(TASK_ID)
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(TASK_ID))

    with h5py.File(demo_path, "r") as f:
        init_state = f["data/demo_0/states"][0]

    env = OffScreenRenderEnv(bddl_file_name=bddl_file,
                             camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
                             camera_names=["agentview"])
    env.seed(0)
    env.reset()

    # Clean scene
    sim = env.env.sim
    for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try: sim.model.body_pos[sim.model.body_name2id(fname)] = np.array([0, 0, -5.0])
        except: pass
    sim.forward()
    for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
        try:
            bid = sim.model.body_name2id(dn)
            for gid in range(sim.model.ngeom):
                if sim.model.geom_bodyid[gid] == bid:
                    sim.model.geom_rgba[gid][3] = 0.0
        except: pass

    obs = env.set_init_state(init_state)
    for _ in range(5):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    print(f"Task: {task.name}")
    print(f"Running {MAX_REPLANS} replan steps...")

    all_frames = []
    current_gripper = 0.0

    for replan in range(MAX_REPLANS):
        rgb_obs = obs["agentview_image"]
        img_tensor, rgb_float = preprocess(rgb_obs)
        img_tensor = img_tensor.to(device)

        w2c, cam_pose, cam_K = get_cam_params(env.sim, "agentview", IMAGE_SIZE)

        # Generate video
        rgb_pil = Image.fromarray((rgb_float * 255).astype(np.uint8)).resize((SVD_W, SVD_H))
        captured.clear()
        with torch.inference_mode():
            gen_pil = pipe(rgb_pil, height=SVD_H, width=SVD_W,
                          num_frames=N_WINDOW, decode_chunk_size=4,
                          num_inference_steps=25).frames[0]

        # PARA predictions
        feat1 = captured["ub1"].clone()
        feat2 = captured["ub2"].clone()
        with torch.no_grad():
            vol_logits, _, _ = para_heads(feat1, feat2)

        # Decode predictions and interleave visualization + execution
        # For each timestep: show GT (current obs) vs Gen frame, then execute action
        n_t = min(vol_logits.shape[0], N_WINDOW, len(gen_pil))

        episode_done = False
        for t in range(n_t):
            vl = vol_logits[t]  # (Nh, P, P)
            max_over_h, _ = vl.max(dim=0)
            flat_idx = max_over_h.view(-1).argmax()
            py_para = flat_idx // PARA_OUT_SIZE
            px_para = flat_idx % PARA_OUT_SIZE
            h_bin = vl[:, py_para, px_para].argmax()

            bin_centers = torch.linspace(0, 1, N_HEIGHT_BINS, device=device)
            height = bin_centers[h_bin].item() * (max_h - min_h) + min_h

            px_img = px_para.item() / coord_scale
            py_img = py_para.item() / coord_scale

            # Get gripper
            query = torch.tensor([[px_para.item(), py_para.item()]], device=device, dtype=torch.float32)
            _, grip_logits, _ = para_heads(feat1[t:t+1].clone(), feat2[t:t+1].clone(), query_pixels=query)
            grip_val = torch.linspace(0, 1, N_GRIPPER_BINS, device=device)[grip_logits[0].argmax()].item()
            grip_val = grip_val * (max_g - min_g) + min_g

            # 3D target
            kp = np.array([px_img, py_img])
            pos_3d = recover_3d_from_direct_keypoint_and_height(kp, height, cam_pose, cam_K)

            # Execute this action FIRST, then visualize
            new_grip = 1.0 if grip_val > 0 else -1.0
            if pos_3d is not None:
                try:
                    for _ in range(25):
                        cur = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                        delta = pos_3d - cur
                        if np.linalg.norm(delta) < 0.005:
                            break
                        dc = np.clip(delta / 0.05, -1.0, 1.0)
                        sa = np.zeros(7, dtype=np.float32)
                        sa[:3] = dc
                        sa[6] = current_gripper
                        obs, _, done, _ = env.step(sa)
                        if done:
                            episode_done = True
                            break
                except ValueError:
                    episode_done = True

                if not episode_done:
                    # Apply gripper
                    try:
                        ga = np.zeros(7, dtype=np.float32)
                        ga[6] = new_grip
                        obs, _, done, _ = env.step(ga)
                        current_gripper = new_grip
                        if done:
                            episode_done = True
                    except ValueError:
                        episode_done = True

            # Visualization AFTER execution: GT (left) | Gen frame t (right)
            gt_rgb_obs = obs["agentview_image"]
            gt_frame = np.flipud(gt_rgb_obs).copy()
            if gt_frame.shape[0] != IMAGE_SIZE:
                gt_frame = cv2.resize(gt_frame, (IMAGE_SIZE, IMAGE_SIZE))
            gen_np = cv2.resize(np.array(gen_pil[t]), (IMAGE_SIZE, IMAGE_SIZE))

            left = make_overlay(gt_frame, vol_logits[t], px_img, py_img,
                               f"GT replan={replan} t={t}")
            right = make_overlay(gen_np, vol_logits[t], px_img, py_img,
                                f"Gen replan={replan} t={t}")
            combined = np.concatenate([left, right], axis=1)
            all_frames.append(combined)

            if episode_done:
                break

            print(f"    replan={replan} t={t}: target={pos_3d}, grip={grip_val:.2f}")

        if episode_done:
            print(f"  Episode ended at replan {replan} t={t}!")
            end_frame = np.zeros((IMAGE_SIZE, IMAGE_SIZE * 2, 3), dtype=np.uint8)
            cv2.putText(end_frame, "DONE", (IMAGE_SIZE // 2, IMAGE_SIZE // 2),
                       cv2.FONT_HERSHEY_SIMPLEX, 2.0, (0, 255, 0), 4)
            all_frames.extend([end_frame] * 5)
            break

        print(f"  Replan {replan}: executed {n_t} steps")
        torch.cuda.empty_cache()

    # Save video (H.264)
    out_dir = "eval_debug_output"
    os.makedirs(out_dir, exist_ok=True)
    tmp_path = f"{out_dir}/debug_rollout_raw.mp4"
    h264_path = f"{out_dir}/debug_rollout.mp4"
    imageio.mimwrite(tmp_path, all_frames, fps=4, quality=8)
    subprocess.run(["ffmpeg", "-y", "-i", tmp_path, "-c:v", "libx264", "-preset", "ultrafast",
                    "-crf", "23", "-movflags", "+faststart", h264_path],
                   capture_output=True)
    os.remove(tmp_path)
    print(f"\nSaved: {h264_path} ({len(all_frames)} frames)")

    env.close()


if __name__ == "__main__":
    main()
