"""Train UVA (MAR video model) + PARA action heads jointly on LIBERO.

Joint training: video diffusion loss + PARA volume/gripper/rotation CE losses.
UVA decoder tokens are upsampled to 64x64 and used as pixel-aligned features for PARA
heads (same architecture as DINO features in model.py, but with UVA backbone).

Usage:
    CUDA_VISIBLE_DEVICES=5 python libero/train_uva_para.py \
        --cache_root /data/libero/parsed_libero \
        --uva_checkpoint video_training/unified_video_action/checkpoints/simple_uva_libero_stride3_latest.pt \
        --log_wandb --run_name uva_para_libero_spatial
"""

import argparse
import io
import json
import random
import sys
import tempfile
from pathlib import Path
from types import SimpleNamespace

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, str(Path(__file__).parent))
UVA_ROOT = Path(__file__).resolve().parent.parent / "video_training" / "unified_video_action"
sys.path.insert(0, str(UVA_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from data import CachedTrajectoryDataset

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
N_FRAMES = 4            # UVA video frames
UVA_IMG_SIZE = 256      # UVA operates at 256x256
PRERENDER_SIZE = 448    # parsed libero images are 448x448
PARA_OUT_SIZE = 64      # PARA head output spatial size
LATENT_SCALE = 0.2325
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
DECODER_DIM = 768       # mar_base decoder_embed_dim
MAR_GRID = 16           # latent grid: 256/16=16

# Dataset stats — updated at runtime from dataset
MIN_HEIGHT = 0.0
MAX_HEIGHT = 1.0
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
MIN_ROT = [-3.14159, -3.14159, -3.14159]
MAX_ROT = [3.14159, 3.14159, 3.14159]

GRIPPER_LOSS_WEIGHT = 5.0
ROTATION_LOSS_WEIGHT = 0.5
PARA_LOSS_WEIGHT = 1.0


# ---------------------------------------------------------------------------
# PARA Heads (on MAR decoder tokens)
# ---------------------------------------------------------------------------

class ParaHeads(nn.Module):
    """Volume + gripper + rotation heads on MAR decoder tokens.

    Takes (B, T, S, C) decoder tokens, reshapes to spatial grid,
    upsamples 16->64 with convs, then:
      - volume_head: 1x1 conv -> (B, T, N_HEIGHT_BINS, 64, 64)
      - gripper/rotation MLPs indexed at query pixel (teacher forcing in train)
    Same architecture as the DINO feature processing in model.py.
    """

    def __init__(self, decoder_dim=DECODER_DIM, para_out_size=PARA_OUT_SIZE,
                 n_height_bins=N_HEIGHT_BINS, n_gripper_bins=N_GRIPPER_BINS,
                 n_rot_bins=N_ROT_BINS):
        super().__init__()
        D = decoder_dim
        self.para_out_size = para_out_size
        self.n_height_bins = n_height_bins

        # Upsample 16x16 -> 64x64 with conv refinement
        self.feature_net = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        self.volume_head = nn.Conv2d(D, n_height_bins, 1)

        self.gripper_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, n_gripper_bins)
        )
        self.rotation_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3 * n_rot_bins)
        )

    def forward(self, dec_tokens, query_pixels=None):
        """
        Args:
            dec_tokens: (B, T, S, C) where S = MAR_GRID^2 = 256
            query_pixels: (B, T, 2) in PARA_OUT_SIZE coords [x, y]. Optional.
        Returns:
            volume_logits: (B, T, N_HEIGHT_BINS, out_size, out_size)
            feats: (B, T, D, out_size, out_size)
            gripper_logits: (B, T, N_GRIPPER_BINS) or None
            rotation_logits: (B, T, 3, N_ROT_BINS) or None
        """
        B, T, S, C = dec_tokens.shape
        H_lat = W_lat = int(round(S ** 0.5))

        x = dec_tokens.reshape(B * T, H_lat, W_lat, C).permute(0, 3, 1, 2)
        feats = self.feature_net(x)  # (B*T, C, 64, 64)
        vol = self.volume_head(feats)  # (B*T, N_HEIGHT_BINS, 64, 64)

        P = self.para_out_size
        feats_5d = feats.view(B, T, C, P, P)
        volume_logits = vol.view(B, T, self.n_height_bins, P, P)

        gripper_logits = rotation_logits = None
        if query_pixels is not None:
            px = query_pixels[..., 0].long().clamp(0, P - 1)  # (B, T)
            py = query_pixels[..., 1].long().clamp(0, P - 1)
            # Detach features for gripper/rotation (same as model.py)
            feats_det = feats_5d.detach()
            batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, T)
            time_idx = torch.arange(T, device=feats.device).view(1, T).expand(B, T)
            indexed = feats_det[batch_idx, time_idx, :, py, px]  # (B, T, C)
            flat = indexed.reshape(B * T, C)
            gripper_logits = self.gripper_mlp(flat).reshape(B, T, N_GRIPPER_BINS)
            rotation_logits = self.rotation_mlp(flat).reshape(B, T, 3, N_ROT_BINS)

        return volume_logits, feats_5d, gripper_logits, rotation_logits

    def predict_at_pixels(self, feats_5d, query_pixels):
        """Inference: predict gripper/rotation at given pixel locations.

        Args:
            feats_5d: (B, T, D, P, P)
            query_pixels: (B, T, 2) in PARA_OUT_SIZE coords
        Returns:
            gripper_logits: (B, T, N_GRIPPER_BINS)
            rotation_logits: (B, T, 3, N_ROT_BINS)
        """
        B, T = query_pixels.shape[:2]
        P = feats_5d.shape[-1]
        C = feats_5d.shape[2]
        px = query_pixels[..., 0].long().clamp(0, P - 1)
        py = query_pixels[..., 1].long().clamp(0, P - 1)
        batch_idx = torch.arange(B, device=feats_5d.device).view(B, 1).expand(B, T)
        time_idx = torch.arange(T, device=feats_5d.device).view(1, T).expand(B, T)
        indexed = feats_5d[batch_idx, time_idx, :, py, px]  # (B, T, C)
        flat = indexed.reshape(B * T, C)
        gripper_logits = self.gripper_mlp(flat).reshape(B, T, N_GRIPPER_BINS)
        rotation_logits = self.rotation_mlp(flat).reshape(B, T, 3, N_ROT_BINS)
        return gripper_logits, rotation_logits


