"""Visualize trajectory predictions on the real dataset (2D: GT vs predicted trajectory).

Uses the volume model: volume_logits + gripper_logits; extracts pred 2D/height/gripper
and plots GT trajectory vs predicted trajectory for a given episode and start frame.
"""

import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS
import model as model_module

# Default checkpoint: volume_dino_tracks (volume + per-pixel gripper heads)
IMAGE_SIZE = 448
_SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT_PATH = str(_SCRIPT_DIR / "checkpoints" / "best.pth")


def _denorm_rgb(rgb_bchw: torch.Tensor) -> np.ndarray:
    """(1,3,H,W) normalized -> (H,W,3) float [0,1]."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_bchw.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb_bchw.device).view(1, 3, 1, 1)
    rgb_denorm = (rgb_bchw * std + mean).detach().cpu().numpy()[0]
    return np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep."""
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep."""
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    logits_at_pixels = gripper_logits[batch_idx, time_idx, :, py, px]
    return logits_at_pixels


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits to continuous values in [MIN_GRIPPER, MAX_GRIPPER]."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    bin_indices = bin_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)
    normalized = bin_centers[bin_indices]
    return normalized * (max_g - min_g) + min_g


def run_volume_model(model, rgb, start_keypoint_2d, device):
    """Run volume model and return pred_2d, pred_height, pred_gripper, volume_logits (for heatmap)."""
    with torch.no_grad():
        volume_logits, gripper_logits = model(
            rgb,
            gt_target_heatmap=None,
            training=False,
            start_keypoint_2d=start_keypoint_2d,
            current_height=None,
            current_gripper=None,
        )
    pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
    gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)
    pred_gripper = decode_gripper_bins(gripper_logits_at_pred)
    return pred_2d, pred_height, pred_gripper, volume_logits


MAX_EPISODES_PER_ROW = 5


def visualize_gt_vs_pred_2d(model, dataset, indices, title, save_path=None):
    """Axes: x = episode (max 5 per row), y = GT vs Pred. Wraps to new row-pairs after every 5 episodes."""
    device = next(model.parameters()).device
    model.eval()

    n_samples = len(indices)
    n_cols = min(MAX_EPISODES_PER_ROW, n_samples)
    n_row_blocks = (n_samples + n_cols - 1) // n_cols  # ceil(n_samples / n_cols)
    n_rows = n_row_blocks * 2  # each block = one GT row + one Pred row

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
    if n_rows == 1:
        axes = axes[np.newaxis, :]
    if n_cols == 1:
        axes = axes[:, np.newaxis]

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        rgb = sample["rgb"].unsqueeze(0).to(device)
        trajectory_2d = sample["trajectory_2d"].cpu().numpy()  # (N_WINDOW, 2)
        start_keypoint_2d = torch.tensor(trajectory_2d[0], device=device, dtype=torch.float32)

        pred_2d, _, _, _ = run_volume_model(model, rgb, start_keypoint_2d, device)
        pred_2d_np = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2)
        rgb_vis = _denorm_rgb(rgb)

        block = i // n_cols
        col = i % n_cols
        row_gt = block * 2
        row_pred = block * 2 + 1

        ax_gt = axes[row_gt, col]
        ax_gt.imshow(rgb_vis)
        ax_gt.plot(trajectory_2d[:, 0], trajectory_2d[:, 1], "w-", linewidth=2, alpha=0.9, label="GT")
        ax_gt.scatter(trajectory_2d[:, 0], trajectory_2d[:, 1], c="white", s=60, marker="o", edgecolors="black", linewidths=1.5, zorder=10)
        ax_gt.set_title(f"Ep {i}", fontsize=11, fontweight="bold")
        if col == 0:
            ax_gt.set_ylabel("GT", fontsize=10, fontweight="bold")
        ax_gt.axis("off")
        ax_gt.legend(loc="upper right", fontsize=8)

        ax_pred = axes[row_pred, col]
        ax_pred.imshow(rgb_vis)
        ax_pred.plot(pred_2d_np[:, 0], pred_2d_np[:, 1], color="lime", linewidth=2, alpha=0.9, label="Pred")
        ax_pred.scatter(pred_2d_np[:, 0], pred_2d_np[:, 1], c="lime", s=60, marker="x", linewidths=2, zorder=10)
        mean_px_err = float(np.mean([np.linalg.norm(trajectory_2d[t] - pred_2d_np[t]) for t in range(N_WINDOW)]))
        ax_pred.set_title(f"Ep {i} (err: {mean_px_err:.1f}px)", fontsize=11, fontweight="bold")
        if col == 0:
            ax_pred.set_ylabel("Pred", fontsize=10, fontweight="bold")
        ax_pred.axis("off")
        ax_pred.legend(loc="upper right", fontsize=8)

    # Hide empty cells in the last block
    for j in range(n_samples, n_cols * n_row_blocks):
        block = j // n_cols
        col = j % n_cols
        axes[block * 2, col].axis("off")
        axes[block * 2 + 1, col].axis("off")

    plt.suptitle(title, fontsize=14, fontweight="bold")
    plt.subplots_adjust(hspace=0.08, wspace=0.04, top=0.94)
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        print(f"Saved: {save_path}")
    else:
        plt.show()
    plt.close()


def compute_metrics(model, dataset, indices=None):
    """Compute pixel / height / gripper errors over samples (optionally only at indices)."""
    device = next(model.parameters()).device
    model.eval()
    if indices is None:
        indices = list(range(len(dataset)))

    total_pixel = 0.0
    total_height = 0.0
    total_gripper = 0.0
    n = 0
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            rgb = sample["rgb"].unsqueeze(0).to(device)
            trajectory_2d = sample["trajectory_2d"]
            trajectory_3d = sample["trajectory_3d"]
            trajectory_gripper = sample["trajectory_gripper"]
            start_keypoint_2d = trajectory_2d[0].to(device)

            pred_2d, pred_height, pred_gripper, _ = run_volume_model(model, rgb, start_keypoint_2d, device)
            pixel_errs = []
            height_errs = []
            gripper_errs = []
            for t in range(N_WINDOW):
                pred_xy = pred_2d[0, t].cpu()
                target_xy = trajectory_2d[t]
                pixel_errs.append(torch.norm(pred_xy - target_xy).item())
                height_errs.append(torch.abs(pred_height[0, t].cpu() - trajectory_3d[t, 2]).item())
                gripper_errs.append(torch.abs(pred_gripper[0, t].cpu() - trajectory_gripper[t]).item())
            total_pixel += np.mean(pixel_errs)
            total_height += np.mean(height_errs)
            total_gripper += np.mean(gripper_errs)
            n += 1

    return {
        "n_samples": n,
        "avg_pixel_error_px": total_pixel / n if n else 0,
        "avg_height_error_mm": (total_height / n * 1000) if n else 0,
        "avg_gripper_abs_error": (total_gripper / n) if n else 0,
    }


def main():
    parser = argparse.ArgumentParser(description="Visualize GT vs predicted 2D trajectory (volume model)")
    parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT_PATH, help="Path to model checkpoint")
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_school_cap", help="Root of real dataset")
    parser.add_argument("--episode", type=str, default=None, help="Episode name (e.g. episode_001)")
    parser.add_argument("--start_frame", type=int, default=None, help="Start frame index (e.g. 0). With --episode, show only this sample.")
    parser.add_argument("--max_episodes", type=int, default=1, help="If no episode/start_frame, load first K episodes")
    parser.add_argument("--num_samples", type=int, default=6, help="Number of samples to visualize when not using episode+start_frame")
    parser.add_argument("--save_dir", type=str, default="scratch/real_eval_vis", help="Directory to save figures")
    parser.add_argument("--metrics_only", action="store_true", help="Only print metrics, no plot")
    parser.add_argument("--debug", action="store_true", help="Drop into pdb before viz")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
        print(f"Height range: [{model_module.MIN_HEIGHT*1000:.2f}, {model_module.MAX_HEIGHT*1000:.2f}] mm")
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
        print(f"Gripper range: [{model_module.MIN_GRIPPER:.3f}, {model_module.MAX_GRIPPER:.3f}]")
    model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    model = model.to(device)
    model.eval()
    print(f"Loaded epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=None if args.episode else args.max_episodes,
    )
    print(f"Total samples: {len(dataset)}")

    if args.episode and args.start_frame is not None:
        idx = None
        for i, (ep_dir, frame_idx) in enumerate(dataset.samples):
            if ep_dir.name == args.episode and frame_idx == args.start_frame:
                idx = i
                break
        if idx is None:
            raise ValueError(f"No sample with episode={args.episode} and start_frame={args.start_frame}")
        indices = [idx]
        title = f"GT vs Pred 2D — {args.episode} frame {args.start_frame}"
    else:
        num_samples = min(args.num_samples, len(dataset))
        indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int).tolist()
        title = "GT vs Pred 2D — sample sweep"

    print("\nComputing metrics...")
    metrics = compute_metrics(model, dataset, indices)
    print(f"Samples: {metrics['n_samples']}")
    print(f"Avg Pixel Error: {metrics['avg_pixel_error_px']:.2f} px")
    print(f"Avg Height Error: {metrics['avg_height_error_mm']:.3f} mm")
    print(f"Avg Gripper Abs Error: {metrics['avg_gripper_abs_error']:.4f}")

    if not args.metrics_only:
        if args.debug:
            import pdb; pdb.set_trace()
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        save_path = save_dir / "gt_vs_pred_2d.png"
        if args.episode and args.start_frame is not None:
            save_path = save_dir / f"gt_vs_pred_{args.episode}_frame{args.start_frame}.png"
        visualize_gt_vs_pred_2d(model, dataset, indices, title)


if __name__ == "__main__":
    main()
