"""Train voxel-token AR policy (variants B and C — abs xyz / rel xyz).

Mirrors train_ar_v2.py with the voxel model swapped in. Multi-target supervision per
DINO call, iterative-backward to bound memory.

Usage:
  cd /data/cameron/para/libero
  CUDA_VISIBLE_DEVICES=9 \
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
  DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
  python train_voxel_ar.py \
    --variant abs \
    --cache_root /data/libero/parsed_libero --benchmark libero_spatial --task_id 0 \
    --window_len 20 --history_len 8 --voxel_xy 28 --voxel_z 16 \
    --batch_size 2 --lr 1e-4 --epochs 2 \
    --run_name voxel_abs_setting_i
"""
import argparse, os, sys, time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))

from data_ar import WindowTrajectoryDataset, target_xy_to_grid_idx, grid_idx_to_pixel
from model_voxel_ar import VoxelARPolicyAbs, VoxelARPolicyRel, IMAGE_SIZE, N_HEIGHT_BINS, N_ROT_BINS

# Pragmatic LIBERO defaults
MIN_HEIGHT = 0.85
MAX_HEIGHT = 1.55
MIN_ROT = [-3.14159, -3.14159, -3.14159]
MAX_ROT = [ 3.14159,  3.14159,  3.14159]
W_XY      = 1.0
W_HEIGHT  = 0.5
W_GRIPPER = 5.0
W_ROT     = 0.5


def discretize_height(h, min_h=MIN_HEIGHT, max_h=MAX_HEIGHT, n_bins=N_HEIGHT_BINS):
    norm = ((h - min_h) / (max_h - min_h + 1e-8)).clamp(0, 1)
    return (norm * (n_bins - 1)).long().clamp(0, n_bins - 1)