# ---------------------------------------------------------------------------
# Loss functions (from train.py)
# ---------------------------------------------------------------------------

def discretize_height(height_values, min_h, max_h, n_bins=N_HEIGHT_BINS):
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (n_bins - 1)).long().clamp(0, n_bins - 1)


def discretize_gripper(gripper_values, min_g, max_g):
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)


def discretize_rotation(euler_values, min_r, max_r):
    normalized = (euler_values - min_r) / (max_r - min_r + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins):
    """CE over flattened (Nh*H*W) per timestep."""
    B, N, Nh, H, W = pred_volume_logits.shape
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)
    h_bin = target_height_bins.clamp(0, Nh - 1)
    losses = []
    for t in range(N):
        logits_flat = pred_volume_logits[:, t].reshape(B, -1)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()
        losses.append(F.cross_entropy(logits_flat, target_idx, reduction='mean'))
    return torch.stack(losses).mean()


def compute_gripper_loss(pred_gripper_logits, target_gripper, min_g, max_g):
    target_bins = discretize_gripper(target_gripper, min_g, max_g)
    B, N, Ng = pred_gripper_logits.shape
    return F.cross_entropy(pred_gripper_logits.reshape(B * N, Ng), target_bins.reshape(B * N))


def compute_rotation_loss(pred_rotation_logits, target_euler, min_r, max_r):
    target_bins = discretize_rotation(target_euler, min_r, max_r)
    B, N, _, Nr = pred_rotation_logits.shape
    losses = []
    for axis in range(3):
        logits_axis = pred_rotation_logits[:, :, axis, :].reshape(B * N, Nr)
        target_axis = target_bins[:, :, axis].reshape(B * N)
        losses.append(F.cross_entropy(logits_axis, target_axis))
    return torch.stack(losses).mean()


def extract_pred_2d_and_height(volume_logits, min_h, max_h):
    """From volume (B, T, Nh, H, W) -> pred_2d (B, T, 2), pred_height (B, T)."""
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[torch.arange(B, device=device), :, py, px].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, Nh, device=device)
    pred_height = bin_centers[pred_height_bins] * (max_h - min_h) + min_h
    return pred_2d, pred_height


# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------

def build_max_along_ray_heatmaps(volume_logits):
    """volume_logits (B, T, Nh, H, W) -> max-over-height heatmap (B, T, H, W)."""
    B, T, Nh, H, W = volume_logits.shape
    vol_probs = F.softmax(volume_logits.reshape(B, T, -1), dim=2).view(B, T, Nh, H, W)
    return vol_probs.max(dim=2)[0]


# ---------------------------------------------------------------------------
# VAE / MAR helpers
# ---------------------------------------------------------------------------

def build_vae(vae_ckpt, device):
    ckpt_path = Path(vae_ckpt)
    if not ckpt_path.is_absolute():
        ckpt_path = UVA_ROOT / vae_ckpt
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device).eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def load_mar_checkpoint(model, ckpt_path, device):
    """Load pretrained UVA MAR weights (strict=False to allow new PARA heads)."""
    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists() and not ckpt_path.is_absolute():
        # Try relative to UVA_ROOT
        ckpt_path = UVA_ROOT / ckpt_path
    if not ckpt_path.exists():
        print(f"UVA checkpoint not found: {ckpt_path}; training from scratch")
        return

    try:
        import dill
        payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
    except Exception:
        payload = torch.load(ckpt_path, map_location=device, weights_only=False)

    # Handle various checkpoint formats
    if "state_dicts" in payload:
        sd = payload["state_dicts"].get("ema_model") or payload["state_dicts"].get("model")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
    elif "model" in payload:
        model_sd = payload["model"]
    else:
        raise KeyError(f"Unrecognized checkpoint format: {list(payload.keys())}")

    current = model.state_dict()
    loadable = {k: v for k, v in model_sd.items() if k in current and current[k].shape == v.shape}
    current.update(loadable)
    model.load_state_dict(current, strict=False)
    print(f"Loaded {len(loadable)}/{len(model_sd)} MAR keys from {ckpt_path}")


def frames_to_video_tensor(rgb_frames_raw, n_frames, uva_img_size=UVA_IMG_SIZE):
    """Convert (B, N_WINDOW, H, W, 3) raw frames [0,1] to (B, 3, n_frames, uva_size, uva_size) in [-1,1]."""
    B = rgb_frames_raw.shape[0]
    frames = rgb_frames_raw[:, :n_frames]  # (B, n_frames, H, W, 3)
    frames = frames.permute(0, 1, 4, 2, 3)  # (B, n_frames, 3, H, W)
    frames = frames.reshape(B * n_frames, 3, frames.shape[3], frames.shape[4])
    frames = F.interpolate(frames, size=(uva_img_size, uva_img_size), mode='bilinear', align_corners=False)
    frames = frames.reshape(B, n_frames, 3, uva_img_size, uva_img_size)
    frames = frames.permute(0, 2, 1, 3, 4)  # (B, 3, n_frames, H, W)
    return frames * 2.0 - 1.0


# ---------------------------------------------------------------------------
# Dataset stats
# ---------------------------------------------------------------------------

