"""Train the autoregressive transformer policy (model_autoregressive.py) on LIBERO.

Minimal sibling to train.py — avoids tangling with the existing 7-way model_type
branching. Teacher-forced single-step prediction over an 8-frame history window.

Usage:
  cd /data/cameron/para/libero
  CUDA_VISIBLE_DEVICES=9 \
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
  DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
  python train_ar.py \
    --cache_root /data/libero/parsed_libero \
    --benchmark libero_spatial --task_id 0 \
    --run_name ar_libero_spatial_t0_h8_v0 \
    --history_len 8 --batch_size 2 --lr 1e-4 --epochs 50
"""
import argparse
import io
import os
import sys
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))

from data_ar import HistoryTrajectoryDataset, target_xy_to_grid_idx, grid_idx_to_pixel
from model_autoregressive import ARTransformerPolicy, IMAGE_SIZE

# Visualization import is heavy (matplotlib) — defer until first use.
_plt = None
def _lazy_plt():
    global _plt
    if _plt is None:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        _plt = plt
    return _plt


def overlay_pred_gt(img_chw_normalized, history_eef_xy, target_xy, pred_xy, image_size, title):
    """Build a (H, W, 3) uint8 RGB overlay PNG buffer for wandb.

    img_chw_normalized: (3, H, W) torch tensor (ImageNet-normalized) — we'll denormalize.
    history_eef_xy: (Hh, 2) numpy past EEF pixels.
    target_xy: (2,) numpy GT next-step pixel.
    pred_xy: (2,) numpy predicted next-step pixel.
    """
    plt = _lazy_plt()
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    img = img_chw_normalized.detach().cpu().numpy() * std + mean
    img = np.clip(img.transpose(1, 2, 0), 0, 1)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)
    ax.imshow(img)
    if len(history_eef_xy) > 1:
        ax.plot(history_eef_xy[:, 0], history_eef_xy[:, 1], '-', color="white", linewidth=1.2, alpha=0.7, label="history")
    ax.scatter(history_eef_xy[-1, 0], history_eef_xy[-1, 1], s=70, c="white", marker="o", label="current EEF")
    ax.scatter(target_xy[0], target_xy[1], s=110, c="lime", marker="*", label="GT next")
    ax.scatter(pred_xy[0], pred_xy[1],     s=110, c="red",  marker="x", label="pred next")
    ax.set_xlim(0, image_size); ax.set_ylim(image_size, 0)
    ax.set_title(title, fontsize=9)
    ax.legend(fontsize=7, loc="upper right")
    ax.axis("off")
    buf = io.BytesIO()
    fig.tight_layout()
    fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)
    buf.seek(0)
    import imageio.v2 as imageio
    arr = imageio.imread(buf)
    if arr.ndim == 3 and arr.shape[-1] == 4:
        arr = arr[..., :3]  # drop alpha
    return arr


