"""Train the v2 two-stage AR policy with multi-target supervision.

For each W-frame window, one DINO call extracts patches, then K = W - H ARHead calls
supervise next-EEF prediction at each valid target step. Net effect: ~10× more gradient
per DINO call vs. train_ar.py (v1).

Usage:
  cd /data/cameron/para/libero
  CUDA_VISIBLE_DEVICES=9 \
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
  DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
  python train_ar_v2.py \
    --cache_root /data/libero/parsed_libero --benchmark libero_spatial --task_id 0 \
    --window_len 20 --history_len 8 --batch_size 2 --lr 1e-4 --epochs 10 \
    --run_name ar_v2_libero_spatial_t0_w20h8
"""
import argparse, io, os, sys, time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))

from data_ar import WindowTrajectoryDataset, target_xy_to_grid_idx, grid_idx_to_pixel
from model_autoregressive_v2 import ARTransformerPolicyV2, IMAGE_SIZE, N_HEIGHT_BINS, N_ROT_BINS

# Pragmatic LIBERO defaults (covers spatial / goal). Overridden by stats if --compute_stats.
MIN_HEIGHT = 0.85
MAX_HEIGHT = 1.55
MIN_ROT = [-3.14159, -3.14159, -3.14159]
MAX_ROT = [ 3.14159,  3.14159,  3.14159]
# Loss weights (rough match to existing PARA train.py)
W_XY      = 1.0
W_HEIGHT  = 0.5
W_GRIPPER = 5.0
W_ROT     = 0.5


def discretize_height(h, min_h=MIN_HEIGHT, max_h=MAX_HEIGHT, n_bins=N_HEIGHT_BINS):
    norm = ((h - min_h) / (max_h - min_h + 1e-8)).clamp(0, 1)
    return (norm * (n_bins - 1)).long().clamp(0, n_bins - 1)


def discretize_rotation(euler, min_r=None, max_r=None, n_bins=N_ROT_BINS):
    """euler: (..., 3); returns (..., 3) bin indices."""
    if min_r is None: min_r = MIN_ROT
    if max_r is None: max_r = MAX_ROT
    min_t = torch.tensor(min_r, device=euler.device, dtype=torch.float32)
    max_t = torch.tensor(max_r, device=euler.device, dtype=torch.float32)
    norm = ((euler - min_t) / (max_t - min_t + 1e-8)).clamp(0, 1)
    return (norm * (n_bins - 1)).long().clamp(0, n_bins - 1)

_plt = None
def _lazy_plt():
    global _plt
    if _plt is None:
        import matplotlib; matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        _plt = plt
    return _plt


def overlay_pred_gt(img_chw_normalized, hist_eef_xy, target_xy, pred_xy, image_size, title):
    plt = _lazy_plt()
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    img = img_chw_normalized.detach().cpu().numpy() * std + mean
    img = np.clip(img.transpose(1, 2, 0), 0, 1)
    fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)
    ax.imshow(img)
    if len(hist_eef_xy) > 1:
        ax.plot(hist_eef_xy[:, 0], hist_eef_xy[:, 1], '-', color="white", linewidth=1.2, alpha=0.7, label="history")
    ax.scatter(hist_eef_xy[-1, 0], hist_eef_xy[-1, 1], s=70, c="white", marker="o", label="current")
    ax.scatter(target_xy[0], target_xy[1], s=110, c="lime", marker="*", label="GT")
    ax.scatter(pred_xy[0], pred_xy[1],     s=110, c="red",  marker="x", label="pred")
    ax.set_xlim(0, image_size); ax.set_ylim(image_size, 0)
    ax.set_title(title, fontsize=9); ax.legend(fontsize=7, loc="upper right"); ax.axis("off")
    buf = io.BytesIO(); fig.tight_layout(); fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0.1)
    plt.close(fig); buf.seek(0)
    import imageio.v2 as imageio
    arr = imageio.imread(buf)
    if arr.ndim == 3 and arr.shape[-1] == 4: arr = arr[..., :3]
    return arr


