"""Train trajectory volume predictor on LIBERO demonstrations.

Model predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (CE at trajectory pixel only).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel (teacher forcing), decoded at pred pixel in val/inference.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import wandb
import json
import math
import random
import subprocess
import time

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

import cv2
from robosuite.utils.camera_utils import project_points_from_world_to_camera
import robosuite.utils.transform_utils as T_rob

from data import RealTrajectoryDataset, CachedTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE
import model as model_module  # Import module to access updated MIN_HEIGHT/MAX_HEIGHT at runtime
from utils import recover_3d_from_direct_keypoint_and_height

# Lazy imports for ablation models (avoids import errors if deps missing)
def get_model_class(model_type):
    if model_type == "para":
        return TrajectoryHeatmapPredictor
    elif model_type == "act":
        from model_act import ACTPredictor
        return ACTPredictor
    elif model_type == "da3":
        from model_da3 import DA3Predictor
        return DA3Predictor
    elif model_type == "moge":
        from model_moge import MoGePredictor
        return MoGePredictor
    elif model_type == "dino_vla":
        from model_dino_vla import DinoVLAPredictor
        return DinoVLAPredictor
    elif model_type == "internvl":
        from model_vla_internvl import InternVLAPredictor
        return InternVLAPredictor
    elif model_type == "internvl_act":
        from model_vla_internvl_act import InternVLACTPredictor
        return InternVLACTPredictor
    elif model_type == "dual_da3":
        from model_dual_da3 import DualDA3Predictor
        return DualDA3Predictor
    elif model_type == "dual_para":
        from model_dual_para import DualParaPredictor
        return DualParaPredictor
    elif model_type == "cost_volume":
        from model_cost_volume import CostVolumePredictor
        return CostVolumePredictor
    elif model_type == "wrist_only":
        return TrajectoryHeatmapPredictor  # same model, different data
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

def is_heatmap_model(model_type):
    """ACT-style models use direct regression; everything else uses pixel-aligned heatmaps."""
    return model_type not in ("act", "internvl_act")

def is_dual_model(model_type):
    """Dual-camera models process two views."""
    return model_type in ("dual_da3", "dual_para")

# Helper functions for discretization
def discretize_height(height_values):
    """Discretize continuous height values into bin indices.

    Args:
        height_values: (B, N_WINDOW) or (N_WINDOW,) tensor of heights in [MIN_HEIGHT, MAX_HEIGHT]

    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_HEIGHT_BINS-1]
    """
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    bin_indices = (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)
    return bin_indices




def decode_height_bins(bin_logits):
    """Decode height bin logits back to continuous height values.

    Args:
        bin_logits: (B, N_WINDOW, N_HEIGHT_BINS) logits for each bin

    Returns:
        height_values: (B, N_WINDOW) continuous height values in [MIN_HEIGHT, MAX_HEIGHT]
    """
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=bin_logits.device)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    height_values = normalized * (max_h - min_h) + min_h
    return height_values


def discretize_gripper(gripper_values):
    """Discretize continuous gripper values into bin indices."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)


def decode_gripper_bins(bin_logits):
    """Decode (B, N_WINDOW, N_GRIPPER_BINS) logits → (B, N_WINDOW) continuous gripper values."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    bin_indices = bin_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)
    return bin_centers[bin_indices] * (max_g - min_g) + min_g


def discretize_rotation(euler_values):
    """Discretize (B, N_WINDOW, 3) euler angles → (B, N_WINDOW, 3) bin indices, per axis."""
    min_r = torch.tensor(model_module.MIN_ROT, device=euler_values.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=euler_values.device, dtype=torch.float32)
    normalized = (euler_values - min_r) / (max_r - min_r + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def decode_rotation_bins(rot_logits):
    """Decode (B, N_WINDOW, 3, N_ROT_BINS) → (B, N_WINDOW, 3) continuous euler angles."""
    min_r = torch.tensor(model_module.MIN_ROT, device=rot_logits.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=rot_logits.device, dtype=torch.float32)
    bin_indices = rot_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_ROT_BINS, device=rot_logits.device)
    return bin_centers[bin_indices] * (max_r - min_r) + min_r


def compute_rotation_loss(pred_rotation_logits, target_euler, mask=None):
    """Cross-entropy for 3 euler axes, averaged.

    Args:
        pred_rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)
        target_euler:         (B, N_WINDOW, 3) euler angles in radians
        mask: (B, N_WINDOW) optional, 1=valid 0=ignore
    """
    target_bins = discretize_rotation(target_euler)
    B, N, _, Nr = pred_rotation_logits.shape
    losses = []
    for axis in range(3):
        logits_axis = pred_rotation_logits[:, :, axis, :].reshape(B * N, Nr)
        target_axis = target_bins[:, :, axis].reshape(B * N)
        ce = F.cross_entropy(logits_axis, target_axis, reduction='none')  # (B*N,)
        if mask is not None:
            ce = ce * mask.reshape(B * N)
        losses.append(ce.mean())
    total = torch.stack(losses).sum()
    if mask is not None:
        n_valid = mask.sum().clamp(min=1.0)
        return total * (B * N) / (3.0 * n_valid)
    return total / 3.0


def compute_rotation_loss_mse(pred_rotation_sigmoid, target_delta_rotvec):
    """MSE loss for continuous rotation prediction (sigmoid [0,1] vs normalized target).

    Args:
        pred_rotation_sigmoid: (B, N_WINDOW, 3) sigmoid outputs in [0, 1]
        target_delta_rotvec:   (B, N_WINDOW, 3) delta axis-angle values
    """
    min_r = torch.tensor(model_module.MIN_ROT, device=pred_rotation_sigmoid.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=pred_rotation_sigmoid.device, dtype=torch.float32)
    target_norm = ((target_delta_rotvec - min_r) / (max_r - min_r + 1e-8)).clamp(0.0, 1.0)
    return F.mse_loss(pred_rotation_sigmoid, target_norm)


# Training configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000
IMAGE_SIZE = 448  # 28x28 DINO patches at patch_size=16
SKIP_ROTATION = False

# All losses are normalized by log(N_classes) so they start at ~1.0.
# No manual weight tuning needed.


def build_volume_3d_points_for_vis(H, W, camera_pose, cam_K, height_bucket_centers, pixel_step=32):
    """Build 3D points for volume visualization (numpy). Returns (N, 3)."""
    points = []
    for y in range(0, H, pixel_step):
        for x in range(0, W, pixel_step):
            for height in height_bucket_centers:
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([x, y], dtype=np.float64), float(height), camera_pose, cam_K
                )
                if pt is not None:
                    points.append(pt)
    return np.array(points) if points else np.zeros((0, 3))


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins, mask=None):
    """Cross-entropy with softmax over all 3D cells (per timestep).

    Args:
        pred_volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
        trajectory_2d: (B, N_WINDOW, 2) pixel coords [x, y]
        target_height_bins: (B, N_WINDOW) bin indices in [0, N_HEIGHT_BINS-1]
        mask: (B, N_WINDOW) optional, 1=valid 0=ignore (for out-of-view wrist targets)
    """
    B, N, Nh, H, W = pred_volume_logits.shape
    device = pred_volume_logits.device
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)  # (B, N)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)  # (B, N)
    h_bin = target_height_bins.clamp(0, Nh - 1)  # (B, N)
    losses = []
    for t in range(N):
        logits_t = pred_volume_logits[:, t]  # (B, Nh, H, W)
        logits_flat = logits_t.reshape(B, -1)  # (B, Nh*H*W)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()  # (B,)
        ce = F.cross_entropy(logits_flat, target_idx, reduction='none')  # (B,)
        if mask is not None:
            ce = ce * mask[:, t]
        losses.append(ce.mean())
    total = torch.stack(losses).sum()
    if mask is not None:
        n_valid = mask.sum().clamp(min=1.0)
        return total * N / n_valid  # normalize by valid count
    return total / N


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.

    For each t: max over height bins gives (H,W) score; argmax gives (x,y); at (x,y) argmax over bins gives height bin.
    Returns:
        pred_2d: (B, N_WINDOW, 2) float pixel coords
        pred_height: (B, N_WINDOW) continuous height from decode_height_bins at that pixel
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.reshape(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def compute_gripper_loss(pred_gripper_logits, target_gripper, mask=None):
    """2-class CE for gripper (open/close).

    Args:
        pred_gripper_logits: (B, N_WINDOW, 2) logits for [open, close]
        target_gripper:      (B, N_WINDOW) values in [-1, 1] → class 0 (open) or 1 (close)
        mask: (B, N_WINDOW) optional, 1=valid 0=ignore
    """
    B, N = target_gripper.shape
    target_class = (target_gripper > 0).long()
    ce = F.cross_entropy(
        pred_gripper_logits.reshape(B * N, 2), target_class.reshape(B * N), reduction='none'
    )
    if mask is not None:
        ce = ce * mask.reshape(B * N)
        return ce.sum() / mask.sum().clamp(min=1.0)
    return ce.mean()


def visualize_sample(rgb, target_heatmap, pred_heatmap, target_2d):
    """Get visualization arrays for a single sample."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb.device).view(3, 1, 1)
    rgb_denorm = (rgb * std + mean).cpu().numpy()
    rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
    pred_heat = pred_heatmap.cpu().numpy()
    target_pt = target_2d.cpu().numpy()
    pred_y, pred_x = np.unravel_index(pred_heat.argmax(), pred_heat.shape)
    pred_pt = np.array([pred_x, pred_y])
    return rgb_vis, pred_heat, target_pt, pred_pt


def normalize_to_01(values, min_vals, max_vals):
    """Normalize values to [0, 1] given per-axis min/max. Clamps to [0, 1]."""
    range_vals = max_vals - min_vals + 1e-8
    return ((values - min_vals) / range_vals).clamp(0.0, 1.0)

def denormalize_from_01(values, min_vals, max_vals):
    """Denormalize values from [0, 1] back to original scale."""
    range_vals = max_vals - min_vals
    return values * range_vals + min_vals

