"""
Generate PARA method rollout: animate through all N_WINDOW prediction timesteps,
then execute the robot through the predicted 3D waypoints. Repeat for each cycle.

For each prediction cycle:
  1. Animate heatmap on frustum: t=0 → keypoint 0, t=1 → keypoint 1, ..., t=3 → keypoint 3
  2. Execute: robot moves through all 4 lifted 3D keypoints
  3. Repeat from new position

Usage:
    export PYTHONPATH=/data/cameron/LIBERO:/data/cameron/para_normalized_losses/libero:$PYTHONPATH
    export DINO_REPO_DIR=/data/cameron/keygrip/dinov3
    export DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth
    python ood_libero/generate_method_rollout.py
"""

import argparse
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from scipy.spatial.transform import Rotation as R

from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm_lib, get_libero_path
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)
import h5py

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'libero'))
import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height

from generate_method_visualization import (
    compute_observer_camera, compute_frustum_corners,
    draw_frustum, draw_camera_icon, render_floating_image,
    project_3d_to_2d, add_text, smoothstep,
)

IMAGE_SIZE = 448
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
FPS = 30

# Timing per window timestep during predict phase
FRAMES_PER_WINDOW_STEP = 8  # ~0.27s per timestep reveal


def preprocess_obs(rgb_obs):
    img = rgb_obs.astype(np.float32) / 255.0
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    return torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0)


