"""Generate a simple teaser image from the real 2D heatmap model.

Outputs a 1x3 grid:
  - Left: RGB input
  - Right: RGB with predicted 2D trajectory overlaid
  - Third: Predicted height per timestep (bar chart)
"""

import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor

IMAGE_SIZE = 448
DEFAULT_CHECKPOINT_PATH = "real_dino_tracks/checkpoints/real_tracks/best.pth"


def _denorm_rgb(rgb_bchw: torch.Tensor) -> np.ndarray:
    """(1,3,H,W) normalized -> (H,W,3) float [0,1]."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_bchw.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb_bchw.device).view(1, 3, 1, 1)
    rgb_denorm = (rgb_bchw * std + mean).detach().cpu().numpy()[0]
    return np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)


@torch.no_grad()
def predict_trajectory_xy(
    model: TrajectoryHeatmapPredictor, sample: dict, device: torch.device
) -> tuple[np.ndarray, np.ndarray, np.ndarray, str]:
    """Run model on a single dataset sample and extract argmax 2D trajectory + predicted heights."""
    rgb = sample["rgb"].unsqueeze(0).to(device)  # (1,3,H,W)
    traj2d_gt = sample["trajectory_2d"].cpu().numpy().astype(np.float32)  # (N,2)
    traj3d = sample["trajectory_3d"].cpu().numpy().astype(np.float32)  # (N,3)
    traj_grip = sample["trajectory_gripper"].cpu().numpy().astype(np.float32)  # (N,)
    episode_id = str(sample.get("episode_id", "sample"))

    start_keypoint_2d = torch.tensor(traj2d_gt[0], device=device)  # (2,)
    current_height = torch.tensor(float(traj3d[0, 2]), dtype=torch.float32, device=device)
    current_gripper = torch.tensor(float(traj_grip[0]), dtype=torch.float32, device=device)

    pred_logits, pred_height, _ = model(
        rgb,
        gt_target_heatmap=None,
        training=False,
        start_keypoint_2d=start_keypoint_2d,
        current_height=current_height,
        current_gripper=current_gripper,
    )  # (1,N,H,W)

    pred_xy = np.zeros((N_WINDOW, 2), dtype=np.float32)
    for t in range(N_WINDOW):
        logits_t = pred_logits[0, t]  # (H,W)
        probs_t = F.softmax(logits_t.view(-1), dim=0).view_as(logits_t)
        flat_idx = int(probs_t.view(-1).argmax().item())
        H, W = probs_t.shape
        y = flat_idx // W
        x = flat_idx % W
        pred_xy[t] = (float(x), float(y))

    pred_height_m = pred_height[0].detach().cpu().numpy().astype(np.float32)  # (N_WINDOW,)
    rgb_vis = _denorm_rgb(rgb)
    return rgb_vis, pred_xy, pred_height_m, episode_id


def main():
    parser = argparse.ArgumentParser(description="Generate teaser image (1 sample, 1 episode)")
    parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT_PATH, help="Path to model checkpoint")
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_moredata_pickplace_home", help="Dataset root")
    parser.add_argument("--episode", type=str, default=None, help="Optional episode name (e.g. episode_001)")
    parser.add_argument("--sample_idx", type=int, default=0, help="Sample index within the loaded episode(s)")
    parser.add_argument(
        "--save_path",
        type=str,
        default="scratch/real_teaser.png",
        help="Where to save the teaser image",
    )
    parser.add_argument("--dont_show", action="store_true", help="Don't open a window; just save")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    # Load dataset-wide normalization ranges into the model module.
    # Without this, MIN_HEIGHT/MAX_HEIGHT can remain at their file defaults (often equal),
    # which forces predicted height to be constant.
    import model as model_module
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
        print(f"✓ Loaded height range from checkpoint: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
        print(f"✓ Loaded gripper range from checkpoint: [{model_module.MIN_GRIPPER:.3f}, {model_module.MAX_GRIPPER:.3f}]")
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device).eval()
    print(f"✓ Loaded model from epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=1 if args.episode is None else None,
    )
    if len(dataset) == 0:
        raise RuntimeError("Dataset returned 0 samples.")
    idx = int(np.clip(args.sample_idx, 0, len(dataset) - 1))
    sample = dataset[idx]

    rgb_vis, pred_xy, pred_height_m, episode_id = predict_trajectory_xy(model, sample, device)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    ax0, ax1, ax2 = axes

    ax0.imshow(rgb_vis)
    ax0.set_title("RGB input", fontsize=12, fontweight="bold")
    ax0.axis("off")

    ax1.imshow(rgb_vis)
    ax1.plot(pred_xy[:, 0], pred_xy[:, 1], "-", color="lime", linewidth=2, alpha=0.9)
    ax1.scatter(pred_xy[:, 0], pred_xy[:, 1], s=18, c="lime", edgecolors="black", linewidths=0.5)
    ax1.scatter(pred_xy[0, 0], pred_xy[0, 1], s=60, c="cyan", edgecolors="black", linewidths=0.8)
    ax1.set_title("Predicted trajectory", fontsize=12, fontweight="bold")
    ax1.axis("off")

    # Height bar chart (horizontal): time on Y axis
    height_mm = pred_height_m * 1000.0
    t_idx = np.arange(N_WINDOW)
    ax2.barh(t_idx, height_mm, color="#4C78A8", alpha=0.9)
    ax2.set_title("Predicted height (mm)", fontsize=12, fontweight="bold")
    ax2.set_ylabel("t")
    ax2.set_xlabel("mm")
    ax2.grid(True, axis="x", alpha=0.25)
    ax2.set_ylim(-0.6, N_WINDOW - 0.4)
    ax2.set_yticks(t_idx)
    ax2.invert_yaxis()
    h_min = float(np.min(height_mm))
    h_max = float(np.max(height_mm))
    pad = max(1.0, 0.1 * (h_max - h_min + 1e-6))
    ax2.set_xlim(h_min - pad, h_max + pad)

    fig.suptitle(f"{episode_id} (idx={idx})", fontsize=11)
    fig.tight_layout()

    save_path = Path(args.save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(save_path, dpi=200, bbox_inches="tight")
    print(f"✓ Saved teaser image to {save_path}")

    if not args.dont_show:
        plt.show()
    plt.close(fig)


if __name__ == "__main__":
    main()