def compute_act_loss(pos_pred, rot_pred, gripper_pred, trajectory_3d, trajectory_euler, trajectory_gripper):
    """ACT losses: MSE for pos/rot (normalized [0,1]), BCE for binary gripper."""
    device = pos_pred.device
    # Normalize pos/rot targets to [0, 1]
    min_pos = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
    max_pos = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
    min_rot = torch.tensor(model_module.MIN_ROT, device=device, dtype=torch.float32)
    max_rot = torch.tensor(model_module.MAX_ROT, device=device, dtype=torch.float32)

    pos_target  = normalize_to_01(trajectory_3d, min_pos, max_pos)
    rot_target  = normalize_to_01(trajectory_euler, min_rot, max_rot)

    pos_loss     = F.mse_loss(pos_pred, pos_target)
    rot_loss     = F.mse_loss(rot_pred, rot_target)
    # Binary gripper: -1 (open) → 0, +1 (close) → 1
    grip_target  = (trajectory_gripper > 0).float()
    gripper_loss = F.binary_cross_entropy_with_logits(gripper_pred, grip_target)
    return pos_loss, rot_loss, gripper_loss


def train_epoch(
    model,
    dataloader,
    optimizer,
    device,
    just_heatmap=False,
    global_step_start=0,
    vis_every_steps=50,
    vis_callback=None,
    log_scalars_every=10,
    model_type="para",
    save_every_steps=0,
    save_callback=None,
    loss_emas=None,
    ema_alpha=0.05,
):
    """Train for one epoch.

    Args:
        just_heatmap: if True, only volume loss is applied (gripper loss skipped).
        model_type: 'para'/'da3'/'moge' use heatmap CE, 'act' uses direct MSE.
    """
    model.train()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_rotation_loss = 0
    n_batches = 0

    heatmap_mode = is_heatmap_model(model_type)

    global_step = int(global_step_start)
    pbar = tqdm(dataloader, desc="Train", leave=False)
    for batch in pbar:
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)

        trajectory_euler = batch['trajectory_euler'].to(device)  # (B, N_WINDOW, 3)
        # Use delta rotvec for rotation supervision (avoids euler wrapping)
        trajectory_rot = batch['trajectory_delta_rotvec'].to(device) if 'trajectory_delta_rotvec' in batch else trajectory_euler

        # Extra kwargs for models that accept them
        extra_kwargs = {}
        if model_type in ("dino_vla", "act") and 'clip_embedding' in batch:
            extra_kwargs['clip_embedding'] = batch['clip_embedding'].to(device)
        elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
            extra_kwargs['task_text'] = batch['task_description']  # list of strings

        if is_dual_model(model_type):
            # Dual-camera: compute losses for both views
            target_height = trajectory_3d[:, :, 2]
            target_height_bins = discretize_height(target_height)
            coord_scale = PRED_SIZE / IMAGE_SIZE
            trajectory_2d_pred = trajectory_2d * coord_scale

            wrist_rgb = batch['wrist_rgb'].to(device)
            wrist_traj_2d = batch['wrist_trajectory_2d'].to(device)
            wrist_traj_2d_pred = wrist_traj_2d * coord_scale
            wrist_mask = batch['wrist_in_view'].to(device)  # (B, N_WINDOW)

            out = model(rgb, wrist_rgb, start_keypoint_2d=start_keypoint_2d,
                        agent_query_pixels=trajectory_2d_pred,
                        wrist_query_pixels=wrist_traj_2d_pred)

            # Agent view losses (always valid)
            a_vol  = compute_volume_loss(out['agent_volume'], trajectory_2d_pred, target_height_bins)
            a_grip = compute_gripper_loss(out['agent_gripper'], trajectory_gripper)
            a_rot  = compute_rotation_loss(out["agent_rotation"], trajectory_rot)
            # Wrist view losses (masked by visibility)
            w_vol  = compute_volume_loss(out['wrist_volume'], wrist_traj_2d_pred, target_height_bins, mask=None)
            w_grip = compute_gripper_loss(out['wrist_gripper'], trajectory_gripper, mask=None)
            w_rot  = compute_rotation_loss(out["wrist_rotation"], trajectory_rot, mask=None)

            # Use volume_loss/rotation_loss/gripper_loss for logging (agent view)
            volume_loss = a_vol
            rotation_loss = a_rot
            gripper_loss = a_grip

            # EMA for all 6 terms
            raw_losses = {
                'a_vol': a_vol.item(), 'a_rot': a_rot.item(), 'a_grip': a_grip.item(),
                'w_vol': w_vol.item(), 'w_rot': w_rot.item(), 'w_grip': w_grip.item(),
            }
            if loss_emas is not None:
                for k in raw_losses:
                    if k not in loss_emas:
                        loss_emas[k] = raw_losses[k]
                    loss_emas[k] = (1 - ema_alpha) * loss_emas[k] + ema_alpha * raw_losses[k]
                inv_sum = sum(1.0 / (loss_emas[k] + 1e-8) for k in raw_losses)
                n_terms = len(raw_losses)
                weights = {k: (n_terms / inv_sum) / (loss_emas[k] + 1e-8) for k in raw_losses}
            else:
                weights = {k: 1.0 for k in raw_losses}

            loss = (weights['a_vol'] * a_vol + weights['a_rot'] * a_rot + weights['a_grip'] * a_grip +
                    weights['w_vol'] * w_vol + weights['w_rot'] * w_rot + weights['w_grip'] * w_grip)

        elif heatmap_mode:
            # Single-camera heatmap models
            # For wrist_only: swap in wrist data
            if model_type == "wrist_only" and 'wrist_rgb' in batch:
                rgb = batch['wrist_rgb'].to(device)
                trajectory_2d = batch['wrist_trajectory_2d'].to(device)
                start_keypoint_2d = trajectory_2d[:, 0]

            target_height = trajectory_3d[:, :, 2]
            target_height_bins = discretize_height(target_height)
            coord_scale = PRED_SIZE / IMAGE_SIZE
            trajectory_2d_pred = trajectory_2d * coord_scale

            if model_type == "cost_volume" and 'wrist_rgb' in batch:
                wrist_rgb = batch['wrist_rgb'].to(device)
                volume_logits, gripper_logits, rotation_logits, _feats = model(
                    rgb, wrist_rgb, start_keypoint_2d=start_keypoint_2d,
                    agent_query_pixels=trajectory_2d_pred,
                    agent_query_height_bins=target_height_bins,
                    agent_cam_pose=batch['camera_pose'].to(device),
                    agent_cam_K_norm=batch['cam_K_norm'].to(device),
                    wrist_cam_pose=batch['wrist_camera_pose'].to(device),
                    wrist_cam_K_norm=batch['wrist_cam_K_norm'].to(device),
                )
            else:
                volume_logits, gripper_logits, rotation_logits, _feats = model(
                    rgb, start_keypoint_2d, query_pixels=trajectory_2d_pred, **extra_kwargs
                )

            volume_loss   = compute_volume_loss(volume_logits, trajectory_2d_pred, target_height_bins)
            gripper_loss  = compute_gripper_loss(gripper_logits, trajectory_gripper)
            if rotation_logits is None or SKIP_ROTATION:
                rotation_loss = torch.tensor(0.0, device=rgb.device)
            elif model_type == "cost_volume":
                rotation_loss = compute_rotation_loss_mse(rotation_logits, trajectory_rot)
            else:
                rotation_loss = compute_rotation_loss(rotation_logits, trajectory_rot)
        else:
            # ACT: direct regression with proprioception
            device = rgb.device
            min_pos = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
            max_pos = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
            min_grip = torch.tensor(model_module.MIN_GRIPPER, device=device, dtype=torch.float32)
            max_grip = torch.tensor(model_module.MAX_GRIPPER, device=device, dtype=torch.float32)
            current_eef_norm = normalize_to_01(trajectory_3d[:, 0], min_pos, max_pos)
            current_grip_norm = normalize_to_01(trajectory_gripper[:, 0], min_grip, max_grip)

            pos_pred, rot_pred, gripper_pred = model(
                rgb, start_keypoint_2d,
                current_eef_pos=current_eef_norm,
                current_gripper=current_grip_norm,
                **extra_kwargs,
            )
            volume_loss, rotation_loss, gripper_loss = compute_act_loss(
                pos_pred, rot_pred, gripper_pred,
                trajectory_3d, trajectory_euler, trajectory_gripper,
            )
            if SKIP_ROTATION:
                rotation_loss = torch.tensor(0.0, device=device)

        # --- EMA adaptive loss weighting (single-camera models) ---
        if not is_dual_model(model_type):
            has_rot = rotation_loss.item() > 0
            raw_losses = {'vol': volume_loss.item(), 'grip': gripper_loss.item()}
            if has_rot:
                raw_losses['rot'] = rotation_loss.item()

            if loss_emas is not None:
                for k in raw_losses:
                    if k not in loss_emas:
                        loss_emas[k] = raw_losses[k]
                    loss_emas[k] = (1 - ema_alpha) * loss_emas[k] + ema_alpha * raw_losses[k]
                active_keys = [k for k in raw_losses if loss_emas.get(k, 0) > 1e-10]
                if active_keys:
                    inv_sum = sum(1.0 / (loss_emas[k] + 1e-8) for k in active_keys)
                    n_terms = len(active_keys)
                    weights = {k: (n_terms / inv_sum) / (loss_emas[k] + 1e-8) for k in active_keys}
                else:
                    weights = {k: 1.0 for k in raw_losses}
                w_vol = weights.get('vol', 1.0)
                w_rot = weights.get('rot', 0.0)
                w_grip = weights.get('grip', 1.0)
            else:
                w_vol = w_rot = w_grip = 1.0

            if heatmap_mode and just_heatmap:
                loss = volume_loss
            else:
                loss = w_vol * volume_loss + w_grip * gripper_loss
                if has_rot:
                    loss = loss + w_rot * rotation_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log losses
        total_loss += loss.item()
        n_batches += 1
        global_step += 1

        if is_dual_model(model_type):
            # Dual: log per-view raw losses separately
            pbar.set_postfix(loss=f"{loss.item():.4f}",
                             a_vol=f"{a_vol.item():.4f}", a_rot=f"{a_rot.item():.4f}", a_grip=f"{a_grip.item():.4f}",
                             w_vol=f"{w_vol.item():.4f}", w_rot=f"{w_rot.item():.4f}", w_grip=f"{w_grip.item():.4f}")
            total_volume_loss += a_vol.item()
            total_gripper_loss += a_grip.item()
            total_rotation_loss += a_rot.item()
            if log_scalars_every > 0 and (global_step % log_scalars_every == 0):
                log_data = {
                    "train_step/loss": loss.item(),
                    "train_step/agent_volume_loss": a_vol.item(),
                    "train_step/agent_rotation_loss": a_rot.item(),
                    "train_step/agent_gripper_loss": a_grip.item(),
                    "train_step/wrist_volume_loss": w_vol.item(),
                    "train_step/wrist_rotation_loss": w_rot.item(),
                    "train_step/wrist_gripper_loss": w_grip.item(),
                }
                if loss_emas is not None:
                    for k in weights:
                        log_data[f"train_step/w_{k}"] = weights[k]
                wandb.log(log_data, step=global_step)
        else:
            weighted_vol  = (w_vol * volume_loss).item()
            weighted_rot  = (w_rot * rotation_loss).item()
            weighted_grip = (w_grip * gripper_loss).item()
            total_volume_loss += weighted_vol
            total_gripper_loss += weighted_grip
            total_rotation_loss += weighted_rot
            pbar.set_postfix(loss=f"{loss.item():.4f}", vol=f"{weighted_vol:.4f}",
                             grip=f"{weighted_grip:.4f}", rot=f"{weighted_rot:.4f}")
            if log_scalars_every > 0 and (global_step % log_scalars_every == 0):
                wandb.log({
                    "train_step/loss":          loss.item(),
                    "train_step/volume_loss":   weighted_vol,
                    "train_step/gripper_loss":  weighted_grip,
                    "train_step/rotation_loss": weighted_rot,
                    "train_step/w_vol":  w_vol,
                    "train_step/w_rot":  w_rot,
                    "train_step/w_grip": w_grip,
                }, step=global_step)
        if vis_callback is not None and vis_every_steps > 0 and (global_step % vis_every_steps == 0):
            vis_callback(global_step)
        if save_callback is not None and save_every_steps > 0 and (global_step % save_every_steps == 0):
            save_callback(global_step)

    return total_loss / n_batches, total_volume_loss / n_batches, total_gripper_loss / n_batches, total_rotation_loss / n_batches, global_step