def log_vis(model, dataset, device, step, n_samples=3, split="val"):
    """Sample a few val examples and log overlay strip to wandb."""
    model.eval()
    tiles = []
    with torch.no_grad():
        idxs = np.linspace(0, len(dataset) - 1, n_samples).astype(int)
        for i in idxs:
            sample = dataset[int(i)]
            imgs = sample["history_imgs"].unsqueeze(0).to(device)
            eef  = sample["history_eef_xy"].unsqueeze(0).to(device)
            tgt  = sample["target_eef_xy"].numpy()
            logits = model(imgs, eef)                                       # (1, grid^2)
            pred_idx = logits.argmax(dim=-1)                                # (1,)
            pred_xy = grid_idx_to_pixel(pred_idx, IMAGE_SIZE, model.grid_size).cpu().numpy()[0]
            tile = overlay_pred_gt(
                sample["history_imgs"][-1],
                sample["history_eef_xy"].numpy(),
                tgt, pred_xy, IMAGE_SIZE,
                title=f"{split} demo={int(sample['demo_idx'])} t={int(sample['start_t'])}",
            )
            tiles.append(tile)
    # Pad to same height
    max_h = max(t.shape[0] for t in tiles)
    max_w = max(t.shape[1] for t in tiles)
    pad_tiles = []
    for t in tiles:
        h, w = t.shape[:2]
        pad = np.zeros((max_h, max_w, 3), dtype=t.dtype)
        pad[:h, :w] = t
        pad_tiles.append(pad)
    strip = np.concatenate(pad_tiles, axis=1)
    wandb.log({f"vis/{split}_overlay": wandb.Image(strip)}, step=step)
    model.train()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_root", type=str, required=True)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--max_demos", type=int, default=0, help="0 = all")
    parser.add_argument("--history_len", type=int, default=8)
    parser.add_argument("--frame_stride", type=int, default=1)
    parser.add_argument("--grid_size", type=int, default=56)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--max_steps", type=int, default=0, help="Stop after N steps (0 = no cap)")
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--vis_every_steps", type=int, default=200)
    parser.add_argument("--log_scalars_every", type=int, default=10)
    parser.add_argument("--run_name", type=str, default="ar_libero_spatial_t0_h8_v0")
    parser.add_argument("--wandb_project", type=str, default="para_libero")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"])
    parser.add_argument("--freeze_backbone", action="store_true", default=True)
    parser.add_argument("--no_freeze_backbone", action="store_false", dest="freeze_backbone")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(
        project=args.wandb_project,
        name=args.run_name,
        mode=args.wandb_mode,
        config=vars(args),
    )

    print("\nLoading dataset...")
    full = HistoryTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=[args.task_id],
        image_size=IMAGE_SIZE,
        history_len=args.history_len,
        frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    n = len(full)
    n_val = max(1, int(n * args.val_split))
    n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42)
    )
    print(f"  Train: {n_tr}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)

    print("\nBuilding model...")
    model = ARTransformerPolicy(
        target_size=IMAGE_SIZE,
        history_len=args.history_len,
        grid_size=args.grid_size,
        freeze_backbone=args.freeze_backbone,
    ).to(device)
    n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {n_train:,} / {n_total:,} ({100*n_train/n_total:.2f}%)")
    wandb.config.update({"trainable_params": n_train, "total_params": n_total})

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr, weight_decay=1e-4,
    )

    best_val_loss = float("inf")
    global_step = 0
    t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} train", leave=False)
        train_losses = []
        for batch in pbar:
            imgs   = batch["history_imgs"].to(device, non_blocking=True)     # (B, H, 3, 448, 448)
            eef    = batch["history_eef_xy"].to(device, non_blocking=True)   # (B, H, 2)
            target = batch["target_eef_xy"].to(device, non_blocking=True)    # (B, 2)
            tgt_idx = target_xy_to_grid_idx(target, IMAGE_SIZE, args.grid_size)  # (B,)

            logits = model(imgs, eef)                                         # (B, grid^2)
            loss = F.cross_entropy(logits, tgt_idx)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            pbar.set_postfix(loss=f"{loss.item():.4f}")
            global_step += 1
            if global_step % args.log_scalars_every == 0:
                # Distance-based diagnostic: pred pixel vs GT pixel.
                with torch.no_grad():
                    pred_xy = grid_idx_to_pixel(logits.argmax(dim=-1), IMAGE_SIZE, args.grid_size)
                    pix_err = (pred_xy - target).norm(dim=-1).mean().item()
                wandb.log({
                    "train/loss": loss.item(),
                    "train/pixel_error": pix_err,
                    "epoch": epoch,
                }, step=global_step)
            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                log_vis(model, val_ds, device, global_step, n_samples=3, split="val")
                log_vis(model, train_ds, device, global_step, n_samples=3, split="train")
            if args.max_steps and global_step >= args.max_steps:
                break

        if not train_losses:
            print("⚠ no training batches this epoch")
            continue

        # ---------- validation ---------- #
        model.eval()
        val_losses = []
        val_pix_errs = []
        with torch.no_grad():
            for batch in val_loader:
                imgs = batch["history_imgs"].to(device, non_blocking=True)
                eef  = batch["history_eef_xy"].to(device, non_blocking=True)
                target = batch["target_eef_xy"].to(device, non_blocking=True)
                tgt_idx = target_xy_to_grid_idx(target, IMAGE_SIZE, args.grid_size)
                logits = model(imgs, eef)
                val_losses.append(F.cross_entropy(logits, tgt_idx).item())
                pred_xy = grid_idx_to_pixel(logits.argmax(dim=-1), IMAGE_SIZE, args.grid_size)
                val_pix_errs.append((pred_xy - target).norm(dim=-1).mean().item())

        train_loss = float(np.mean(train_losses))
        val_loss   = float(np.mean(val_losses))
        val_err    = float(np.mean(val_pix_errs))
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}  val_loss={val_loss:.4f}  val_px_err={val_err:.1f}px")

        wandb.log({
            "epoch_end/train_loss": train_loss,
            "epoch_end/val_loss":   val_loss,
            "epoch_end/val_pixel_error": val_err,
            "epoch": epoch,
            "elapsed_min": (time.time() - t_start) / 60.0,
        }, step=global_step)

        ckpt = {
            "epoch": epoch,
            "global_step": global_step,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_loss": val_loss,
            "args": vars(args),
        }
        torch.save(ckpt, ckpt_dir / "latest.pth")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(ckpt, ckpt_dir / "best.pth")
            print(f"  ✓ best (val_loss={val_loss:.4f})")

        if args.max_steps and global_step >= args.max_steps:
            print(f"Hit max_steps={args.max_steps}; stopping.")
            break

    wandb.finish()
    print("\nDone.")
    print(f"Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