def compute_dataset_stats(dataset, sample_limit=500, seed=42):
    """Scan dataset for height/gripper/rotation range."""
    rng = random.Random(seed)
    n = min(sample_limit, len(dataset))
    indices = rng.sample(range(len(dataset)), n)

    all_heights, all_grippers, all_eulers = [], [], []
    for idx in tqdm(indices, desc="Computing stats", leave=False):
        try:
            sample = dataset[idx]
        except Exception:
            continue
        t3d = sample['trajectory_3d'].numpy()
        all_heights.extend(t3d[:, 2].tolist())
        all_grippers.extend(sample['trajectory_gripper'].numpy().tolist())
        all_eulers.append(sample['trajectory_euler'].numpy())

    heights = np.array(all_heights)
    grippers = np.array(all_grippers)
    eulers = np.concatenate(all_eulers, axis=0)
    return {
        "min_height": float(heights.min()), "max_height": float(heights.max()),
        "min_gripper": float(grippers.min()), "max_gripper": float(grippers.max()),
        "min_rot": eulers.min(axis=0).tolist(), "max_rot": eulers.max(axis=0).tolist(),
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    p = argparse.ArgumentParser(description="Train UVA + PARA on LIBERO")
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_ids", type=str, default="all")
    p.add_argument("--uva_checkpoint", type=str,
                    default="checkpoints/simple_uva_libero_stride3_latest.pt",
                    help="Pretrained UVA MAR checkpoint")
    p.add_argument("--vae_ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=500)
    p.add_argument("--frame_stride", type=int, default=3)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--run_name", type=str, default="uva_para_libero")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--vis_every", type=int, default=100, help="Steps between vis updates")
    p.add_argument("--checkpoint_every", type=int, default=1000, help="Save checkpoint every N steps")
    p.add_argument("--num_iter", type=int, default=64, help="Diffusion sampling steps for vis")
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--video_loss_weight", type=float, default=1.0)
    p.add_argument("--para_loss_weight", type=float, default=1.0,
                    help="Weight for total PARA loss (volume + gripper + rotation)")
    p.add_argument("--freeze_mar", action="store_true",
                    help="Freeze MAR backbone, only train PARA heads")
    p.add_argument("--resume", type=str, default="", help="Path to resume checkpoint")
    args = p.parse_args()

    device = torch.device(args.device)
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.log_wandb:
        import wandb
        wandb.init(project="uva_para_libero", config=vars(args), name=args.run_name, mode="online")

    # --- Build models ---
    vae = build_vae(args.vae_ckpt, device)
    mar = mar_base_video_only(
        img_size=UVA_IMG_SIZE, vae_stride=16, patch_size=1, vae_embed_dim=16,
        num_sampling_steps="100", diffloss_d=6, diffloss_w=1024,
    ).to(device)
    load_mar_checkpoint(mar, args.uva_checkpoint, device)

    para_heads = ParaHeads(
        decoder_dim=DECODER_DIM, para_out_size=PARA_OUT_SIZE,
    ).to(device)
    print(f"ParaHeads: {sum(p.numel() for p in para_heads.parameters()):,} params")

    # --- Dataset ---
    task_ids = None
    if args.task_ids and args.task_ids.strip().lower() != "all":
        task_ids = [int(x) for x in args.task_ids.split(",")]

    dataset = CachedTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=task_ids,
        image_size=PRERENDER_SIZE,  # keep at 448, we resize to 256 for UVA
        n_window=N_FRAMES,          # only need 4 frames for UVA
        frame_stride=args.frame_stride,
    )
    print(f"Dataset: {len(dataset)} samples")

    # Train/val split
    val_size = max(1, int(len(dataset) * args.val_split))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
    )
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

    # --- Dataset stats ---
    stats_path = ckpt_dir / "dataset_stats.json"
    if stats_path.exists():
        with open(stats_path) as f:
            stats = json.load(f)
        print(f"Loaded stats from {stats_path}")
    else:
        stats = compute_dataset_stats(dataset)
        with open(stats_path, "w") as f:
            json.dump(stats, f, indent=2)
        print(f"Saved stats to {stats_path}")

    global MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER, MIN_ROT, MAX_ROT
    MIN_HEIGHT = stats["min_height"]
    MAX_HEIGHT = stats["max_height"]
    MIN_GRIPPER = stats["min_gripper"]
    MAX_GRIPPER = stats["max_gripper"]
    MIN_ROT = stats["min_rot"]
    MAX_ROT = stats["max_rot"]
    min_r_t = torch.tensor(MIN_ROT, dtype=torch.float32, device=device)
    max_r_t = torch.tensor(MAX_ROT, dtype=torch.float32, device=device)
    print(f"Height: [{MIN_HEIGHT:.4f}, {MAX_HEIGHT:.4f}]")
    print(f"Gripper: [{MIN_GRIPPER:.4f}, {MAX_GRIPPER:.4f}]")
    print(f"Rotation: {MIN_ROT} .. {MAX_ROT}")

    # Coord scale: trajectory_2d is in PRERENDER_SIZE (448) space, PARA output is 64
    coord_scale = PARA_OUT_SIZE / PRERENDER_SIZE

    # --- Optimizer ---
    if args.freeze_mar:
        for p_param in mar.parameters():
            p_param.requires_grad = False
        mar.eval()
        all_params = list(para_heads.parameters())
        print("Frozen MAR backbone — only training PARA heads")
    else:
        all_params = list(mar.parameters()) + list(para_heads.parameters())
    opt = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=1e-4)

    start_epoch = 0
    if args.resume and Path(args.resume).exists():
        ckpt = torch.load(args.resume, map_location=device)
        mar.load_state_dict(ckpt["mar_state_dict"])
        para_heads.load_state_dict(ckpt["para_heads_state_dict"])
        opt.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt.get("epoch", 0) + 1
        print(f"Resumed from {args.resume} at epoch {start_epoch}")

    # --- Training loop ---
    global_step = 0
    best_val_loss = float('inf')

    for epoch in tqdm(range(start_epoch, args.epochs), desc="Epochs"):
        if not args.freeze_mar:
            mar.train()
        para_heads.train()
        epoch_losses = {"total": 0, "video": 0, "volume": 0, "gripper": 0, "rotation": 0}
        n_batches = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            # --- Prepare video tensor ---
            rgb_frames_raw = batch['rgb_frames_raw'].to(device)  # (B, N_FRAMES, 448, 448, 3)
            video = frames_to_video_tensor(rgb_frames_raw, N_FRAMES, UVA_IMG_SIZE)  # (B, 3, T, 256, 256)
            B, C, T, H, W = video.shape

            # Trajectory supervision
            traj_2d = batch['trajectory_2d'].to(device)[:, :N_FRAMES]  # (B, T, 2) in 448 space
            traj_3d = batch['trajectory_3d'].to(device)[:, :N_FRAMES]
            traj_gripper = batch['trajectory_gripper'].to(device)[:, :N_FRAMES]
            traj_euler = batch['trajectory_euler'].to(device)[:, :N_FRAMES]

            # Scale trajectory to PARA output space
            traj_para = traj_2d * coord_scale
            traj_para = traj_para.clamp(0, PARA_OUT_SIZE - 1.001)

            target_height = traj_3d[:, :, 2]
            target_height_bins = discretize_height(target_height, MIN_HEIGHT, MAX_HEIGHT)

            # --- Encode video through frozen VAE ---
            frames_flat = rearrange(video, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames_flat.float())
                z = posterior.sample() * LATENT_SCALE
            z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
            z_flat = rearrange(z, "b t c h w -> (b t) c h w")
            x_tokens = mar.patchify(z_flat)
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
            cond_tokens = x_tokens[:, :1].expand(-1, T, -1, -1)

            # --- Video diffusion loss ---
            if not args.freeze_mar:
                video_loss = mar.compute_loss(x_tokens, cond_tokens)
            else:
                video_loss = torch.tensor(0.0, device=device)

            # --- PARA loss (decoder tokens, no masking) ---
            dec_tokens = mar.forward_decode_tokens(x_tokens, cond_tokens, mask=None)  # (B, T, S, C)
            volume_logits, feats, gripper_logits, rotation_logits = para_heads(
                dec_tokens, query_pixels=traj_para
            )

            volume_loss = compute_volume_loss(volume_logits, traj_para, target_height_bins)
            gripper_loss = compute_gripper_loss(gripper_logits, traj_gripper, MIN_GRIPPER, MAX_GRIPPER)
            rotation_loss = compute_rotation_loss(rotation_logits, traj_euler, min_r_t, max_r_t)

            para_loss = volume_loss + GRIPPER_LOSS_WEIGHT * gripper_loss + ROTATION_LOSS_WEIGHT * rotation_loss
            total_loss = args.video_loss_weight * video_loss + args.para_loss_weight * para_loss

            opt.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(all_params, 1.0)
            opt.step()

            # --- Logging ---
            epoch_losses["total"] += total_loss.item()
            epoch_losses["video"] += video_loss.item()
            epoch_losses["volume"] += volume_loss.item()
            epoch_losses["gripper"] += gripper_loss.item()
            epoch_losses["rotation"] += rotation_loss.item()
            n_batches += 1
            pbar.set_postfix(
                vid=f"{video_loss.item():.3f}",
                vol=f"{volume_loss.item():.3f}",
                grip=f"{gripper_loss.item():.3f}",
                rot=f"{rotation_loss.item():.3f}",
            )

            if args.log_wandb:
                import wandb
                wandb.log({
                    "train_step/total_loss": total_loss.item(),
                    "train_step/video_loss": video_loss.item(),
                    "train_step/volume_loss": volume_loss.item(),
                    "train_step/gripper_loss": gripper_loss.item(),
                    "train_step/rotation_loss": rotation_loss.item(),
                }, step=global_step)

            # --- Visualization ---
            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                mar.eval()
                para_heads.eval()
                with torch.no_grad():
                    # Heatmap visualization
                    heatmaps = build_max_along_ray_heatmaps(volume_logits[:1])  # (1, T, 64, 64)
                    heatmaps_np = heatmaps[0].cpu().numpy()

                    # Input frame for overlay
                    input_frame = rgb_frames_raw[0, 0].cpu().numpy()  # (448, 448, 3)
                    input_frame_small = cv2.resize(input_frame, (PARA_OUT_SIZE, PARA_OUT_SIZE))

                    import matplotlib
                    matplotlib.use("Agg")
                    import matplotlib.pyplot as plt

                    fig, axes = plt.subplots(2, N_FRAMES, figsize=(3 * N_FRAMES, 6))
                    for t in range(N_FRAMES):
                        # Top row: heatmap overlay on input frame
                        ax = axes[0, t]
                        heat = heatmaps_np[t]
                        heat_norm = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
                        overlay = input_frame_small * 0.5 + np.stack([heat_norm, np.zeros_like(heat_norm), np.zeros_like(heat_norm)], axis=-1) * 0.5
                        ax.imshow(np.clip(overlay, 0, 1))
                        ax.scatter(
                            traj_para[0, t, 0].cpu().item(),
                            traj_para[0, t, 1].cpu().item(),
                            c="cyan", s=40, marker="x", linewidths=2,
                        )
                        ax.set_title(f"t={t} heatmap")
                        ax.axis("off")

                        # Bottom row: predicted video frame (if available, else GT)
                        gt_frame = rgb_frames_raw[0, t].cpu().numpy()
                        gt_small = cv2.resize(gt_frame, (PARA_OUT_SIZE, PARA_OUT_SIZE))
                        axes[1, t].imshow(gt_small)
                        axes[1, t].set_title(f"t={t} GT frame")
                        axes[1, t].axis("off")

                    plt.tight_layout()
                    buf = io.BytesIO()
                    plt.savefig(buf, format="png", dpi=100)
                    buf.seek(0)
                    from PIL import Image
                    wandb.log({"vis/heatmap_and_frames": wandb.Image(Image.open(buf))}, step=global_step)
                    plt.close("all")

                    # Predicted video (sample from UVA)
                    if global_step % (args.vis_every * 5) == 0:
                        import torchvision
                        first_frame = video[:1, :, 0]  # (1, 3, 256, 256)
                        posterior0 = vae.encode(first_frame.float())
                        z0 = posterior0.sample() * LATENT_SCALE
                        cond = z0.unsqueeze(1).expand(1, N_FRAMES, -1, -1, -1)
                        tokens, _ = mar.sample_tokens(
                            bsz=1, cond=cond, num_iter=args.num_iter, cfg=1.0, temperature=0.95,
                        )
                        pred = vae.decode(tokens / LATENT_SCALE)
                        pred = pred.view(1, N_FRAMES, 3, UVA_IMG_SIZE, UVA_IMG_SIZE)
                        pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                            tmp_path = f.name
                        frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                        torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                        wandb.log({"vis/predicted_video": wandb.Video(tmp_path, format="mp4")}, step=global_step)
                        Path(tmp_path).unlink(missing_ok=True)

                mar.train()
                para_heads.train()

            # --- Step-based checkpoint ---
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_data = {
                    "epoch": epoch, "global_step": global_step,
                    "mar_state_dict": mar.state_dict(),
                    "para_heads_state_dict": para_heads.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "stats": stats,
                    "min_height": MIN_HEIGHT, "max_height": MAX_HEIGHT,
                    "min_gripper": MIN_GRIPPER, "max_gripper": MAX_GRIPPER,
                    "min_rot": MIN_ROT, "max_rot": MAX_ROT,
                }
                torch.save(ckpt_data, ckpt_dir / "latest.pth")
                print(f"  Saved latest.pth (step {global_step})")

            global_step += 1

        # --- Epoch summary ---
        n = max(1, n_batches)
        avg = {k: v / n for k, v in epoch_losses.items()}
        print(f"Epoch {epoch}: total={avg['total']:.4f} vid={avg['video']:.4f} "
              f"vol={avg['volume']:.4f} grip={avg['gripper']:.4f} rot={avg['rotation']:.4f}")

        # --- Validation ---
        mar.eval()
        para_heads.eval()
        val_losses = {"total": 0, "volume": 0, "pixel_error": 0}
        val_n = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Val", leave=False):
                rgb_frames_raw = batch['rgb_frames_raw'].to(device)
                video = frames_to_video_tensor(rgb_frames_raw, N_FRAMES, UVA_IMG_SIZE)
                B = video.shape[0]
                traj_2d = batch['trajectory_2d'].to(device)[:, :N_FRAMES]
                traj_3d = batch['trajectory_3d'].to(device)[:, :N_FRAMES]
                traj_gripper = batch['trajectory_gripper'].to(device)[:, :N_FRAMES]
                traj_euler = batch['trajectory_euler'].to(device)[:, :N_FRAMES]
                traj_para = traj_2d * coord_scale
                traj_para = traj_para.clamp(0, PARA_OUT_SIZE - 1.001)
                target_height_bins = discretize_height(traj_3d[:, :, 2], MIN_HEIGHT, MAX_HEIGHT)

                frames_flat = rearrange(video, "b c t h w -> (b t) c h w")
                posterior = vae.encode(frames_flat.float())
                z = posterior.sample() * LATENT_SCALE
                z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
                z_flat = rearrange(z, "b t c h w -> (b t) c h w")
                x_tokens = mar.patchify(z_flat)
                x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
                cond_tokens = x_tokens[:, :1].expand(-1, N_FRAMES, -1, -1)

                dec_tokens = mar.forward_decode_tokens(x_tokens, cond_tokens, mask=None)
                volume_logits, feats, _, _ = para_heads(dec_tokens)
                volume_loss = compute_volume_loss(volume_logits, traj_para, target_height_bins)

                pred_2d, _ = extract_pred_2d_and_height(volume_logits, MIN_HEIGHT, MAX_HEIGHT)
                # pixel error in PARA output space (64)
                pixel_err = torch.norm(pred_2d - traj_para, dim=-1).mean().item()

                val_losses["volume"] += volume_loss.item() * B
                val_losses["pixel_error"] += pixel_err * B
                val_n += B

        val_n = max(1, val_n)
        val_vol = val_losses["volume"] / val_n
        val_px = val_losses["pixel_error"] / val_n
        print(f"  Val: volume={val_vol:.4f}, pixel_error={val_px:.2f}px (in 64-space)")

        if args.log_wandb:
            import wandb
            wandb.log({
                "epoch": epoch,
                "train/total_loss": avg["total"],
                "train/video_loss": avg["video"],
                "train/volume_loss": avg["volume"],
                "train/gripper_loss": avg["gripper"],
                "train/rotation_loss": avg["rotation"],
                "val/volume_loss": val_vol,
                "val/pixel_error_64": val_px,
            }, step=global_step)

        # --- Best checkpoint (val-based) ---
        if val_vol < best_val_loss:
            best_val_loss = val_vol
            ckpt_data = {
                "epoch": epoch, "global_step": global_step,
                "mar_state_dict": mar.state_dict(),
                "para_heads_state_dict": para_heads.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "stats": stats,
                "min_height": MIN_HEIGHT, "max_height": MAX_HEIGHT,
                "min_gripper": MIN_GRIPPER, "max_gripper": MAX_GRIPPER,
                "min_rot": MIN_ROT, "max_rot": MAX_ROT,
            }
            torch.save(ckpt_data, ckpt_dir / "best.pth")
            print(f"  Saved best (val_vol={val_vol:.4f}, step {global_step})")

    if args.log_wandb:
        import wandb
        wandb.finish()
    print(f"Done. Checkpoints at {ckpt_dir}")


if __name__ == "__main__":
    main()