def validate(model, dataloader, device, image_size=IMAGE_SIZE, model_type="para", loss_emas=None):
    """Validate model."""
    model.eval()

    # Dual models: simplified validation (just compute loss, skip pixel metrics for now)
    if is_dual_model(model_type) or model_type == "cost_volume":
        total_loss = 0
        n = 0
        with torch.no_grad():
            for batch in dataloader:
                rgb = batch['rgb'].to(device)
                traj_2d = batch['trajectory_2d'].to(device)
                traj_3d = batch['trajectory_3d'].to(device)
                traj_grip = batch['trajectory_gripper'].to(device)
                traj_euler = batch['trajectory_euler'].to(device)
                cs = PRED_SIZE / image_size
                th = discretize_height(traj_3d[:, :, 2])

                if model_type == "cost_volume":
                    wrist_rgb = batch['wrist_rgb'].to(device)
                    vol, grip, rot, _ = model(
                        rgb, wrist_rgb, start_keypoint_2d=traj_2d[:, 0],
                        agent_query_pixels=traj_2d * cs,
                        agent_query_height_bins=th,
                        agent_cam_pose=batch['camera_pose'].to(device),
                        agent_cam_K_norm=batch['cam_K_norm'].to(device),
                        wrist_cam_pose=batch['wrist_camera_pose'].to(device),
                        wrist_cam_K_norm=batch['wrist_cam_K_norm'].to(device),
                    )
                    loss = compute_volume_loss(vol, traj_2d * cs, th) + compute_gripper_loss(grip, traj_grip)
                    if rot is not None:
                        traj_rot_val = batch["trajectory_delta_rotvec"].to(device) if "trajectory_delta_rotvec" in batch else traj_euler
                        loss = loss + compute_rotation_loss_mse(rot, traj_rot_val)
                else:
                    wrist_rgb = batch['wrist_rgb'].to(device)
                    wrist_traj_2d = batch['wrist_trajectory_2d'].to(device)
                    out = model(rgb, wrist_rgb, start_keypoint_2d=traj_2d[:, 0],
                                agent_query_pixels=traj_2d * cs, wrist_query_pixels=wrist_traj_2d * cs)
                    a_vol = compute_volume_loss(out['agent_volume'], traj_2d * cs, th)
                    w_vol = compute_volume_loss(out['wrist_volume'], wrist_traj_2d * cs, th, mask=None)
                    a_grip = compute_gripper_loss(out['agent_gripper'], traj_grip)
                    w_grip = compute_gripper_loss(out['wrist_gripper'], traj_grip, mask=None)
                    a_rot = compute_rotation_loss(out['agent_rotation'], traj_euler)
                    w_rot = compute_rotation_loss(out['wrist_rotation'], traj_euler, mask=None)
                    loss = a_vol + w_vol + a_grip + w_grip + a_rot + w_rot

                total_loss += loss.item() * rgb.shape[0]
                n += rgb.shape[0]
        val_loss = total_loss / max(n, 1)
        return val_loss, val_loss, 0, 0, 0, 0, 0, 0, None

    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_pixel_error = 0
    total_height_error = 0
    total_height_error_tf = 0
    total_gripper_error = 0
    n_samples = 0
    sample_data = None

    heatmap_mode = is_heatmap_model(model_type)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Val", leave=False)
        for batch_idx, batch in enumerate(pbar):
            rgb = batch['rgb'].to(device)
            trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
            trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
            trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
            camera_pose = batch['camera_pose']  # (B, 4, 4)
            cam_K_norm = batch['cam_K_norm']  # (B, 3, 3) normalized
            start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)
            target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)

            trajectory_euler = batch['trajectory_euler'].to(device)
            trajectory_rot = batch['trajectory_delta_rotvec'].to(device) if 'trajectory_delta_rotvec' in batch else trajectory_euler

            extra_kwargs = {}
            if model_type in ("dino_vla", "act") and 'clip_embedding' in batch:
                extra_kwargs['clip_embedding'] = batch['clip_embedding'].to(device)
            elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
                extra_kwargs['task_text'] = batch['task_description']

            if not heatmap_mode:
                # ACT: direct regression validation with proprioception
                min_pos_t = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
                max_pos_t = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
                min_grip_t = torch.tensor(model_module.MIN_GRIPPER, device=device, dtype=torch.float32)
                max_grip_t = torch.tensor(model_module.MAX_GRIPPER, device=device, dtype=torch.float32)
                current_eef_norm = normalize_to_01(trajectory_3d[:, 0], min_pos_t, max_pos_t)
                current_grip_norm = normalize_to_01(trajectory_gripper[:, 0], min_grip_t, max_grip_t)

                pos_pred, rot_pred, gripper_pred = model(
                    rgb, start_keypoint_2d,
                    current_eef_pos=current_eef_norm,
                    current_gripper=current_grip_norm,
                    **extra_kwargs,
                )
                vol_loss, rot_loss, grip_loss = compute_act_loss(
                    pos_pred, rot_pred, gripper_pred,
                    trajectory_3d, trajectory_euler, trajectory_gripper,
                )
                loss = vol_loss + grip_loss + rot_loss
                total_loss += loss.item() * rgb.shape[0]
                total_volume_loss += vol_loss.item() * rgb.shape[0]
                total_gripper_loss += grip_loss.item() * rgb.shape[0]
                # Denormalize pos predictions for mm error metric
                pos_pred_denorm = denormalize_from_01(pos_pred, min_pos_t, max_pos_t)
                pos_err_mm = (pos_pred_denorm - trajectory_3d).norm(dim=-1).mean(dim=1).sum().item() * 1000
                total_pixel_error += pos_err_mm
                # Binary gripper accuracy: logit > 0 → close (+1), else open (-1)
                grip_pred_binary = torch.where(gripper_pred > 0,
                                               torch.ones_like(gripper_pred),
                                               -torch.ones_like(gripper_pred))
                total_gripper_error += torch.abs(grip_pred_binary - trajectory_gripper).mean(dim=1).sum().item()
                n_samples += rgb.shape[0]
                pbar.set_postfix(loss=f"{loss.item():.4f}", pos_mm=f"{pos_err_mm/rgb.shape[0]:.1f}")
                continue

            target_height_bins = discretize_height(target_height)

            # Scale GT coords from IMAGE_SIZE → pred_size for loss supervision
            pred_size = PRED_SIZE
            coord_scale = pred_size / image_size
            trajectory_2d_pred = trajectory_2d * coord_scale

            # Model: volume head spatial + gripper/rotation MLP at GT pixels (teacher forcing)
            volume_logits, gripper_logits, rotation_logits, feats = model(
                rgb, start_keypoint_2d, query_pixels=trajectory_2d_pred, **extra_kwargs
            )
            volume_loss   = compute_volume_loss(volume_logits, trajectory_2d_pred, target_height_bins)
            gripper_loss  = compute_gripper_loss(gripper_logits, trajectory_gripper)
            if SKIP_ROTATION or rotation_logits is None:
                rotation_loss = torch.tensor(0.0, device=device)
            else:
                rotation_loss = compute_rotation_loss(rotation_logits, trajectory_rot)
            # Apply EMA weights to val loss (same as training) so best.pth reflects weighted loss
            if loss_emas is not None:
                active = {k: v for k, v in loss_emas.items() if v > 1e-10}
                if active:
                    inv_sum = sum(1.0 / (v + 1e-8) for v in active.values())
                    n = len(active)
                    w_v = (n / inv_sum) / (loss_emas.get('vol', 1) + 1e-8) if 'vol' in active else 1.0
                    w_g = (n / inv_sum) / (loss_emas.get('grip', 1) + 1e-8) if 'grip' in active else 1.0
                    loss = w_v * volume_loss + w_g * gripper_loss + rotation_loss
                else:
                    loss = volume_loss + gripper_loss + rotation_loss
            else:
                loss = volume_loss + gripper_loss + rotation_loss

            total_loss += loss.item() * rgb.shape[0]
            total_volume_loss += volume_loss.item() * rgb.shape[0]
            total_gripper_loss += gripper_loss.item() * rgb.shape[0]

            # pred_2d is in pred_size space; scale back to IMAGE_SIZE for pixel error + 3D recovery
            pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
            pred_2d_full = pred_2d / coord_scale
            # Gripper at predicted pixel (inference mode) for error metric
            pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d)
            # 2-class gripper: argmax → class 0 (open, -1) or class 1 (close, +1)
            pred_gripper_class = pred_gripper_logits.argmax(dim=-1)  # (B, N)
            pred_gripper = pred_gripper_class.float() * 2.0 - 1.0   # 0→-1, 1→+1

            B, N, H, W = volume_logits.shape[0], volume_logits.shape[1], volume_logits.shape[3], volume_logits.shape[4]
            for t in range(N):
                pixel_error_t = torch.norm(pred_2d_full[:, t] - trajectory_2d[:, t], dim=1).sum()
                total_pixel_error += pixel_error_t.item()
            total_height_error += torch.abs(pred_height - target_height).mean(dim=1).sum().item()
            total_height_error_tf += 0.0
            total_gripper_error += torch.abs(pred_gripper - trajectory_gripper).mean(dim=1).sum().item()
            n_samples += rgb.shape[0]
            pbar.set_postfix(loss=f"{loss.item():.4f}", px=f"{(torch.norm(pred_2d_full[:, 0] - trajectory_2d[:, 0], dim=1).mean().item()):.2f}")

            if batch_idx == 0 and sample_data is None:
                pred_heatmaps = []
                for t in range(N_WINDOW):
                    vol_t = volume_logits[0, t].contiguous()  # (Nh, pred_size, pred_size)
                    vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                    heatmap_t = vol_probs.max(dim=0)[0]  # (pred_size, pred_size)
                    heatmap_up = F.interpolate(heatmap_t.unsqueeze(0).unsqueeze(0), size=(image_size, image_size), mode='bilinear', align_corners=False)[0, 0]
                    pred_heatmaps.append(heatmap_up)
                pred_heatmaps = torch.stack(pred_heatmaps)  # (N_WINDOW, image_size, image_size)

                pred_h_0 = pred_height[0]
                pred_g_0 = pred_gripper[0]
                if pred_h_0.dim() == 0:
                    pred_h_0 = pred_h_0.unsqueeze(0).expand(N_WINDOW)
                if pred_g_0.dim() == 0:
                    pred_g_0 = pred_g_0.unsqueeze(0).expand(N_WINDOW)

                cam_pose_np = camera_pose[0].cpu().numpy()
                cam_K_norm_np = cam_K_norm[0].cpu().numpy()
                cam_K_np = cam_K_norm_np.copy()
                cam_K_np[0] *= image_size
                cam_K_np[1] *= image_size
                pred_trajectory_3d_list = []
                for t in range(N_WINDOW):
                    px, py = pred_2d_full[0, t, 0].item(), pred_2d_full[0, t, 1].item()
                    h = pred_height[0, t].item()
                    pt = recover_3d_from_direct_keypoint_and_height(
                        np.array([px, py], dtype=np.float64), h, cam_pose_np, cam_K_np
                    )
                    if pt is not None:
                        pred_trajectory_3d_list.append(pt)
                    else:
                        pred_trajectory_3d_list.append(trajectory_3d[0, t].cpu().numpy())
                pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

                sample_data = {
                    'rgb': rgb[0],
                    'target_heatmap': batch['heatmap_target'][0].to(device),
                    'pred_heatmap': pred_heatmaps,
                    'trajectory_2d': trajectory_2d[0],
                    'trajectory_3d': trajectory_3d[0],
                    'trajectory_quat': batch['trajectory_quat'][0],
                    'rgb_frames_raw': batch['rgb_frames_raw'][0],
                    'world_to_camera': batch['world_to_camera'][0].cpu().numpy(),
                    'base_z': batch['base_z'][0],
                    'pred_trajectory_3d': pred_trajectory_3d_np,
                    'camera_pose': camera_pose[0].cpu().numpy(),
                    'cam_K_norm': cam_K_norm[0].cpu().numpy(),
                    'cam_K_at_size': cam_K_np,
                    'pred_height': pred_h_0,
                    'target_height': target_height[0],
                    'pred_gripper': pred_g_0,
                    'target_gripper': trajectory_gripper[0],
                }

    n = max(1, n_samples)
    avg_pixel_error = total_pixel_error / (n * N_WINDOW)
    return (
        total_loss / n, total_volume_loss / n, 0.0, total_gripper_loss / n,
        avg_pixel_error, total_height_error / n, total_height_error_tf / n, total_gripper_error / n,
        sample_data
    )