def log_vis(model, dataset, device, step, history_len, grid_size, n_samples=3, split="val"):
    model.eval()
    tiles = []
    with torch.no_grad():
        idxs = np.linspace(0, len(dataset) - 1, n_samples).astype(int)
        for i in idxs:
            sample = dataset[int(i)]
            W = sample["window_imgs"].shape[0]
            imgs = sample["window_imgs"].unsqueeze(0).to(device)
            eef  = sample["window_eef_xy"].unsqueeze(0).to(device)
            # Pick target step at the END of the window (most demanding case)
            t = W - 1
            hist_imgs = imgs[:, t - history_len : t]
            hist_eef  = eef[:,  t - history_len : t]
            target_xy = sample["window_eef_xy"][t].numpy()
            out_d = model(hist_imgs, hist_eef)
            pred_xy = grid_idx_to_pixel(out_d["xy_logits"].argmax(dim=-1), IMAGE_SIZE, grid_size).cpu().numpy()[0]
            tile = overlay_pred_gt(
                imgs[0, t-1], hist_eef[0].cpu().numpy(),
                target_xy, pred_xy, IMAGE_SIZE,
                title=f"{split} demo={int(sample['demo_idx'])} t={int(sample['start_t']) + t}",
            )
            tiles.append(tile)
    max_h = max(t.shape[0] for t in tiles); max_w = max(t.shape[1] for t in tiles)
    pad_tiles = []
    for t in tiles:
        h, w = t.shape[:2]
        pad = np.zeros((max_h, max_w, 3), dtype=t.dtype); pad[:h, :w] = t
        pad_tiles.append(pad)
    strip = np.concatenate(pad_tiles, axis=1)
    wandb.log({f"vis/{split}_overlay": wandb.Image(strip)}, step=step)
    model.train()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_root", type=str, required=True)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--max_demos", type=int, default=0)
    parser.add_argument("--window_len", type=int, default=20)
    parser.add_argument("--history_len", type=int, default=8)
    parser.add_argument("--frame_stride", type=int, default=1)
    parser.add_argument("--grid_size", type=int, default=56)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--max_steps", type=int, default=0)
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--vis_every_steps", type=int, default=200)
    parser.add_argument("--log_scalars_every", type=int, default=10)
    parser.add_argument("--run_name", type=str, default="ar_v2_libero_spatial_t0_w20h8")
    parser.add_argument("--wandb_project", type=str, default="para_libero")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"])
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    assert args.window_len > args.history_len, "W must exceed H for multi-target supervision"

    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("\nLoading dataset...")
    full = WindowTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=[args.task_id],
        image_size=IMAGE_SIZE,
        window_len=args.window_len,
        frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42)
    )
    print(f"  Train: {n_tr}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)

    print("\nBuilding v2 model...")
    model = ARTransformerPolicyV2(
        target_size=IMAGE_SIZE, history_len=args.history_len,
        grid_size=args.grid_size, freeze_backbone=True,
    ).to(device)
    n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {n_train:,} / {n_total:,}")
    wandb.config.update({"trainable_params": n_train, "total_params": n_total})

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.lr, weight_decay=1e-4)

    H = args.history_len; W = args.window_len; G = args.grid_size
    target_steps = list(range(H, W))  # 12 targets at W=20, H=8

    def run_window(batch, train_mode):
        """One window forward — DINO once, K iterative ARHead calls with per-target backward
        (so only one ARHead graph in memory at a time). Gradients accumulate into the optimizer's
        parameters; caller is responsible for optimizer.step() + zero_grad().

        Returns (mean_loss_value, mean_pix_err, K_used).
        """
        imgs    = batch["window_imgs"].to(device, non_blocking=True)            # (B, W, 3, 448, 448)
        eef_xy  = batch["window_eef_xy"].to(device, non_blocking=True)          # (B, W, 2)
        eef_pos = batch["window_eef_pos"].to(device, non_blocking=True)         # (B, W, 3)
        eef_eul = batch["window_eef_euler"].to(device, non_blocking=True)       # (B, W, 3)
        grip    = batch["window_gripper"].to(device, non_blocking=True)         # (B, W)
        valid   = batch["valid_mask"].to(device, non_blocking=True)             # (B, W)

        if train_mode:
            patches = model.patch_encoder(imgs)
        else:
            with torch.no_grad():
                patches = model.patch_encoder(imgs)

        valid_targets = [t for t in target_steps if valid[:, t].any()]
        K = len(valid_targets)
        if K == 0:
            return None, None, 0

        loss_value_sum = 0.0
        pix_err_sum = 0.0
        for i, t in enumerate(valid_targets):
            hist_p = patches[:, t - H : t]
            hist_e = eef_xy[:, t - H : t]
            anchor = hist_e[:, -1]
            target_xy = eef_xy[:, t]
            target_h_bin   = discretize_height(eef_pos[:, t, 2])              # (B,)
            target_g_bin   = (grip[:, t] > 0).float()                           # (B,) BCE target
            target_rot_bin = discretize_rotation(eef_eul[:, t])                 # (B, 3)
            tgt_idx = target_xy_to_grid_idx(target_xy, IMAGE_SIZE, G)
            t_valid = valid[:, t]
            mask_f = t_valid.float()
            denom = mask_f.sum().clamp_min(1.0)

            out = model.ar_head(hist_p, hist_e, anchor, IMAGE_SIZE)
            ce_xy = (F.cross_entropy(out["xy_logits"], tgt_idx, reduction="none") * mask_f).sum() / denom
            ce_h  = (F.cross_entropy(out["height_logits"], target_h_bin, reduction="none") * mask_f).sum() / denom
            bce_g = (F.binary_cross_entropy_with_logits(out["gripper_logit"], target_g_bin, reduction="none") * mask_f).sum() / denom
            # rotation: per-axis CE, averaged across 3 axes
            rot_loss_axes = []
            for axis in range(3):
                rot_loss_axes.append(
                    (F.cross_entropy(out["rotation_logits"][:, axis], target_rot_bin[:, axis], reduction="none") * mask_f).sum() / denom
                )
            ce_r = torch.stack(rot_loss_axes).mean()

            term_loss = (W_XY * ce_xy + W_HEIGHT * ce_h + W_GRIPPER * bce_g + W_ROT * ce_r)
            scaled = term_loss / K

            if train_mode:
                scaled.backward(retain_graph=(i < K - 1))

            loss_value_sum += term_loss.item()
            with torch.no_grad():
                pred_xy = grid_idx_to_pixel(out["xy_logits"].argmax(dim=-1), IMAGE_SIZE, G)
                pix_err_sum += ((pred_xy - target_xy).norm(dim=-1) * mask_f).sum().item() / denom.item()

        return loss_value_sum / K, pix_err_sum / K, K

    best_val_loss = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} train", leave=False)
        tr_losses = []
        for batch in pbar:
            optimizer.zero_grad()
            loss_val, pix_err, n_terms = run_window(batch, train_mode=True)
            if loss_val is None: continue
            optimizer.step()
            tr_losses.append(loss_val)
            pbar.set_postfix(loss=f"{loss_val:.4f}", px=f"{pix_err:.1f}", K=n_terms)
            global_step += 1
            if global_step % args.log_scalars_every == 0:
                wandb.log({"train/loss": loss_val, "train/pixel_error": pix_err,
                           "train/n_targets": n_terms, "epoch": epoch}, step=global_step)
            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                log_vis(model, val_ds, device, global_step, H, G, n_samples=3, split="val")
                log_vis(model, train_ds, device, global_step, H, G, n_samples=3, split="train")
            if args.max_steps and global_step >= args.max_steps:
                break

        if not tr_losses:
            print(f"⚠ epoch {epoch}: no train batches"); continue

        # validation
        model.eval()
        val_losses = []; val_pix = []
        with torch.no_grad():
            for batch in val_loader:
                l, pe, _ = run_window(batch, train_mode=False)
                if l is None: continue
                val_losses.append(l); val_pix.append(pe)
        tr_loss = float(np.mean(tr_losses)); val_loss = float(np.mean(val_losses)) if val_losses else float("inf")
        val_err = float(np.mean(val_pix)) if val_pix else float("inf")
        print(f"Epoch {epoch}: train_loss={tr_loss:.4f}  val_loss={val_loss:.4f}  val_px_err={val_err:.1f}px")
        wandb.log({"epoch_end/train_loss": tr_loss, "epoch_end/val_loss": val_loss,
                   "epoch_end/val_pixel_error": val_err, "epoch": epoch,
                   "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        ckpt = {"epoch": epoch, "global_step": global_step,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss, "args": vars(args)}
        torch.save(ckpt, ckpt_dir / "latest.pth")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(ckpt, ckpt_dir / "best.pth")
            print(f"  ✓ best (val_loss={val_loss:.4f})")
        if args.max_steps and global_step >= args.max_steps:
            print(f"Hit max_steps; stopping."); break

    wandb.finish(); print(f"\nDone. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