def make_heatmap_overlay(rgb, vol_probs, pred_size):
    """Overlay heatmap (from a single timestep's volume) on RGB."""
    heat_2d = vol_probs.max(dim=0)[0].cpu().numpy()
    heat_up = cv2.resize(heat_2d, (IMAGE_SIZE, IMAGE_SIZE))
    heat_norm = (heat_up - heat_up.min()) / (heat_up.max() + 1e-8)
    heat_boosted = np.power(heat_norm, 0.4)
    heat_color = cv2.applyColorMap((heat_boosted * 255).astype(np.uint8), cv2.COLORMAP_PLASMA)
    heat_color_rgb = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
    overlay = np.clip(
        rgb.astype(np.float32) * 0.4 + heat_color_rgb.astype(np.float32) * 0.6,
        0, 255).astype(np.uint8)
    # Draw predicted pixel marker
    flat = heat_2d.argmax()
    py, px = flat // pred_size, flat % pred_size
    scale = IMAGE_SIZE / pred_size
    pu, pv = int((px + 0.5) * scale), int((py + 0.5) * scale)
    cv2.drawMarker(overlay, (pu, pv), (0, 255, 0), cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
    cv2.circle(overlay, (pu, pv), 6, (0, 255, 0), 1, cv2.LINE_AA)
    return overlay


def draw_trajectory_trail(frame, trail_3d, w2c, image_size, current_idx=-1):
    """Draw accumulated trajectory: small connected dots with lines."""
    if len(trail_3d) < 1:
        return
    pts_2d = project_3d_to_2d(np.array(trail_3d), w2c, image_size)

    # Connecting lines
    for i in range(len(pts_2d) - 1):
        if pts_2d[i] is not None and pts_2d[i + 1] is not None:
            n = len(pts_2d)
            age = max(0.25, 1.0 - (n - 1 - i) * 0.04)
            color = (int(60 * age), int(200 * age), int(60 * age))
            cv2.line(frame, pts_2d[i], pts_2d[i + 1], color, 1, cv2.LINE_AA)

    # Dots
    for i, pt in enumerate(pts_2d):
        if pt is None:
            continue
        is_current = (i == current_idx) or (current_idx == -1 and i == len(pts_2d) - 1)
        n = len(pts_2d)
        age = max(0.25, 1.0 - (n - 1 - i) * 0.04)
        if is_current:
            cv2.circle(frame, pt, 3, (0, 255, 100), -1, cv2.LINE_AA)
            cv2.circle(frame, pt, 5, (0, 255, 100), 1, cv2.LINE_AA)
        else:
            color = (int(50 * age), int(160 * age), int(50 * age))
            cv2.circle(frame, pt, 2, color, -1, cv2.LINE_AA)


def draw_window_keypoints(frame, window_pts_3d, w2c, image_size, n_revealed):
    """Draw the current window's predicted keypoints (up to n_revealed).

    Shows them as a mini-trajectory with the latest one highlighted.
    """
    if n_revealed < 1:
        return
    pts = window_pts_3d[:n_revealed]
    pts_2d = project_3d_to_2d(np.array(pts), w2c, image_size)

    # Lines between window keypoints
    for i in range(len(pts_2d) - 1):
        if pts_2d[i] is not None and pts_2d[i + 1] is not None:
            cv2.line(frame, pts_2d[i], pts_2d[i + 1], (0, 255, 180), 1, cv2.LINE_AA)

    # Dots
    for i, pt in enumerate(pts_2d):
        if pt is None:
            continue
        if i == n_revealed - 1:
            # Latest revealed: bright with ring
            cv2.circle(frame, pt, 3, (0, 255, 100), -1, cv2.LINE_AA)
            cv2.circle(frame, pt, 5, (100, 255, 180), 1, cv2.LINE_AA)
        else:
            # Already revealed: small solid
            cv2.circle(frame, pt, 2, (0, 200, 80), -1, cv2.LINE_AA)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str,
                        default="/data/cameron/para/.agents/reports/project_site/media/")
    parser.add_argument("--checkpoint", type=str,
                        default="/data/cameron/para_normalized_losses/libero/checkpoints/para_v2_exp4_n64/best.pth")
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--n_cycles", type=int, default=8,
                        help="Number of predict-execute cycles")
    parser.add_argument("--demo_idx", type=int, default=0)
    parser.add_argument("--image_plane_depth", type=float, default=0.45)
    parser.add_argument("--frames_execute", type=int, default=14,
                        help="Frames for execution phase per cycle")
    args = parser.parse_args()

    device = torch.device(args.device if args.device else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # ── Load model ──
    print("Loading PARA model...")
    ckpt = torch.load(args.checkpoint, map_location=device)
    model_module.MIN_HEIGHT = float(ckpt.get("min_height", model_module.MIN_HEIGHT))
    model_module.MAX_HEIGHT = float(ckpt.get("max_height", model_module.MAX_HEIGHT))
    model_module.MIN_GRIPPER = float(ckpt.get("min_gripper", model_module.MIN_GRIPPER))
    model_module.MAX_GRIPPER = float(ckpt.get("max_gripper", model_module.MAX_GRIPPER))
    if "min_rot" in ckpt:
        model_module.MIN_ROT = ckpt["min_rot"] if isinstance(ckpt["min_rot"], list) else ckpt["min_rot"].tolist()
        model_module.MAX_ROT = ckpt["max_rot"] if isinstance(ckpt["max_rot"], list) else ckpt["max_rot"].tolist()

    vol_w = ckpt["model_state_dict"]["volume_head.weight"]
    n_window = vol_w.shape[0] // N_HEIGHT_BINS
    print(f"  N_WINDOW={n_window}")

    model = TrajectoryHeatmapPredictor(n_window=n_window)
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model = model.to(device).eval()

    # ── Initialize environment ──
    print("Initializing LIBERO environment...")
    benchmark = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = benchmark.get_task(0)
    bddl_file = os.path.join(get_libero_path("bddl_files"),
                              task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
        camera_names=["agentview"],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    # ── Load full demo ──
    demo_path = os.path.join(get_libero_path("datasets"),
                              benchmark.get_task_demonstration(0))
    with h5py.File(demo_path, "r") as f:
        demo_key = f"data/demo_{args.demo_idx}"
        all_states = np.array(f[f"{demo_key}/states"])
    n_demo = len(all_states)
    print(f"  Demo: {n_demo} frames, max {args.n_cycles} cycles")

    # ── Clean scene ──
    for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            sim.model.body_pos[sim.model.body_name2id(name)] = [0, 0, -5.0]
        except:
            pass
    for name in ["akita_black_bowl_2_main", "cookies_1_main",
                  "glazed_rim_porcelain_ramekin_1_main"]:
        try:
            dbid = sim.model.body_name2id(name)
            for gid in range(sim.model.ngeom):
                if sim.model.geom_bodyid[gid] == dbid:
                    sim.model.geom_rgba[gid][3] = 0.0
        except:
            pass
    sim.forward()

    # ── Camera setup ──
    cam_name = "agentview"
    cam_id = sim.model.camera_name2id(cam_name)
    policy_cam_pos = sim.model.cam_pos[cam_id].copy()
    policy_cam_quat = sim.model.cam_quat[cam_id].copy()

    env.set_init_state(all_states[0])
    sim.forward()
    camera_pose = get_camera_extrinsic_matrix(sim, cam_name)
    cam_K_norm = get_camera_intrinsic_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
    cam_K_norm[0] /= IMAGE_SIZE
    cam_K_norm[1] /= IMAGE_SIZE
    cam_K = cam_K_norm.copy()
    cam_K[0] *= IMAGE_SIZE
    cam_K[1] *= IMAGE_SIZE

    # Observer camera
    default_rot = R.from_quat([policy_cam_quat[1], policy_cam_quat[2],
                                policy_cam_quat[3], policy_cam_quat[0]])
    forward_dir = -default_rot.as_matrix()[:, 2]
    t_table = (0.85 - policy_cam_pos[2]) / forward_dir[2]
    scene_center = policy_cam_pos + t_table * forward_dir
    observer_pos, observer_quat = compute_observer_camera(
        policy_cam_pos, policy_cam_quat, scene_center)

    frustum_corners = compute_frustum_corners(
        camera_pose, cam_K, IMAGE_SIZE, args.image_plane_depth)

    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    scale = IMAGE_SIZE / PRED_SIZE

    # ── Helpers ──
    def render_from(demo_fi, cam_p, cam_q):
        sim.model.cam_pos[cam_id] = cam_p
        sim.model.cam_quat[cam_id] = cam_q
        env.set_init_state(all_states[demo_fi])
        sim.forward()
        obs = env.env._get_observations()
        rgb = np.flipud(np.asarray(obs["agentview_image"]).copy())
        bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
        w2c = get_camera_transform_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
        eef = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        return bgr, w2c, rgb, eef

    def extract_keypoint_3d(vol_logits_t):
        """Extract predicted 3D point from a single timestep's volume logits."""
        vol_probs = F.softmax(vol_logits_t.reshape(-1), dim=0).reshape(vol_logits_t.shape)
        heat_2d = vol_probs.max(dim=0)[0].cpu().numpy()
        flat_idx = heat_2d.argmax()
        py = flat_idx // PRED_SIZE
        px = flat_idx % PRED_SIZE
        px_full = (px + 0.5) * scale
        py_full = (py + 0.5) * scale
        h_bin = vol_logits_t[:, py, px].argmax().item()
        height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        pt = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64),
            height, camera_pose, cam_K)
        return pt, vol_probs

    # ── Compute cycle boundaries dynamically ──
    # Each cycle: predict at current frame, find demo frame nearest to last
    # window keypoint, execute until that frame. This ensures the robot
    # visibly moves through ALL predicted keypoints before the next cycle.

    # First, get EEF at every demo frame (for nearest-frame lookup)
    print("  Computing EEF positions for all demo frames...")
    all_eefs = []
    for di in range(n_demo):
        _, _, _, eef_di = render_from(di, policy_cam_pos, policy_cam_quat)
        all_eefs.append(eef_di)
    all_eefs = np.array(all_eefs)

    print("Running PARA inference and computing cycle boundaries...")
    cycle_data = []   # (heatmaps[n_window], keypoints[n_window], kf_start, kf_end)

    current_frame = 3  # start a few frames in
    cycle_idx = 0
    while current_frame < n_demo - 5 and cycle_idx < args.n_cycles:
        kf = current_frame
        _, policy_w2c, policy_rgb, eef = render_from(kf, policy_cam_pos, policy_cam_quat)

        img_tensor = preprocess_obs(policy_rgb).to(device)
        pix_rc = project_points_from_world_to_camera(
            eef.reshape(1, 3), policy_w2c, IMAGE_SIZE, IMAGE_SIZE)[0]
        start_kp = torch.tensor([float(pix_rc[1]), float(pix_rc[0])],
                                 dtype=torch.float32).to(device)
        with torch.no_grad():
            volume_logits, _, _, _ = model(img_tensor, start_kp)

        heatmaps = []
        keypoints = []
        for t in range(n_window):
            pt_3d, vol_probs_t = extract_keypoint_3d(volume_logits[0, t])
            if pt_3d is None:
                pt_3d = eef.copy()
            keypoints.append(pt_3d.copy())
            hm = cv2.cvtColor(
                make_heatmap_overlay(policy_rgb, vol_probs_t, PRED_SIZE),
                cv2.COLOR_RGB2BGR)
            heatmaps.append(hm)

        # Find demo frame nearest to the LAST window keypoint
        last_kp = keypoints[-1]
        # Only search forward from current frame
        search_range = all_eefs[kf:]
        dists = np.linalg.norm(search_range - last_kp, axis=1)
        nearest_offset = np.argmin(dists)
        kf_end = min(kf + max(nearest_offset, 4), n_demo - 1)  # at least 4 frames forward

        cycle_data.append((heatmaps, keypoints, kf, kf_end))
        print(f"  Cycle {cycle_idx}: frames {kf}→{kf_end} "
              f"({kf_end - kf} demo frames), keypoints: "
              + ", ".join(f"[{kp[0]:.2f},{kp[1]:.2f},{kp[2]:.2f}]" for kp in keypoints))

        current_frame = kf_end
        cycle_idx += 1

    n_actual_cycles = len(cycle_data)
    print(f"  Total cycles: {n_actual_cycles}")

    # ── Generate video ──
    frames_predict = n_window * FRAMES_PER_WINDOW_STEP
    frames_per_cycle = frames_predict + args.frames_execute
    total_frames = n_actual_cycles * frames_per_cycle
    print(f"Generating {total_frames} frames "
          f"({n_actual_cycles} cycles × {frames_per_cycle} frames/cycle)...")

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, "para_method_rollout.mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(output_path, fourcc, FPS, (IMAGE_SIZE, IMAGE_SIZE))

    full_trajectory = []  # all keypoints across all cycles
    frame_count = 0

    for cycle_idx in range(n_actual_cycles):
        heatmaps, keypoints, kf_start, kf_end = cycle_data[cycle_idx]

        # ─── Predict phase: reveal window steps one by one ───
        for win_t in range(n_window):
            # Add this window keypoint to the full trajectory
            full_trajectory.append(keypoints[win_t])
            current_heatmap = heatmaps[win_t]

            for fi in range(FRAMES_PER_WINDOW_STEP):
                frame_count += 1
                # Render scene at cycle start (robot hasn't moved yet)
                obs_bgr, obs_w2c, _, _ = render_from(kf_start, observer_pos, observer_quat)
                frame = obs_bgr

                # Frustum + floating heatmap (updates each window step)
                fc_2d = project_3d_to_2d(frustum_corners, obs_w2c, IMAGE_SIZE)
                if all(c is not None for c in fc_2d):
                    pts = np.array(fc_2d, dtype=np.float32)
                    if cv2.contourArea(pts.astype(np.int32)) > 500:
                        render_floating_image(frame, current_heatmap, fc_2d, alpha=0.7)
                draw_camera_icon(frame, policy_cam_pos, obs_w2c, IMAGE_SIZE)
                draw_frustum(frame, policy_cam_pos, frustum_corners, obs_w2c, IMAGE_SIZE, 0.5)

                # Full trajectory trail (past cycles)
                draw_trajectory_trail(frame, full_trajectory, obs_w2c, IMAGE_SIZE)

                # Current window's keypoints revealed so far
                n_revealed = win_t + 1
                draw_window_keypoints(frame, keypoints, obs_w2c, IMAGE_SIZE, n_revealed)

                add_text(frame, f"Cycle {cycle_idx+1}/{n_actual_cycles}: "
                         f"predict t={win_t+1}/{n_window}",
                         position="bottom", font_scale=0.5)
                writer.write(frame)

        # ─── Execute phase: replay demo frames, robot moves through keypoints ───
        demo_frames = np.linspace(kf_start, kf_end, args.frames_execute, dtype=int)
        for di, demo_fi in enumerate(demo_frames):
            frame_count += 1
            obs_bgr, obs_w2c, _, _ = render_from(demo_fi, observer_pos, observer_quat)
            frame = obs_bgr

            # Dimmed frustum during execution
            fc_2d = project_3d_to_2d(frustum_corners, obs_w2c, IMAGE_SIZE)
            if all(c is not None for c in fc_2d):
                pts = np.array(fc_2d, dtype=np.float32)
                if cv2.contourArea(pts.astype(np.int32)) > 500:
                    render_floating_image(frame, heatmaps[-1], fc_2d, alpha=0.3)
            draw_camera_icon(frame, policy_cam_pos, obs_w2c, IMAGE_SIZE)
            draw_frustum(frame, policy_cam_pos, frustum_corners, obs_w2c, IMAGE_SIZE, 0.2)

            # Full trajectory trail
            draw_trajectory_trail(frame, full_trajectory, obs_w2c, IMAGE_SIZE)

            # Show all window keypoints during execution
            draw_window_keypoints(frame, keypoints, obs_w2c, IMAGE_SIZE, n_window)

            add_text(frame, f"Cycle {cycle_idx+1}/{n_actual_cycles}: execute",
                     position="bottom", font_scale=0.5)
            writer.write(frame)

        if frame_count % 50 == 0:
            print(f"  Frame {frame_count}/{total_frames}")

    writer.release()
    print(f"Saved raw video: {output_path}")

    h264_path = output_path.replace(".mp4", "_h264.mp4")
    ret = os.system(
        f'ffmpeg -y -i "{output_path}" -c:v libx264 -preset ultrafast -crf 23 '
        f'-movflags +faststart "{h264_path}" 2>/dev/null'
    )
    if ret == 0:
        os.replace(h264_path, output_path)
        print(f"Re-encoded to H.264: {output_path}")

    env.close()
    print("Done!")


if __name__ == "__main__":
    main()