def _proj_world_to_vis(points_3d, world_to_camera, H, W):
    """Project (N,3) world points to pixel coords on the training image (flipud of obs).

    Matches debug_libero_projection.py exactly: project_points_from_world_to_camera returns
    (row, col) that can be drawn directly on flipud(obs_img) with NO additional row flip.
    Returns list of (u, v) = (col, row) ready for cv2 drawing.
    """
    pts = np.asarray(points_3d, dtype=np.float64)
    if pts.ndim == 1:
        pts = pts.reshape(1, 3)
    pix_rc = project_points_from_world_to_camera(
        points=pts,
        world_to_camera_transform=world_to_camera,
        camera_height=H,
        camera_width=W,
    )
    return [(int(round(float(rc[1]))), int(round(float(rc[0])))) for rc in pix_rc]  # (u=col, v=row)


def build_wandb_timestep_strip(sample, split_name):
    """Build a horizontal strip (one tile per timestep) matching debug_libero_projection.py style.

    Each tile shows the actual RGB frame at that timestep with:
      - predicted heatmap blended in red
      - predicted pixel: green crosshair
      - GT EEF projection: white filled circle + label
      - GT base-plane projection: cyan ring + yellow line to EEF
      - GT EEF rotation axes: red (x), green (y), blue (z) lines
    """
    if sample is None:
        return None

    world_to_camera = sample['world_to_camera']  # (4, 4) numpy
    trajectory_3d = sample['trajectory_3d']       # (N, 3) tensor
    trajectory_quat = sample['trajectory_quat']   # (N, 4) tensor
    trajectory_2d = sample['trajectory_2d']       # (N, 2) tensor  [x, y] upright
    rgb_frames_raw = sample['rgb_frames_raw']      # (N, H, W, 3) tensor float [0,1]
    base_z = float(sample['base_z'])

    tiles = []
    for t in range(N_WINDOW):
        frame = rgb_frames_raw[t].cpu().numpy()  # (H, W, 3) float [0,1]
        H, W = frame.shape[:2]

        # --- Heatmap overlay (red channel blend) ---
        pred_heatmap_t = sample['pred_heatmap'][t].detach().cpu().numpy()
        heat = pred_heatmap_t - pred_heatmap_t.min()
        if heat.max() > 1e-8:
            heat = heat / heat.max()
        heat_rgb = np.zeros_like(frame)
        heat_rgb[..., 0] = heat
        overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
        vis = (overlay * 255.0).astype(np.uint8)

        # --- Predicted pixel: green crosshair ---
        pred_y, pred_x = np.unravel_index(pred_heatmap_t.argmax(), pred_heatmap_t.shape)
        px, py = int(pred_x), int(pred_y)
        if 0 <= px < W and 0 <= py < H:
            cv2.drawMarker(vis, (px, py), (0, 255, 0), cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
            cv2.putText(vis, "pred", (px + 8, py - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1, cv2.LINE_AA)

        # --- GT EEF: project directly onto training image (flipud of obs), no extra row flip needed ---
        eef_pos = trajectory_3d[t].cpu().numpy().astype(np.float64)
        u, v = _proj_world_to_vis(eef_pos, world_to_camera, H, W)[0]
        if 0 <= u < W and 0 <= v < H:
            cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)
            cv2.putText(vis, "eef", (u + 8, v - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)

            # --- GT base-plane projection (same XY, z=robot base z) ---
            eef_base = eef_pos.copy()
            eef_base[2] = base_z
            ug, vg = _proj_world_to_vis(eef_base, world_to_camera, H, W)[0]
            if 0 <= ug < W and 0 <= vg < H:
                cv2.circle(vis, (ug, vg), 6, (0, 255, 255), 2)
                cv2.putText(vis, f"gt_base", (ug + 8, vg + 12), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 255), 1, cv2.LINE_AA)
                cv2.line(vis, (u, v), (ug, vg), (255, 255, 0), 2, cv2.LINE_AA)

            # --- Predicted base-plane projection (from predicted 3D keypoint) ---
            if 'pred_trajectory_3d' in sample:
                pred_3d_t = sample['pred_trajectory_3d'][t]
                pred_base = pred_3d_t.copy()
                pred_base[2] = base_z
                up, vp = _proj_world_to_vis(pred_base, world_to_camera, H, W)[0]
                if 0 <= up < W and 0 <= vp < H:
                    cv2.circle(vis, (up, vp), 6, (255, 128, 0), 2)  # orange ring
                    cv2.putText(vis, "pred_base", (up + 8, vp - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 128, 0), 1, cv2.LINE_AA)
                    # Line from predicted pixel to predicted base
                    if 0 <= px < W and 0 <= py < H:
                        cv2.line(vis, (px, py), (up, vp), (255, 128, 0), 1, cv2.LINE_AA)

            # --- EEF rotation axes: RGB colors matching debug_libero_projection.py ---
            eef_quat = trajectory_quat[t].cpu().numpy().astype(np.float64)
            eef_rot = T_rob.quat2mat(eef_quat)  # (3, 3)
            axis_len = 0.08
            axis_colors_rgb = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]  # x=red, y=green, z=blue (image is RGB)
            for i, color in enumerate(axis_colors_rgb):
                endpoint = eef_pos + eef_rot[:, i] * axis_len
                ua, va = _proj_world_to_vis(endpoint, world_to_camera, H, W)[0]
                if 0 <= ua < W and 0 <= va < H:
                    cv2.line(vis, (u, v), (ua, va), color, 2, cv2.LINE_AA)
                    cv2.circle(vis, (ua, va), 3, color, -1)

        # --- Timestep label ---
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)
        tiles.append(vis)

    strip = np.concatenate(tiles, axis=1)
    return wandb.Image(strip, caption=f"{split_name}: timesteps 0..{N_WINDOW-1} (left->right)")