def discretize_rotation(euler, n_bins=N_ROT_BINS):
    min_t = torch.tensor(MIN_ROT, device=euler.device, dtype=torch.float32)
    max_t = torch.tensor(MAX_ROT, device=euler.device, dtype=torch.float32)
    norm = ((euler - min_t) / (max_t - min_t + 1e-8)).clamp(0, 1)
    return (norm * (n_bins - 1)).long().clamp(0, n_bins - 1)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--variant", type=str, required=True, choices=["abs", "rel"])
    p.add_argument("--cache_root", type=str, required=True)
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--max_demos", type=int, default=0)
    p.add_argument("--window_len", type=int, default=20)
    p.add_argument("--history_len", type=int, default=8)
    p.add_argument("--voxel_xy", type=int, default=28, help="voxel grid xy resolution (default 28 = lite)")
    p.add_argument("--voxel_z",  type=int, default=16, help="voxel grid z resolution")
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--grid_size", type=int, default=56, help="output classification grid")
    p.add_argument("--batch_size", type=int, default=2)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=2)
    p.add_argument("--max_steps", type=int, default=0)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--num_workers", type=int, default=8)
    p.add_argument("--log_scalars_every", type=int, default=20)
    p.add_argument("--run_name", type=str, default="voxel_run")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"])
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    assert args.window_len > args.history_len

    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("\nLoading dataset...")
    full = WindowTrajectoryDataset(
        cache_root=args.cache_root, benchmark_name=args.benchmark, task_ids=[args.task_id],
        image_size=IMAGE_SIZE, window_len=args.window_len, frame_stride=args.frame_stride,
        max_demos=args.max_demos,
    )
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42)
    )
    print(f"  Train: {n_tr}  Val: {n_val}")
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)

    print(f"\nBuilding voxel model (variant={args.variant}, V={args.voxel_xy}×{args.voxel_xy}×{args.voxel_z})...")
    Cls = VoxelARPolicyAbs if args.variant == "abs" else VoxelARPolicyRel
    model = Cls(
        target_size=IMAGE_SIZE, history_len=args.history_len, grid_size=args.grid_size,
        voxel_xy=args.voxel_xy, voxel_z=args.voxel_z,
        freeze_backbone=True, min_height=MIN_HEIGHT, max_height=MAX_HEIGHT,
    ).to(device)
    n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_train:,}")
    wandb.config.update({"trainable_params": n_train})

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.lr, weight_decay=1e-4)

    H = args.history_len; W = args.window_len; G = args.grid_size
    target_steps = list(range(H, W))

    def run_window(batch, train_mode):
        imgs    = batch["window_imgs"].to(device, non_blocking=True)
        eef_xy  = batch["window_eef_xy"].to(device, non_blocking=True)
        eef_pos = batch["window_eef_pos"].to(device, non_blocking=True)
        eef_eul = batch["window_eef_euler"].to(device, non_blocking=True)
        grip    = batch["window_gripper"].to(device, non_blocking=True)
        valid   = batch["valid_mask"].to(device, non_blocking=True)
        eef_start = batch["window_eef_start"].to(device, non_blocking=True)
        cam_K   = batch["cam_K"].to(device, non_blocking=True)
        cam_E   = batch["cam_extrinsic"].to(device, non_blocking=True)

        if train_mode:
            patches = model.patch_encoder(imgs)                                # (B, W, Np, D)
        else:
            with torch.no_grad():
                patches = model.patch_encoder(imgs)

        valid_targets = [t for t in target_steps if valid[:, t].any()]
        K = len(valid_targets)
        if K == 0: return None, None, 0

        loss_sum = 0.0; pix_err_sum = 0.0
        for i, t in enumerate(valid_targets):
            hist_p = patches[:, t - H : t]                                     # (B, H, Np, D)
            hist_e = eef_xy[:, t - H : t]
            current_patches = patches[:, t - 1]                                 # (B, Np, D) — most recent past
            # Voxel features for the most recent past frame (the "current" frame for prediction)
            anchor = eef_start if args.variant == "rel" else None
            voxel_feats, _ = model.voxel_builder(current_patches, cam_K, cam_E, anchor)
            out = model.ar_head(hist_p, hist_e, voxel_feats, IMAGE_SIZE)

            target_xy = eef_xy[:, t]
            target_h_bin = discretize_height(eef_pos[:, t, 2])
            target_g_bin = (grip[:, t] > 0).float()
            target_rot_bin = discretize_rotation(eef_eul[:, t])
            tgt_idx = target_xy_to_grid_idx(target_xy, IMAGE_SIZE, G)
            mask_f = valid[:, t].float()
            denom = mask_f.sum().clamp_min(1.0)

            ce_xy = (F.cross_entropy(out["xy_logits"], tgt_idx, reduction="none") * mask_f).sum() / denom
            ce_h  = (F.cross_entropy(out["height_logits"], target_h_bin, reduction="none") * mask_f).sum() / denom
            bce_g = (F.binary_cross_entropy_with_logits(out["gripper_logit"], target_g_bin, reduction="none") * mask_f).sum() / denom
            rot_loss_axes = []
            for axis in range(3):
                rot_loss_axes.append(
                    (F.cross_entropy(out["rotation_logits"][:, axis], target_rot_bin[:, axis], reduction="none") * mask_f).sum() / denom
                )
            ce_r = torch.stack(rot_loss_axes).mean()
            term_loss = W_XY * ce_xy + W_HEIGHT * ce_h + W_GRIPPER * bce_g + W_ROT * ce_r
            scaled = term_loss / K
            if train_mode:
                scaled.backward(retain_graph=(i < K - 1))
            loss_sum += term_loss.item()
            with torch.no_grad():
                pred_xy = grid_idx_to_pixel(out["xy_logits"].argmax(dim=-1), IMAGE_SIZE, G)
                pix_err_sum += ((pred_xy - target_xy).norm(dim=-1) * mask_f).sum().item() / denom.item()
        return loss_sum / K, pix_err_sum / K, K

    best_val = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} train", leave=False)
        tr_losses = []
        for batch in pbar:
            optimizer.zero_grad()
            l, pe, n = run_window(batch, True)
            if l is None: continue
            optimizer.step()
            tr_losses.append(l)
            pbar.set_postfix(loss=f"{l:.3f}", px=f"{pe:.1f}", K=n)
            global_step += 1
            if global_step % args.log_scalars_every == 0:
                wandb.log({"train/loss": l, "train/pixel_error": pe, "train/K": n, "epoch": epoch}, step=global_step)
            if args.max_steps and global_step >= args.max_steps:
                break

        model.eval()
        val_losses = []; val_pix = []
        with torch.no_grad():
            for batch in val_loader:
                l, pe, _ = run_window(batch, False)
                if l is None: continue
                val_losses.append(l); val_pix.append(pe)
        tr_loss = float(np.mean(tr_losses)) if tr_losses else float("inf")
        val_loss = float(np.mean(val_losses)) if val_losses else float("inf")
        val_err = float(np.mean(val_pix)) if val_pix else float("inf")
        print(f"Epoch {epoch}: train_loss={tr_loss:.4f}  val_loss={val_loss:.4f}  val_px_err={val_err:.1f}px")
        wandb.log({"epoch_end/train_loss": tr_loss, "epoch_end/val_loss": val_loss,
                   "epoch_end/val_pixel_error": val_err, "epoch": epoch,
                   "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        ckpt = {"epoch": epoch, "global_step": global_step,
                "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss, "args": vars(args), "variant": args.variant}
        torch.save(ckpt, ckpt_dir / "latest.pth")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(ckpt, ckpt_dir / "best.pth")
            print(f"  ✓ best (val_loss={val_loss:.4f})")
        if args.max_steps and global_step >= args.max_steps:
            print(f"Hit max_steps; stopping."); break

    wandb.finish()
    print(f"\nDone. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