def build_sample_data_for_logging(model, batch, device, image_size=IMAGE_SIZE, model_type="para"):
    """Build visualization sample dict from a single batch."""
    model.eval()
    with torch.no_grad():
        rgb = batch['rgb'][0:1].to(device)
        trajectory_2d = batch['trajectory_2d'][0:1].to(device)
        trajectory_3d = batch['trajectory_3d'][0:1].to(device)
        trajectory_gripper = batch['trajectory_gripper'][0:1].to(device)
        camera_pose = batch['camera_pose'][0].cpu().numpy()
        cam_K_norm = batch['cam_K_norm'][0].cpu().numpy()
        cam_K = cam_K_norm.copy()
        cam_K[0] *= image_size
        cam_K[1] *= image_size
        start_keypoint_2d = trajectory_2d[:, 0]

        # For wrist_only: swap in wrist data for visualization
        if model_type == "wrist_only" and 'wrist_rgb' in batch:
            rgb = batch['wrist_rgb'][0:1].to(device)
            trajectory_2d = batch['wrist_trajectory_2d'][0:1].to(device)
            start_keypoint_2d = trajectory_2d[:, 0]
            camera_pose = batch['wrist_camera_pose'][0].cpu().numpy()
            cam_K_norm = batch['wrist_cam_K_norm'][0].cpu().numpy()
            cam_K = cam_K_norm.copy()
            cam_K[0] *= image_size
            cam_K[1] *= image_size

        extra_kwargs = {}
        if model_type == "dino_vla" and 'clip_embedding' in batch:
            extra_kwargs['clip_embedding'] = batch['clip_embedding'][0:1].to(device)
        elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
            extra_kwargs['task_text'] = [batch['task_description'][0]]

        if is_dual_model(model_type):
            wrist_rgb = batch['wrist_rgb'][0:1].to(device)
            wrist_traj_2d = batch['wrist_trajectory_2d'][0:1].to(device)
            cs = PRED_SIZE / image_size
            out = model(rgb, wrist_rgb, start_keypoint_2d=start_keypoint_2d,
                        agent_query_pixels=trajectory_2d * cs,
                        wrist_query_pixels=wrist_traj_2d * cs)
            volume_logits = out['agent_volume']
            feats = out['agent_feats']
        elif model_type == "cost_volume":
            wrist_rgb = batch['wrist_rgb'][0:1].to(device)
            cs = PRED_SIZE / image_size
            volume_logits, _, _, feats = model(
                rgb, wrist_rgb, start_keypoint_2d=start_keypoint_2d,
                agent_query_pixels=trajectory_2d * cs,
                agent_cam_pose=batch['camera_pose'][0:1].to(device),
                agent_cam_K_norm=batch['cam_K_norm'][0:1].to(device),
                wrist_cam_pose=batch['wrist_camera_pose'][0:1].to(device),
                wrist_cam_K_norm=batch['wrist_cam_K_norm'][0:1].to(device),
            )
        else:
            volume_logits, _, _, feats = model(rgb, start_keypoint_2d, **extra_kwargs)

        pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
        pred_size = volume_logits.shape[-1]
        coord_scale = pred_size / image_size
        pred_2d_full = pred_2d / coord_scale
        if model_type == "cost_volume":
            # Extract height bins from volume argmax for predict_at_pixels
            _pred_hbins = []
            for _t in range(N_WINDOW):
                _vol_t = volume_logits[0, _t]
                _px = pred_2d[0, _t, 0].long().clamp(0, pred_size - 1)
                _py = pred_2d[0, _t, 1].long().clamp(0, pred_size - 1)
                _pred_hbins.append(_vol_t[:, _py, _px].argmax().item())
            _pred_hbins_t = torch.tensor(_pred_hbins, device=device).unsqueeze(0)
            pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d, _pred_hbins_t)
        elif is_dual_model(model_type):
            pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d, view_name="agent")
        else:
            pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d)
        pred_gripper_class = pred_gripper_logits.argmax(dim=-1)
        pred_gripper = pred_gripper_class.float() * 2.0 - 1.0

        # Heatmaps are pred_size×pred_size; upsample to image_size for visualization overlay
        pred_heatmaps = []
        for t in range(N_WINDOW):
            vol_t = volume_logits[0, t].contiguous()  # (Nh, pred_size, pred_size)
            vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
            heatmap_t = vol_probs.max(dim=0)[0]  # (pred_size, pred_size)
            heatmap_up = F.interpolate(heatmap_t.unsqueeze(0).unsqueeze(0), size=(image_size, image_size), mode='bilinear', align_corners=False)[0, 0]
            pred_heatmaps.append(heatmap_up)
        pred_heatmaps = torch.stack(pred_heatmaps)

        pred_trajectory_3d_list = []
        for t in range(N_WINDOW):
            px, py = pred_2d_full[0, t, 0].item(), pred_2d_full[0, t, 1].item()
            h = pred_height[0, t].item()
            pt = recover_3d_from_direct_keypoint_and_height(np.array([px, py], dtype=np.float64), h, camera_pose, cam_K)
            pred_trajectory_3d_list.append(pt if pt is not None else trajectory_3d[0, t].cpu().numpy())
        pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

        # For wrist_only: use wrist RGB as frame background
        if model_type == "wrist_only" and 'wrist_rgb' in batch:
            _mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            _std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            wrist_denorm = (batch['wrist_rgb'][0].cpu() * _std + _mean).clamp(0, 1).permute(1, 2, 0)
            rgb_frames_for_vis = wrist_denorm.unsqueeze(0).expand(N_WINDOW, -1, -1, -1).contiguous()
            w2c_for_vis = batch['wrist_world_to_camera'][0].cpu().numpy()
        else:
            rgb_frames_for_vis = batch['rgb_frames_raw'][0]
            w2c_for_vis = batch['world_to_camera'][0].cpu().numpy()

        result = {
            'rgb': rgb[0],
            'target_heatmap': batch['heatmap_target'][0].to(device),
            'pred_heatmap': pred_heatmaps,
            'trajectory_2d': trajectory_2d[0],
            'trajectory_3d': trajectory_3d[0],
            'trajectory_quat': batch['trajectory_quat'][0],
            'rgb_frames_raw': rgb_frames_for_vis,
            'world_to_camera': w2c_for_vis,
            'base_z': batch['base_z'][0],
            'pred_trajectory_3d': pred_trajectory_3d_np,
            'camera_pose': camera_pose,
            'cam_K_at_size': cam_K,
            'pred_height': pred_height[0],
            'target_height': trajectory_3d[0, :, 2],
            'pred_gripper': pred_gripper[0],
            'target_gripper': trajectory_gripper[0],
        }

        # Add wrist view for dual models
        if is_dual_model(model_type) and 'wrist_volume' in out:
            wrist_vol = out['wrist_volume']
            w_pred_2d, w_pred_height = extract_pred_2d_and_height_from_volume(wrist_vol)
            w_pred_heatmaps = []
            for t in range(N_WINDOW):
                vol_t = wrist_vol[0, t].contiguous()
                vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
                hm_t = vol_probs.max(dim=0)[0]
                hm_up = F.interpolate(hm_t.unsqueeze(0).unsqueeze(0),
                                      size=(image_size, image_size), mode='bilinear', align_corners=False)[0, 0]
                w_pred_heatmaps.append(hm_up)
            result['wrist_rgb'] = wrist_rgb[0]
            result['wrist_pred_heatmap'] = torch.stack(w_pred_heatmaps)
            result['wrist_trajectory_2d'] = wrist_traj_2d[0]
            result['wrist_world_to_camera'] = batch['wrist_world_to_camera'][0].cpu().numpy()
            result['wrist_in_view'] = batch['wrist_in_view'][0].cpu().numpy()  # (N_WINDOW,)

        return result


def main():
    parser = argparse.ArgumentParser(description="Train PARA trajectory heatmap predictor on LIBERO")
    parser.add_argument("--model_type", type=str, default="para",
                        choices=["para", "act", "da3", "moge", "dino_vla", "internvl", "internvl_act", "dual_da3", "dual_para", "wrist_only", "cost_volume"],
                        help="Model architecture to train")
    parser.add_argument("--model_name", type=str, default="OpenGVLab/InternVL2_5-1B",
                        help="HuggingFace model name (used by internvl model_type)")
    parser.add_argument("--benchmark", type=str, default="libero_spatial",
                        help="LIBERO benchmark name")
    parser.add_argument("--task_id", type=int, default=0,
                        help="Task index within benchmark")
    parser.add_argument("--task_ids", type=str, default="",
                        help="Comma-separated task indices to train on all at once (e.g. '0,1,2' or 'all'). Overrides --task_id.")
    parser.add_argument("--camera", type=str, default="agentview",
                        help="Camera name used for training observations")
    parser.add_argument("--max_demos", type=int, default=0,
                        help="Maximum demos to load from task dataset")
    parser.add_argument("--val_split", type=float, default=0.05,
                        help="Fraction of episodes to use for validation")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE,
                        help="Batch size for training")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                        help="Learning rate")
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS,
                        help="Number of epochs")
    parser.add_argument("--checkpoint", type=str, default="",
                        help="Path to checkpoint to resume from")
    parser.add_argument("--run_name", type=str, default="para_libero",
                        help="Name of run (used for checkpoint paths and W&B)")
    parser.add_argument("--wandb_project", type=str, default="para_libero",
                        help="W&B project name")
    parser.add_argument("--wandb_entity", type=str, default=None,
                        help="W&B entity/team (optional)")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"],
                        help="W&B mode")
    parser.add_argument("--stats_cache_path", type=str, default="",
                        help="Path to JSON cache for height/gripper stats")
    parser.add_argument("--stats_sample_limit", type=int, default=500,
                        help="Random number of samples to use for stats computation (0 = full dataset)")
    parser.add_argument("--stats_seed", type=int, default=42,
                        help="Random seed for stats subsampling")
    parser.add_argument("--vis_every_steps", type=int, default=50,
                        help="Log visualization images every N train steps")
    parser.add_argument("--frame_stride", type=int, default=3,
                        help="Sample every Nth frame from the demo (default 3 → ~6.7Hz @ 20Hz, N_WINDOW=4 spans ~0.6s)")
    parser.add_argument("--cache_root", type=str, default="",
                        help="Path to pre-rendered dataset (e.g. /data/libero/parsed_libero). Uses CachedTrajectoryDataset when set.")
    parser.add_argument("--pos_loss_weight", type=float, default=1.0,
                        help="Weight for position loss (ACT: normalizes pos MSE to match rot/grip)")
    parser.add_argument("--volume_loss_weight", type=float, default=1.0,
                        help="Weight for volume/heatmap CE loss (PARA: scale relative to grip/rot CE)")
    parser.add_argument("--gripper_loss_weight", type=float, default=5.0,
                        help="Weight for gripper CE loss")
    parser.add_argument("--rotation_loss_weight", type=float, default=0.5,
                        help="Weight for rotation CE loss")
    parser.add_argument("--act_gripper_loss_weight", type=float, default=1.0,
                        help="Weight for ACT gripper BCE loss")
    parser.add_argument("--save_every_steps", type=int, default=1000,
                        help="Save checkpoint every N training steps (0 = only save per-epoch)")
    parser.add_argument("--max_minutes", type=float, default=0,
                        help="Stop training after N minutes (0 = no limit, use epochs)")
    parser.add_argument("--async_eval_every", type=int, default=500,
                        help="Launch background eval every N training steps (0 = disabled)")
    parser.add_argument("--no_ema_loss", action="store_true",
                        help="Disable EMA adaptive loss weighting (use equal weights)")
    parser.add_argument("--skip_rotation", action="store_true",
                        help="Skip rotation loss entirely (for zero-rotation tasks)")
    parser.add_argument("--overfit_one_sample", action="store_true",
                        help="Overfit to a single dataset sample (sanity check)")
    parser.add_argument("--pretrained_backbone", type=str, default="",
                        help="Path to point-track pretrained checkpoint (e.g. point_track_pretraining/checkpoints/.../best.pth). "
                             "Loads only the DINO backbone weights (dino.* keys) from the checkpoint, "
                             "initializing all other heads randomly. Use this for the pretrain→finetune pipeline.")
    args = parser.parse_args()

    # Losses are auto-normalized by log(N_classes) — no manual weights needed

    # BUG FIX: checkpoint dir is now relative to script location (libero/checkpoints/),
    # not hardcoded to the old volume_dino_tracks/ path
    script_dir = Path(__file__).parent
    CHECKPOINT_DIR = script_dir / "checkpoints" / args.run_name
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    stats_cache_path = Path(args.stats_cache_path) if args.stats_cache_path else (CHECKPOINT_DIR / "dataset_stats.json")

    # Device: MPS on Mac, CUDA on server
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps" if torch.backends.mps.is_available() else
        "cpu"
    )
    print(f"Using device: {device}")

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.run_name,
        mode=args.wandb_mode,
        config={
            "benchmark": args.benchmark,
            "task_id": args.task_id,
            "camera": args.camera,
            "max_demos": args.max_demos,
            "val_split": args.val_split,
            "batch_size": args.batch_size,
            "lr": args.lr,
            "epochs": args.epochs,
            "image_size": IMAGE_SIZE,
            "n_window": N_WINDOW,
            "n_height_bins": N_HEIGHT_BINS,
            "n_gripper_bins": N_GRIPPER_BINS,
        },
    )

    print("\nLoading dataset...")
    if args.cache_root:
        # Use pre-rendered cached dataset (fast, supports num_workers>0)
        task_id_list = None
        if args.task_ids:
            if args.task_ids.strip().lower() == "all":
                task_id_list = None  # CachedTrajectoryDataset loads all by default
            else:
                task_id_list = [int(x) for x in args.task_ids.split(",")]
        else:
            task_id_list = [args.task_id]
        full_dataset = CachedTrajectoryDataset(
            cache_root=args.cache_root,
            benchmark_name=args.benchmark,
            task_ids=task_id_list,
            image_size=IMAGE_SIZE,
            n_window=N_WINDOW,
            frame_stride=args.frame_stride,
            max_demos=args.max_demos,
        )
        dataset_source = {"cache_root": args.cache_root, "benchmark": args.benchmark, "task_ids": task_id_list, "frame_stride": args.frame_stride}
    elif args.task_ids:
        from libero.libero import benchmark as _bm
        n_tasks = _bm.get_benchmark_dict()[args.benchmark]().get_num_tasks()
        if args.task_ids.strip().lower() == "all":
            task_id_list = list(range(n_tasks))
        else:
            task_id_list = [int(x) for x in args.task_ids.split(",")]
        print(f"  Multi-task mode: tasks {task_id_list}")
        datasets = []
        for tid in task_id_list:
            ds = RealTrajectoryDataset(
                image_size=IMAGE_SIZE,
                benchmark_name=args.benchmark,
                task_id=tid,
                camera=args.camera,
                max_demos=args.max_demos,
                frame_stride=args.frame_stride,
            )
            datasets.append(ds)
        full_dataset = ConcatDataset(datasets)
        dataset_source = {"benchmark": args.benchmark, "task_ids": task_id_list, "camera": args.camera, "max_demos": args.max_demos}
    else:
        full_dataset = RealTrajectoryDataset(
            image_size=IMAGE_SIZE,
            benchmark_name=args.benchmark,
            task_id=args.task_id,
            camera=args.camera,
            max_demos=args.max_demos,
            frame_stride=args.frame_stride,
        )
        dataset_source = {
            "benchmark": args.benchmark,
            "task_id": args.task_id,
            "camera": args.camera,
            "max_demos": args.max_demos,
        }
    print(f"  Source: {dataset_source}")
    print(f"  Total: {len(full_dataset)} samples")

    if args.overfit_one_sample:
        # Grab one sample and repeat it — useful for verifying the supervision signal
        print("\n⚠ OVERFIT MODE: using a single repeated sample for train and val")
        one_sample = full_dataset[0]
        class _RepeatDataset(torch.utils.data.Dataset):
            def __init__(self, sample, n=2000):
                self.sample = sample
                self.n = n
            def __len__(self):
                return self.n
            def __getitem__(self, _):
                return self.sample
        train_dataset = _RepeatDataset(one_sample, n=2000)
        val_dataset   = _RepeatDataset(one_sample, n=10)
    else:
        dataset_size = len(full_dataset)
        val_size = max(1, int(dataset_size * args.val_split))
        train_size = dataset_size - val_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )

    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

    print(f"\nInitializing model (type={args.model_type})...")
    ModelClass = get_model_class(args.model_type)
    model_kwargs = dict(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    if args.model_type in ("internvl", "internvl_act"):
        model_kwargs["model_name"] = args.model_name
    model = ModelClass(**model_kwargs)
    model = model.to(device)

    # Load pretrained backbone from point-track pretraining (before checkpoint resume)
    if args.pretrained_backbone and os.path.exists(args.pretrained_backbone):
        print(f"\nLoading pretrained backbone from: {args.pretrained_backbone}")
        pt_ckpt = torch.load(args.pretrained_backbone, map_location=device)
        pt_state = pt_ckpt['model_state_dict']
        model_dict = model.state_dict()
        # Transfer only dino.* keys (the shared backbone)
        backbone_keys = {k: v for k, v in pt_state.items() if k.startswith("dino.")}
        loaded = []
        skipped = []
        for k, v in backbone_keys.items():
            if k in model_dict and v.shape == model_dict[k].shape:
                model_dict[k] = v
                loaded.append(k)
            else:
                skipped.append(k)
        model.load_state_dict(model_dict, strict=False)
        print(f"  Loaded {len(loaded)} backbone keys from pretrained checkpoint")
        if skipped:
            print(f"  Skipped {len(skipped)} keys (shape mismatch or not in model): {skipped[:5]}")
        pt_epoch = pt_ckpt.get('epoch', '?')
        pt_val_loss = pt_ckpt.get('val_loss', '?')
        print(f"  Pretrained checkpoint: epoch={pt_epoch}, val_loss={pt_val_loss}")
    elif args.pretrained_backbone:
        print(f"\nWARNING: pretrained backbone not found: {args.pretrained_backbone}")

    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {n_trainable:,} / {n_total:,} ({100*n_trainable/n_total:.2f}%)")

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=1e-4
    )

    start_epoch = 0
    checkpoint = None
    checkpoint_height_values = None
    checkpoint_gripper_values = None
    checkpoint_path = args.checkpoint
    if args.checkpoint:
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
    if args.checkpoint and os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        model_state = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        filtered_state = {}
        shape_mismatches = []
        for k, v in model_state.items():
            if k in model_dict:
                if v.shape == model_dict[k].shape:
                    filtered_state[k] = v
                else:
                    shape_mismatches.append(f"{k}: checkpoint {v.shape} vs model {model_dict[k].shape}")

        missing_keys = set(model_dict.keys()) - set(model_state.keys())
        unexpected_keys = set(model_state.keys()) - set(model_dict.keys())

        if missing_keys:
            print(f"⚠ Missing keys (random init): {sorted(missing_keys)}")
        if unexpected_keys:
            print(f"⚠ Unexpected keys (ignored): {sorted(unexpected_keys)}")
        if shape_mismatches:
            print(f"⚠ Shape mismatches (random init):")
            for msg in shape_mismatches:
                print(f"    {msg}")

        model_dict.update(filtered_state)
        model.load_state_dict(model_dict, strict=False)

        if shape_mismatches:
            print(f"⚠ Skipping optimizer state (model shape changed)")
        else:
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            except Exception as e:
                print(f"⚠ Could not load optimizer state: {e}")

        start_epoch = checkpoint.get('epoch', 0) + 1

        if 'min_height' in checkpoint and 'max_height' in checkpoint:
            checkpoint_height_values = (checkpoint['min_height'], checkpoint['max_height'])
            print(f"✓ Height range from checkpoint: [{checkpoint_height_values[0]:.6f}, {checkpoint_height_values[1]:.6f}] m")
        if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
            checkpoint_gripper_values = (checkpoint['min_gripper'], checkpoint['max_gripper'])
            print(f"✓ Gripper range from checkpoint: [{checkpoint_gripper_values[0]:.6f}, {checkpoint_gripper_values[1]:.6f}]")
        print(f"✓ Resumed from epoch {start_epoch}")
    elif args.checkpoint:
        print(f"\n⚠ Checkpoint not found: {args.checkpoint}")

    stats_cache = None
    if stats_cache_path.exists():
        try:
            with open(stats_cache_path, "r", encoding="utf-8") as f:
                stats_cache = json.load(f)
            print(f"✓ Loaded stats cache: {stats_cache_path}")
        except Exception as e:
            print(f"⚠ Failed to read stats cache: {e}")
            stats_cache = None

    def compute_stats_from_dataset():
        total_len = len(train_dataset) + len(val_dataset)
        if args.stats_sample_limit and args.stats_sample_limit > 0:
            sample_count = min(args.stats_sample_limit, total_len)
            rng = random.Random(args.stats_seed)
            print(f"\nComputing dataset stats from random subset: {sample_count}/{total_len} samples (seed={args.stats_seed})")
        else:
            sample_count = total_len
            print(f"\nComputing dataset stats from full dataset: {total_len} samples")

        all_heights = []
        all_grippers = []
        all_eulers = []   # list of (3,) arrays
        all_positions = []  # list of (N, 3) arrays for per-axis pos min/max
        all_quats = []     # list of (N, 4) arrays for reference rotation computation
        seen = set()
        success_count = 0
        max_attempts = max(sample_count * 20, total_len if sample_count == total_len else sample_count)
        attempts = 0
        pbar = tqdm(total=sample_count, desc="Stats subset", leave=False)
        while success_count < sample_count and attempts < max_attempts:
            attempts += 1
            if sample_count == total_len:
                global_idx = success_count
            else:
                global_idx = rng.randrange(total_len)
                if global_idx in seen:
                    continue
            seen.add(global_idx)
            dataset = train_dataset if global_idx < len(train_dataset) else val_dataset
            local_idx = global_idx if global_idx < len(train_dataset) else (global_idx - len(train_dataset))
            try:
                sample = dataset[local_idx]
            except Exception:
                continue
            trajectory_3d = sample['trajectory_3d'].numpy()
            trajectory_gripper = sample['trajectory_gripper'].numpy()
            trajectory_euler = sample['trajectory_euler'].numpy()  # (N, 3)
            trajectory_quat = sample['trajectory_quat'].numpy()  # (N, 4)
            all_heights.extend(trajectory_3d[:, 2].tolist())
            all_grippers.extend(trajectory_gripper.tolist())
            all_eulers.append(trajectory_euler)
            all_positions.append(trajectory_3d)
            all_quats.append(trajectory_quat)
            success_count += 1
            pbar.update(1)
        pbar.close()

        if len(all_heights) == 0 or len(all_grippers) == 0:
            print("⚠ No valid samples found for stats; falling back to model defaults.")
            return {
                "min_height": float(model_module.MIN_HEIGHT),
                "max_height": float(model_module.MAX_HEIGHT),
                "min_gripper": float(model_module.MIN_GRIPPER),
                "max_gripper": float(model_module.MAX_GRIPPER),
                "min_rot": model_module.MIN_ROT,
                "max_rot": model_module.MAX_ROT,
                "num_height_values": 0,
                "num_gripper_values": 0,
            }

        all_heights_np  = np.array(all_heights, dtype=np.float64)
        all_grippers_np = np.array(all_grippers, dtype=np.float64)
        all_eulers_np   = np.concatenate(all_eulers, axis=0)  # (total_N, 3)
        all_positions_np = np.concatenate(all_positions, axis=0)  # (total_N, 3)
        all_quats_np = np.concatenate(all_quats, axis=0)  # (total_N, 4)

        # Use first sample as reference rotation (avoids quaternion sign ambiguity in mean)
        from scipy.spatial.transform import Rotation as ScipyR
        ref_quat = all_quats_np[0].copy()
        ref_rot = ScipyR.from_quat(ref_quat)
        # Compute delta axis-angle from reference for all samples
        all_delta_rotvec = np.stack([
            (ref_rot.inv() * ScipyR.from_quat(q)).as_rotvec() for q in all_quats_np
        ], axis=0)

        return {
            "min_height":  float(all_heights_np.min()),
            "max_height":  float(all_heights_np.max()),
            "min_gripper": float(all_grippers_np.min()),
            "max_gripper": float(all_grippers_np.max()),
            "min_rot":     all_delta_rotvec.min(axis=0).tolist(),  # delta axis-angle min per axis
            "max_rot":     all_delta_rotvec.max(axis=0).tolist(),  # delta axis-angle max per axis
            "ref_rotation_quat": ref_quat.tolist(),  # reference rotation as quaternion
            "min_pos":     all_positions_np.min(axis=0).tolist(),
            "max_pos":     all_positions_np.max(axis=0).tolist(),
            "num_height_values":  int(all_heights_np.size),
            "num_gripper_values": int(all_grippers_np.size),
            "dataset_source": dataset_source,
            "n_window":   int(N_WINDOW),
            "image_size": int(IMAGE_SIZE),
        }

    # Resolve height range from checkpoint > cache > dataset
    if checkpoint_height_values is not None:
        model_module.MIN_HEIGHT, model_module.MAX_HEIGHT = checkpoint_height_values
        print(f"✓ Using height range from checkpoint: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
            print(f"✓ Saved stats cache: {stats_cache_path}")
        model_module.MIN_HEIGHT = float(stats_cache["min_height"])
        model_module.MAX_HEIGHT = float(stats_cache["max_height"])
        print(f"✓ Height range from dataset: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print("  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT — all height predictions will be constant!")

    # Resolve gripper range from checkpoint > cache > dataset
    if checkpoint_gripper_values is not None:
        model_module.MIN_GRIPPER, model_module.MAX_GRIPPER = checkpoint_gripper_values
        print(f"✓ Using gripper range from checkpoint: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
            print(f"✓ Saved stats cache: {stats_cache_path}")
        model_module.MIN_GRIPPER = float(stats_cache["min_gripper"])
        model_module.MAX_GRIPPER = float(stats_cache["max_gripper"])
        print(f"✓ Gripper range from dataset: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")

    # Resolve rotation range + reference rotation from checkpoint > cache > dataset
    checkpoint_rot_values = None
    if checkpoint is not None and 'min_rot' in checkpoint and 'max_rot' in checkpoint:
        checkpoint_rot_values = (checkpoint['min_rot'], checkpoint['max_rot'])
    if checkpoint_rot_values is not None:
        model_module.MIN_ROT, model_module.MAX_ROT = checkpoint_rot_values
        print(f"✓ Rotation range from checkpoint: {model_module.MIN_ROT} .. {model_module.MAX_ROT}")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
        if "min_rot" in stats_cache:
            model_module.MIN_ROT = stats_cache["min_rot"]
            model_module.MAX_ROT = stats_cache["max_rot"]
            print(f"✓ Rotation range (delta rotvec): {[f'{v:.3f}' for v in model_module.MIN_ROT]} .. {[f'{v:.3f}' for v in model_module.MAX_ROT]}")
    # Reference rotation
    if checkpoint is not None and 'ref_rotation_quat' in checkpoint:
        model_module.REF_ROTATION_QUAT = checkpoint['ref_rotation_quat']
    elif stats_cache is not None and 'ref_rotation_quat' in stats_cache:
        model_module.REF_ROTATION_QUAT = stats_cache['ref_rotation_quat']
    print(f"✓ Reference rotation: {[f'{v:.4f}' for v in model_module.REF_ROTATION_QUAT]}")

    # Resolve position range from checkpoint > cache > dataset
    checkpoint_pos_values = None
    if checkpoint is not None and 'min_pos' in checkpoint and 'max_pos' in checkpoint:
        checkpoint_pos_values = (checkpoint['min_pos'], checkpoint['max_pos'])
    if checkpoint_pos_values is not None:
        model_module.MIN_POS, model_module.MAX_POS = checkpoint_pos_values
        print(f"✓ Position range from checkpoint: {model_module.MIN_POS} .. {model_module.MAX_POS}")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
        if "min_pos" in stats_cache:
            model_module.MIN_POS = stats_cache["min_pos"]
            model_module.MAX_POS = stats_cache["max_pos"]
            print(f"✓ Position range from dataset: {[f'{v:.3f}' for v in model_module.MIN_POS]} .. {[f'{v:.3f}' for v in model_module.MAX_POS]}")

    print(f"\nStarting training for {args.epochs} epochs...")
    best_val_loss = float('inf')
    global_step = 0

    global SKIP_ROTATION
    SKIP_ROTATION = args.skip_rotation
    if SKIP_ROTATION:
        print("✓ Rotation loss SKIPPED")

    # EMA-based adaptive loss weighting: initialize with expected raw CE values
    if args.no_ema_loss:
        loss_emas = None
        print("✓ EMA loss weighting DISABLED (equal weights)")
    else:
        loss_emas = {'vol': 11.8, 'rot': 3.5, 'grip': 0.69}  # ~log(N_classes) initial estimates
        print(f"✓ EMA loss weights initialized (vol={loss_emas['vol']:.1f}, rot={loss_emas['rot']:.1f}, grip={loss_emas['grip']:.2f})")

    def _build_wrist_strip(sample, split_name):
        """Build a wrist view strip for dual models (same layout as agentview strip)."""
        if 'wrist_pred_heatmap' not in sample:
            return None
        wrist_sample = dict(sample)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        wrist_denorm = (sample['wrist_rgb'].cpu() * std + mean).clamp(0, 1)
        wrist_hwc = wrist_denorm.permute(1, 2, 0)
        wrist_sample['rgb_frames_raw'] = wrist_hwc.unsqueeze(0).expand(N_WINDOW, -1, -1, -1).contiguous()
        wrist_sample['pred_heatmap'] = sample['wrist_pred_heatmap']
        wrist_sample['trajectory_2d'] = sample['wrist_trajectory_2d']
        if 'wrist_world_to_camera' in sample:
            wrist_sample['world_to_camera'] = sample['wrist_world_to_camera']
        strip = build_wandb_timestep_strip(wrist_sample, f"{split_name}_wrist")
        # Overlay "OUT OF FRUSTUM" on tiles where target is not in view
        if strip is not None and 'wrist_in_view' in sample:
            try:
                in_view = sample['wrist_in_view']
                # Extract numpy array from wandb.Image
                pil_img = strip._image
                arr = np.array(pil_img)
                if arr.ndim == 3:
                    tile_w = arr.shape[1] // N_WINDOW
                    for t in range(N_WINDOW):
                        if in_view[t] < 0.5:
                            x0 = t * tile_w
                            arr[0:30, x0:x0+tile_w, 0] = np.clip(arr[0:30, x0:x0+tile_w, 0].astype(int) + 80, 0, 255).astype(np.uint8)
                            cv2.putText(arr, "OUT OF FRUSTUM", (x0 + 5, 20),
                                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 50, 50), 1, cv2.LINE_AA)
                    strip = wandb.Image(arr)
            except Exception:
                pass  # skip overlay if it fails
        return strip

    def log_visualizations(step):
        if not is_heatmap_model(args.model_type):
            return  # skip heatmap-specific visualization for ACT
        train_batch = next(iter(train_loader))
        val_batch = next(iter(val_loader))
        sample_train = build_sample_data_for_logging(model, train_batch, device, image_size=IMAGE_SIZE, model_type=args.model_type)
        sample_val_local = build_sample_data_for_logging(model, val_batch, device, image_size=IMAGE_SIZE, model_type=args.model_type)
        payload = {}
        train_strip = build_wandb_timestep_strip(sample_train, "train")
        val_strip = build_wandb_timestep_strip(sample_val_local, "val")
        if train_strip is not None:
            payload["vis/train_strip"] = train_strip
        if val_strip is not None:
            payload["vis/val_strip"] = val_strip
        # Dual model: add wrist view strips
        if is_dual_model(args.model_type):
            wrist_train = _build_wrist_strip(sample_train, "train")
            wrist_val = _build_wrist_strip(sample_val_local, "val")
            if wrist_train is not None:
                payload["vis/wrist_train_strip"] = wrist_train
            if wrist_val is not None:
                payload["vis/wrist_val_strip"] = wrist_val
        if payload:
            wandb.log(payload, step=step)

    if args.vis_every_steps > 0:
        log_visualizations(0)

    def save_step_checkpoint(step):
        """Save a mid-epoch checkpoint at a given global step."""
        ckpt = {
            'global_step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'min_height':  model_module.MIN_HEIGHT,
            'max_height':  model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
            'min_rot':     model_module.MIN_ROT,
            'max_rot':     model_module.MAX_ROT,
            'min_pos':     model_module.MIN_POS,
            'max_pos':     model_module.MAX_POS,
        }
        torch.save(ckpt, CHECKPOINT_DIR / f'step_{step}.pth')
        # Also overwrite latest so eval scripts can find it
        torch.save(ckpt, CHECKPOINT_DIR / 'latest.pth')
        print(f"\n  Saved step checkpoint: step_{step}.pth")

    # --- Async eval: launch background eval subprocess every N steps ---
    scratch_dir = script_dir / "logs" / "scratch"
    scratch_dir.mkdir(parents=True, exist_ok=True)
    _last_logged_eval_video = [None]  # mutable container for closure
    _eval_proc = [None]  # track running eval subprocess

    def maybe_launch_async_eval(step):
        """Save scratch checkpoint and launch eval.py in background."""
        if args.async_eval_every <= 0:
            return
        if step % args.async_eval_every != 0 or step == 0:
            return
        # Don't launch if previous eval is still running
        if _eval_proc[0] is not None and _eval_proc[0].poll() is None:
            return
        # Save scratch checkpoint
        scratch_ckpt = CHECKPOINT_DIR / "scratch_eval.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'min_height': model_module.MIN_HEIGHT, 'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER, 'max_gripper': model_module.MAX_GRIPPER,
            'min_rot': model_module.MIN_ROT, 'max_rot': model_module.MAX_ROT,
            'ref_rotation_quat': model_module.REF_ROTATION_QUAT,
            'min_pos': model_module.MIN_POS, 'max_pos': model_module.MAX_POS,
        }, scratch_ckpt)
        video_path = scratch_dir / f"eval_{step}_{args.run_name}.mp4"
        eval_cmd = (
            f"PYTHONPATH={os.environ.get('PYTHONPATH', '')} "
            f"DINO_REPO_DIR={os.environ.get('DINO_REPO_DIR', '')} "
            f"DINO_WEIGHTS_PATH={os.environ.get('DINO_WEIGHTS_PATH', '')} "
            f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', '')} "
            f"python {script_dir}/eval.py "
            f"--model_type {args.model_type} "
            f"--checkpoint {scratch_ckpt} "
            f"--benchmark {args.benchmark} "
            f"--task_id {args.task_id if not args.task_ids else 0} "
            f"--n_episodes 1 --save_video "
            f"--teleport --zero_rotation --max_steps 600 "
            f"--out_dir {scratch_dir}/eval_run "
            f"--clip_embeddings_dir /data/libero/parsed_libero"
        )
        _eval_proc[0] = subprocess.Popen(eval_cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        print(f"\n  [async eval] launched at step {step}")

    def maybe_log_eval_video(step):
        """Check if a new eval video exists and log it to wandb."""
        if args.async_eval_every <= 0:
            return
        # Look for the latest eval video
        video_dir = scratch_dir / "eval_run" / "videos" / f"task_{args.task_id if not args.task_ids else 0}"
        if not video_dir.exists():
            return
        videos = sorted(video_dir.glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
        if not videos:
            return
        latest = videos[0]
        if _last_logged_eval_video[0] != latest.stat().st_mtime:
            _last_logged_eval_video[0] = latest.stat().st_mtime
            try:
                # Re-encode to H.264 for browser compatibility
                h264_path = latest.with_suffix('.h264.mp4')
                os.system(f"ffmpeg -y -i {latest} -vcodec libx264 -pix_fmt yuv420p -loglevel error {h264_path}")
                vid_path = str(h264_path) if h264_path.exists() else str(latest)
                wandb.log({"eval/rollout_video": wandb.Video(vid_path, fps=10, format="mp4")}, step=step)
                print(f"\n  [async eval] logged video to wandb: {latest.name}")
            except Exception as e:
                print(f"\n  [async eval] failed to log video: {e}")

    # Wrap existing vis_callback to also handle async eval
    _orig_vis_callback = log_visualizations
    def vis_and_eval_callback(step):
        _orig_vis_callback(step)
        maybe_launch_async_eval(step)
        maybe_log_eval_video(step)

    training_start_time = time.time()

    for epoch in tqdm(range(start_epoch, args.epochs), desc="Epochs"):
        # Time limit check
        if args.max_minutes > 0:
            elapsed = (time.time() - training_start_time) / 60.0
            if elapsed >= args.max_minutes:
                print(f"\n⏱ Time limit reached ({elapsed:.1f} / {args.max_minutes:.0f} min). Stopping.")
                break

        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{args.epochs}")
        print(f"{'='*60}")

        train_loss, train_volume_loss, train_gripper_loss, train_rotation_loss, global_step = train_epoch(
            model, train_loader, optimizer, device,
            just_heatmap=False,
            global_step_start=global_step,
            vis_every_steps=args.vis_every_steps,
            vis_callback=vis_and_eval_callback,
            model_type=args.model_type,
            save_every_steps=args.save_every_steps,
            save_callback=save_step_checkpoint,
            loss_emas=loss_emas,
        )
        print(f"Train Loss: {train_loss:.4f} (Volume: {train_volume_loss:.4f}, Gripper: {train_gripper_loss:.4f}, Rotation: {train_rotation_loss:.4f})")

        val_loss, val_heatmap_loss, val_height_loss, val_gripper_loss, \
        val_error, val_height_error, val_height_error_tf, val_gripper_error, sample_val = validate(
            model, val_loader, device, model_type=args.model_type, loss_emas=loss_emas,
        )
        print(f"Val - Loss: {val_loss:.4f}, Volume: {val_heatmap_loss:.4f}, Pixel Error: {val_error:.2f}px, Height Error: {val_height_error*1000:.3f}mm, Gripper: {val_gripper_error:.4f}")

        wandb.log({
            "epoch": epoch,
            "train/loss": train_loss,
            "train/volume_loss": train_volume_loss,
            "train/gripper_loss": train_gripper_loss,
            "train/rotation_loss": train_rotation_loss,
            "val/loss": val_loss,
            "val/volume_loss": val_heatmap_loss,
            "val/gripper_loss": val_gripper_loss,
            "val/pixel_error": val_error,
            "val/height_error_mm": val_height_error * 1000.0,
            "val/gripper_abs_error": val_gripper_error,
        }, step=global_step)

        checkpoint_data = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height':  model_module.MIN_HEIGHT,
            'max_height':  model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
            'min_rot':     model_module.MIN_ROT,
            'max_rot':     model_module.MAX_ROT,
            'ref_rotation_quat': model_module.REF_ROTATION_QUAT,
            'min_pos':     model_module.MIN_POS,
            'max_pos':     model_module.MAX_POS,
        }

        torch.save(checkpoint_data, CHECKPOINT_DIR / 'latest.pth')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint_data, CHECKPOINT_DIR / 'best.pth')
            print(f"✓ Saved best model (val_loss={val_loss:.4f})")

    wandb.finish()
    print("\n" + "=" * 60)
    print("✓ Training complete!")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints saved to: {CHECKPOINT_DIR}")


if __name__ == "__main__":
    main()
