"""Train trajectory volume predictor on real data.

Model predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (CE at trajectory pixel only).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel (teacher forcing), decoded at pred pixel in val/inference.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from pathlib import Path
from tqdm import tqdm
import argparse

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS, DEFAULT_SMOLVLA_CKPT
import model as model_module  # Import module to access updated MIN_HEIGHT/MAX_HEIGHT at runtime
from utils import recover_3d_from_direct_keypoint_and_height

# Helper functions for discretization
def discretize_height(height_values):
    """Discretize continuous height values into bin indices.
    
    Args:
        height_values: (B, N_WINDOW) or (N_WINDOW,) tensor of heights in [MIN_HEIGHT, MAX_HEIGHT]
    
    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_HEIGHT_BINS-1]
    """
    # Access MIN_HEIGHT/MAX_HEIGHT from model module at runtime (updated by train.py)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    # Normalize to [0, 1]
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    # Map to bin indices [0, N_HEIGHT_BINS-1]
    bin_indices = (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)
    return bin_indices


def discretize_gripper(gripper_values):
    """Discretize continuous gripper values into bin indices.
    
    Args:
        gripper_values: (B, N_WINDOW) or (N_WINDOW,) tensor of gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    
    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_GRIPPER_BINS-1]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by train.py)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Normalize to [0, 1]
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    # Map to bin indices [0, N_GRIPPER_BINS-1]
    bin_indices = (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)
    return bin_indices


def decode_height_bins(bin_logits):
    """Decode height bin logits back to continuous height values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_HEIGHT_BINS) logits for each bin
    
    Returns:
        height_values: (B, N_WINDOW) continuous height values in [MIN_HEIGHT, MAX_HEIGHT]
    """
    # Access MIN_HEIGHT/MAX_HEIGHT from model module at runtime (updated by train.py)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=bin_logits.device)  # (N_HEIGHT_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_HEIGHT, MAX_HEIGHT]
    height_values = normalized * (max_h - min_h) + min_h
    return height_values


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits back to continuous gripper values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_GRIPPER_BINS) logits for each bin
    
    Returns:
        gripper_values: (B, N_WINDOW) continuous gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by train.py)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)  # (N_GRIPPER_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_GRIPPER, MAX_GRIPPER]
    gripper_values = normalized * (max_g - min_g) + min_g
    return gripper_values

# Training configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000
IMAGE_SIZE = 448  # Higher resolution for better spatial precision (28x28 patches)



# Gripper loss: supervise at GT pixel (teacher forcing); decode at pred pixel in val/inference
GRIPPER_LOSS_WEIGHT = 1.0


def build_volume_3d_points_for_vis(H, W, camera_pose, cam_K, height_bucket_centers, pixel_step=32):
    """Build 3D points for volume visualization (numpy). Returns (N, 3)."""
    points = []
    for y in range(0, H, pixel_step):
        for x in range(0, W, pixel_step):
            for height in height_bucket_centers:
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([x, y], dtype=np.float64), float(height), camera_pose, cam_K
                )
                if pt is not None:
                    points.append(pt)
    return np.array(points) if points else np.zeros((0, 3))


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins):
    """Cross-entropy with softmax over all 3D cells (per timestep).
    
    For each timestep, flatten volume to (B, H*W*N_HEIGHT_BINS), softmax over cells,
    and supervise with the correct 3D cell index: (height_bin, y, x) -> h_bin*(H*W) + y*W + x.
    
    Args:
        pred_volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
        trajectory_2d: (B, N_WINDOW, 2) pixel coords [x, y]
        target_height_bins: (B, N_WINDOW) bin indices in [0, N_HEIGHT_BINS-1]
    """
    B, N, Nh, H, W = pred_volume_logits.shape
    device = pred_volume_logits.device
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)  # (B, N)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)  # (B, N)
    h_bin = target_height_bins.clamp(0, Nh - 1)  # (B, N)
    losses = []
    for t in range(N):
        logits_t = pred_volume_logits[:, t]  # (B, Nh, H, W)
        # Flatten to (B, Nh*H*W) so softmax is over all 3D cells
        logits_flat = logits_t.reshape(B, -1)  # (B, Nh*H*W)
        # Target index: cell (height_bin, y, x) -> height_bin*(H*W) + y*W + x
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()  # (B,)
        losses.append(F.cross_entropy(logits_flat, target_idx, reduction='mean'))
    return torch.stack(losses).mean() * 1.0


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.
    
    For each t: max over height bins gives (H,W) score; argmax gives (x,y); at (x,y) argmax over bins gives height bin.
    Returns:
        pred_2d: (B, N_WINDOW, 2) float pixel coords
        pred_height: (B, N_WINDOW) continuous height from decode_height_bins at that pixel
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        # Height bin at that pixel
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    # Decode bins to continuous height
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep (teacher forcing at GT or pred pixel).
    
    Args:
        gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W)
        pixel_2d: (B, N_WINDOW, 2) pixel coords [x, y]
    
    Returns:
        logits_at_pixels: (B, N_WINDOW, N_GRIPPER_BINS)
    """
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)  # (B, N)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)  # (B, N)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    logits_at_pixels = gripper_logits[batch_idx, time_idx, :, py, px]  # (B, N, Ng)
    return logits_at_pixels


def compute_gripper_loss(pred_gripper_logits, target_gripper, trajectory_2d, pos_weight=None):
    """Cross-entropy for gripper at ground-truth pixel (teacher forcing).
    
    Args:
        pred_gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W) per-pixel gripper logits
        target_gripper: (B, N_WINDOW) target gripper values (continuous) in [MIN_GRIPPER, MAX_GRIPPER]
        trajectory_2d: (B, N_WINDOW, 2) GT pixel coords [x, y]
        pos_weight: unused (kept for API compatibility)
    
    Returns:
        loss: scalar cross-entropy loss (averaged over timesteps)
    """
    logits_at_gt = extract_gripper_logits_at_pixels(pred_gripper_logits, trajectory_2d)  # (B, N, Ng)
    target_bins = discretize_gripper(target_gripper)  # (B, N_WINDOW)
    B, N, Ng = logits_at_gt.shape
    pred_flat = logits_at_gt.reshape(B * N, Ng)
    target_flat = target_bins.reshape(B * N)
    #print(pred_flat.argmax(dim=1), target_flat)
    loss = F.cross_entropy(pred_flat, target_flat, reduction='mean')*4
    return loss

# Commented out binary classification loss (kept for reference)
# def compute_gripper_loss(pred_gripper, target_gripper, pos_weight=None):
#     """Compute weighted binary cross-entropy loss for gripper prediction across all timesteps.
#     
#     Args:
#         pred_gripper: (B, N_WINDOW) predicted gripper values (sigmoid output, [0, 1])
#         target_gripper: (B, N_WINDOW) target gripper values (binary: 0.0=closed, 1.0=open)
#         pos_weight: scalar weight for positive class (1.0 = open). 
#                     pos_weight = n_negative / n_positive = n_closed / n_open
#                     This weights the open class (majority), so we invert it to weight closed (minority) more
#     
#     Returns:
#         loss: scalar weighted BCE loss (averaged over timesteps)
#     """
#     # Binary cross-entropy loss with class weighting to handle imbalance
#     # Note: target_gripper encoding: 0.0 = closed (minority), 1.0 = open (majority)
#     if pos_weight is not None and pos_weight != 1.0:
#         # Manually compute weighted BCE since pred_gripper is already sigmoid output
#         # BCE = -(y * log(p) + (1-y) * log(1-p))
#         # We want to weight the minority class (closed = 0.0) more
#         # So we weight the (1-y) term, which corresponds to closed predictions
#         eps = 1e-7
#         pred_gripper_clamped = pred_gripper.clamp(eps, 1 - eps)
#         
#         # Loss for positive class (open = 1.0) - standard weight
#         pos_loss = -target_gripper * torch.log(pred_gripper_clamped)
#         # Loss for negative class (closed = 0.0) - weighted more heavily
#         # pos_weight = n_open / n_closed, so we use it to weight closed (minority) more
#         neg_loss = -(1 - target_gripper) * torch.log(1 - pred_gripper_clamped) * pos_weight
#         
#         loss = (pos_loss + neg_loss).mean() * 1e1
#     else:
#         # Standard BCE (pred_gripper is already sigmoid output from model)
#         loss = F.binary_cross_entropy(pred_gripper, target_gripper) * 1e1
#     return loss


def visualize_sample(rgb, target_heatmap, pred_heatmap, target_2d):
    """Get visualization arrays for a single sample.
    
    Args:
        rgb: (3, H, W) normalized RGB image tensor
        target_heatmap: (H, W) ground truth heatmap tensor
        pred_heatmap: (H, W) predicted heatmap probabilities tensor
        target_2d: (2,) target pixel location tensor
    
    Returns:
        rgb_vis: numpy array of denormalized RGB
        pred_heat_vis: numpy array of predicted heatmap
        target_pt: numpy array of GT 2D location
        pred_pt: numpy array of predicted 2D location (argmax)
    """
    # Denormalize RGB for visualization
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb.device).view(3, 1, 1)
    
    rgb_denorm = (rgb * std + mean).cpu().numpy()
    rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
    
    pred_heat = pred_heatmap.cpu().numpy()
    target_pt = target_2d.cpu().numpy()
    
    # Get predicted location (argmax)
    pred_y, pred_x = np.unravel_index(pred_heat.argmax(), pred_heat.shape)
    pred_pt = np.array([pred_x, pred_y])
    
    return rgb_vis, pred_heat, target_pt, pred_pt


def train_epoch(model, dataloader, optimizer, device, just_heatmap=False):
    """Train for one epoch. Volume loss + gripper loss (weight 0).
    
    Returns:
        avg_total_loss, avg_volume_loss, avg_height_loss (0), avg_gripper_loss
    """
    model.train()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    n_batches = 0

    for batch in tqdm(dataloader):
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)

        volume_logits, gripper_logits = model(
            rgb,
            training=True,
            start_keypoint_2d=start_keypoint_2d,
            task="Pick the orange mug and place it onto the black saucer",
        )  # (B, N_WINDOW, N_HEIGHT_BINS, H, W), (B, N_WINDOW, N_GRIPPER_BINS)

        target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)
        target_height_bins = discretize_height(target_height)  # (B, N_WINDOW)

        volume_loss = compute_volume_loss(volume_logits, trajectory_2d, target_height_bins)
        gripper_loss = compute_gripper_loss(gripper_logits, trajectory_gripper, trajectory_2d)
        loss = volume_loss + GRIPPER_LOSS_WEIGHT * gripper_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_volume_loss += volume_loss.item()
        total_gripper_loss += gripper_loss.item()
        n_batches += 1

    return total_loss / n_batches, total_volume_loss / n_batches, 0.0, total_gripper_loss / n_batches


def validate(model, dataloader, device, image_size=IMAGE_SIZE):
    """Validate model. Volume loss; extract pred 2D/height from volume; compute pred 3D for 3D pane.
    
    Returns:
        avg_loss, avg_volume_loss, 0, avg_gripper_loss, avg_pixel_error, avg_height_error,
        avg_height_error_tf, avg_gripper_error, sample_data (with trajectory_3d, pred_trajectory_3d, camera_pose, cam_K)
    """
    model.eval()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_pixel_error = 0
    total_height_error = 0
    total_height_error_tf = 0
    total_gripper_error = 0
    n_samples = 0
    sample_data = None

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
            trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
            trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
            camera_pose = batch['camera_pose']  # (B, 4, 4)
            cam_K_norm = batch['cam_K_norm']  # (B, 3, 3) normalized
            start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)
            target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)

            volume_logits, gripper_logits = model(
                rgb,
                training=False,
                start_keypoint_2d=start_keypoint_2d,
                task="Pick the orange mug and place it onto the black saucer",
            )  # (B, N_WINDOW, N_HEIGHT_BINS, H, W), (B, N_WINDOW, N_GRIPPER_BINS)

            target_height_bins = discretize_height(target_height)  # (B, N_WINDOW)
            volume_loss = compute_volume_loss(volume_logits, trajectory_2d, target_height_bins)
            gripper_loss = compute_gripper_loss(gripper_logits, trajectory_gripper, trajectory_2d)
            loss = volume_loss + GRIPPER_LOSS_WEIGHT * gripper_loss

            total_loss += loss.item() * rgb.shape[0]
            total_volume_loss += volume_loss.item() * rgb.shape[0]
            total_gripper_loss += gripper_loss.item() * rgb.shape[0]

            pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)  # (B, N_WINDOW, 2), (B, N_WINDOW)
            gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)  # (B, N_WINDOW, N_GRIPPER_BINS)
            pred_gripper = decode_gripper_bins(gripper_logits_at_pred)  # (B, N_WINDOW)

            # Pixel error
            B, N, H, W = volume_logits.shape[0], volume_logits.shape[1], volume_logits.shape[3], volume_logits.shape[4]
            for t in range(N):
                pixel_error_t = torch.norm(pred_2d[:, t] - trajectory_2d[:, t], dim=1).sum()
                total_pixel_error += pixel_error_t.item()
            total_height_error += torch.abs(pred_height - target_height).mean(dim=1).sum().item()
            total_height_error_tf += 0.0  # no teacher forcing for volume
            total_gripper_error += torch.abs(pred_gripper - trajectory_gripper).mean(dim=1).sum().item()
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                # Pred heatmap per timestep = max over height bins (for image panes)
                pred_heatmaps = []
                for t in range(N_WINDOW):
                    vol_t = volume_logits[0, t]  # (Nh, H, W)
                    # Softmax over full volume, then max along the ray for visualization
                    vol_probs = F.softmax(vol_t.view(-1), dim=0).view(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                    max_along_ray = vol_probs.max(dim=0)[0]  # (H, W)
                    pred_heatmaps.append(max_along_ray)
                pred_heatmaps = torch.stack(pred_heatmaps)  # (N_WINDOW, H, W)

                pred_h_0 = pred_height[0]
                pred_g_0 = pred_gripper[0]
                if pred_h_0.dim() == 0:
                    pred_h_0 = pred_h_0.unsqueeze(0).expand(N_WINDOW)
                if pred_g_0.dim() == 0:
                    pred_g_0 = pred_g_0.unsqueeze(0).expand(N_WINDOW)

                # Build pred_trajectory_3d by unprojecting pred_2d + pred_height (first sample)
                cam_pose_np = camera_pose[0].cpu().numpy()
                cam_K_norm_np = cam_K_norm[0].cpu().numpy()
                cam_K_np = cam_K_norm_np.copy()
                cam_K_np[0] *= image_size
                cam_K_np[1] *= image_size
                pred_trajectory_3d_list = []
                for t in range(N_WINDOW):
                    px, py = pred_2d[0, t, 0].item(), pred_2d[0, t, 1].item()
                    h = pred_height[0, t].item()
                    pt = recover_3d_from_direct_keypoint_and_height(
                        np.array([px, py], dtype=np.float64), h, cam_pose_np, cam_K_np
                    )
                    if pt is not None:
                        pred_trajectory_3d_list.append(pt)
                    else:
                        pred_trajectory_3d_list.append(trajectory_3d[0, t].cpu().numpy())
                pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

                sample_data = {
                    'rgb': rgb[0],
                    'target_heatmap': batch['heatmap_target'][0].to(device),  # (N_WINDOW, H, W)
                    'pred_heatmap': pred_heatmaps,
                    'trajectory_2d': trajectory_2d[0],
                    'trajectory_3d': trajectory_3d[0],
                    'pred_trajectory_3d': pred_trajectory_3d_np,
                    'camera_pose': camera_pose[0].cpu().numpy(),
                    'cam_K_norm': cam_K_norm[0].cpu().numpy(),
                    'cam_K_at_size': cam_K_np,
                    'pred_height': pred_h_0,
                    'target_height': target_height[0],
                    'pred_gripper': pred_g_0,
                    'target_gripper': trajectory_gripper[0],
                }

    n = max(1, n_samples)
    avg_pixel_error = total_pixel_error / (n * N_WINDOW)
    return (
        total_loss / n, total_volume_loss / n, 0.0, total_gripper_loss / n,
        avg_pixel_error, total_height_error / n, total_height_error_tf / n, total_gripper_error / n,
        sample_data
    )


def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train trajectory heatmap predictor")
    parser.add_argument("--dataset_root", "-d",nargs="+", default=["scratch/parsed_school_long_recap"],
                       help="One or more root directories with episodes (same structure); datasets are concatenated")
    parser.add_argument("--val_split", type=float, default=0.05,
                       help="Fraction of episodes to use for validation")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE,
                       help="Batch size for training")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                       help="Learning rate")
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS,
                       help="Number of epochs")
    parser.add_argument("--checkpoint", type=str, default="",
                       help="Path to checkpoint to resume from (default: 2D-only model)")
    parser.add_argument("--run_name", type=str, default="volume_dino_tracks",
                       help="Name of run (used for checkpoint and visualization paths)")
    parser.add_argument("--pretrained_ckpt", type=str, default=None,
                       help="SmolVLA pretrained checkpoint (default: lerobot/smolvla_base)")
    args = parser.parse_args()

    # Checkpoint and visualization paths
    CHECKPOINT_DIR = Path(f"volume_tracks_smolvla/checkpoints/{args.run_name}")
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    
    # Setup device: prefer CUDA, then MPS, then CPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")
    
    # Load full dataset (single root or concatenation of multiple roots)
    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    print("\nLoading dataset...")
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(
            dataset_root=dataset_roots[0],
            image_size=IMAGE_SIZE
        )
        print(f"  Single root: {dataset_roots[0]}")
    else:
        datasets = [
            RealTrajectoryDataset(dataset_root=root, image_size=IMAGE_SIZE)
            for root in dataset_roots
        ]
        full_dataset = ConcatDataset(datasets)
        for root, d in zip(dataset_roots, datasets):
            print(f"  {root}: {len(d)} samples")
    print(f"  Total: {len(full_dataset)} samples")
    
    # Split into train/val (simple split by samples)
    dataset_size = len(full_dataset)
    val_size = int(dataset_size * args.val_split)
    train_size = dataset_size - val_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")
    
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=True, 
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    
    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")
    
    # Initialize model
    print("\nInitializing model...")
    ckpt = args.pretrained_ckpt if args.pretrained_ckpt else DEFAULT_SMOLVLA_CKPT
    print("making model")
    model = TrajectoryHeatmapPredictor(
        target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False, pretrained_ckpt=ckpt
    )
    print("model made")
    model = model.to(device)
    
    # Count parameters
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {n_trainable:,} / {n_total:,} ({100*n_trainable/n_total:.2f}%)")
    
    # Setup optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=1e-4
    )
    
    # Load checkpoint if specified (before computing stats, so we can use checkpoint values if available).
    # Supports 2D-only model (latest.pt): only matching keys are loaded; height/gripper heads are randomly initialized.
    # Try .pt and .pth so both user's "latest.pt" and our saved "latest.pth" work.
    start_epoch = 0
    checkpoint_height_values = None
    checkpoint_gripper_values = None
    checkpoint_path = args.checkpoint
    if args.checkpoint:
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
    if args.checkpoint and os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint: {checkpoint_path} (initializing from 2D-only model; height/gripper 32-bin heads random if missing)")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model state dict with partial loading (skip missing keys and shape mismatches for height/gripper heads if 2D-only model)
        model_state = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        
        # Filter: only load keys that exist in current model AND have matching shapes
        filtered_state = {}
        shape_mismatches = []
        for k, v in model_state.items():
            if k in model_dict:
                if v.shape == model_dict[k].shape:
                    filtered_state[k] = v
                else:
                    shape_mismatches.append(f"{k}: checkpoint {v.shape} vs model {model_dict[k].shape}")
            # else: key not in current model, skip it
        
        missing_keys = set(model_dict.keys()) - set(model_state.keys())
        unexpected_keys = set(model_state.keys()) - set(model_dict.keys())
        
        if missing_keys:
            print(f"⚠ Missing keys in checkpoint (will use random initialization): {sorted(missing_keys)}")
        if unexpected_keys:
            print(f"⚠ Unexpected keys in checkpoint (will be ignored): {sorted(unexpected_keys)}")
        if shape_mismatches:
            print(f"⚠ Shape mismatches (will use random initialization for these):")
            for msg in shape_mismatches:
                print(f"    {msg}")
        
        model_dict.update(filtered_state)
        model.load_state_dict(model_dict, strict=False)  # strict=False allows remaining missing keys
        
        # Load optimizer state if available (may not match if architecture changed)
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        except Exception as e:
            print(f"⚠ Could not load optimizer state (architecture may have changed): {e}")
            print("  Starting with fresh optimizer state")
        
        start_epoch = checkpoint.get('epoch', 0) + 1
        
        # Save height values from checkpoint if available (use these instead of recomputing)
        if 'min_height' in checkpoint and 'max_height' in checkpoint:
            checkpoint_height_values = (checkpoint['min_height'], checkpoint['max_height'])
            print(f"✓ Found height range in checkpoint: [{checkpoint_height_values[0]:.6f}, {checkpoint_height_values[1]:.6f}] m")
        # Load gripper min/max from checkpoint if available
        if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
            checkpoint_gripper_values = (checkpoint['min_gripper'], checkpoint['max_gripper'])
            print(f"✓ Found gripper range in checkpoint: [{checkpoint_gripper_values[0]:.6f}, {checkpoint_gripper_values[1]:.6f}]")
        print(f"✓ Resumed from epoch {start_epoch}")
    elif args.checkpoint:
        print(f"\n⚠ Checkpoint not found: {args.checkpoint}")
        print("  Initializing model from scratch (height/gripper heads will be randomly initialized)")
    
    # Compute or load min/max height and gripper from dataset (computed once, used as constants throughout training)
    # If resuming from checkpoint with values, use those instead of recomputing
    import model as model_module
    if checkpoint_height_values is not None:
        # Use height values from checkpoint (model was trained with these)
        model_module.MIN_HEIGHT, model_module.MAX_HEIGHT = checkpoint_height_values
        print(f"✓ Using height range from checkpoint: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset or checkpoint.")
    else:
        # Compute min/max height from entire dataset (once, used as constants)
        print("\nComputing height statistics from dataset...")
        all_heights = []
        for dataset in [train_dataset, val_dataset]:
            for i in range(len(dataset)):
                sample = dataset[i]
                trajectory_3d = sample['trajectory_3d'].numpy()  # (N_WINDOW, 3)
                waypoint_heights = trajectory_3d[:, 2]  # z-coordinates for all waypoints
                all_heights.extend(waypoint_heights.tolist())
        
        all_heights = np.array(all_heights)
        min_height = float(all_heights.min())
        max_height = float(all_heights.max())
        model_module.MIN_HEIGHT = min_height
        model_module.MAX_HEIGHT = max_height
        print(f"✓ Height range computed from dataset: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        print(f"  ({model_module.MIN_HEIGHT*1000:.2f}mm to {model_module.MAX_HEIGHT*1000:.2f}mm)")
        print(f"  (computed from {len(all_heights)} waypoint heights across all samples)")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset.")
        print(f"✓ These values will be used as constants throughout training and saved in checkpoints")
    
    # Gripper range: compute from dataset or use checkpoint/hard-coded values
    if checkpoint_gripper_values is not None:
        # Use gripper values from checkpoint (model was trained with these)
        model_module.MIN_GRIPPER, model_module.MAX_GRIPPER = checkpoint_gripper_values
        print(f"✓ Using gripper range from checkpoint: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
        if abs(model_module.MIN_GRIPPER - model_module.MAX_GRIPPER) < 1e-6:
            print(f"  ⚠ WARNING: MIN_GRIPPER == MAX_GRIPPER! All gripper predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset or checkpoint.")
    else:
        # Compute min/max gripper from entire dataset (once, used as constants)
        print("\nComputing gripper statistics from dataset...")
        all_grippers = []
        for dataset in [train_dataset, val_dataset]:
            for i in range(len(dataset)):
                sample = dataset[i]
                trajectory_gripper = sample['trajectory_gripper'].numpy()  # (N_WINDOW,)
                all_grippers.extend(trajectory_gripper.tolist())
        
        all_grippers = np.array(all_grippers)
        min_gripper = float(all_grippers.min())
        max_gripper = float(all_grippers.max())
        model_module.MIN_GRIPPER = min_gripper
        model_module.MAX_GRIPPER = max_gripper
        print(f"✓ Gripper range computed from dataset: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
        print(f"  (computed from {len(all_grippers)} gripper values across all samples)")
        if abs(model_module.MIN_GRIPPER - model_module.MAX_GRIPPER) < 1e-6:
            print(f"  ⚠ WARNING: MIN_GRIPPER == MAX_GRIPPER! All gripper predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset.")
        print(f"✓ These values will be used as constants throughout training and saved in checkpoints")
    
    # Setup live visualization
    print("\nSetting up live visualization...")
    plt.ion()
    from mpl_toolkits.mplot3d import Axes3D
    # Grid: 3 loss + 2 image rows + 1 height line + 1 gripper line + 1 big 3D = 8 rows
    fig = plt.figure(figsize=(4*N_WINDOW, 14))
    gs = GridSpec(8, N_WINDOW, figure=fig, hspace=0.35, wspace=0.2,
                  height_ratios=[1, 1, 1, 2, 2, 0.8, 0.8, 5])
    
    ax_loss_heatmap = fig.add_subplot(gs[0, :])
    ax_loss_height = fig.add_subplot(gs[1, :])
    ax_loss_gripper = fig.add_subplot(gs[2, :])
    
    # 2 rows for train, val images (each row N_WINDOW cols)
    axes_vis = []
    for row_idx in range(2):
        row_axes = []
        for t in range(N_WINDOW):
            ax = fig.add_subplot(gs[3 + row_idx, t])
            ax.axis('off')
            if row_idx == 0:
                ax.set_title(f"t={t}", fontweight='bold', fontsize=10)
            row_axes.append(ax)
        axes_vis.append(row_axes)
    
    # Single line charts: height (pred vs GT over timesteps), gripper (pred vs GT over timesteps)
    ax_height_line = fig.add_subplot(gs[5, :])
    ax_gripper_line = fig.add_subplot(gs[6, :])

    # 3D volume pane (GT vs Pred trajectory) - larger height_ratio
    ax_3d = fig.add_subplot(gs[7, :], projection='3d')
    
    plt.show(block=False)
    fig.canvas.draw()  # Initial draw to ensure window is ready
    plt.pause(0.1)  # Small pause to let window initialize
    
    # Training loop
    print(f"\nStarting training for {args.epochs} epochs...")
    best_val_loss = float('inf')
    
    # Track losses separately
    train_heatmap_losses = []
    train_height_losses = []
    train_gripper_losses = []
    val_heatmap_losses = []
    val_height_losses = []
    val_gripper_losses = []
    
    for epoch in range(start_epoch, args.epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{args.epochs}")
        print(f"{'='*60}")
        
        # Train
        train_loss, train_volume_loss, _, train_gripper_loss = train_epoch(
            model, train_loader, optimizer, device, just_heatmap=epoch<2
        )
        train_heatmap_losses.append(train_volume_loss)  # store volume loss in same list for plot
        train_height_losses.append(0.0)
        train_gripper_losses.append(train_gripper_loss)
        print(f"Train Loss: {train_loss:.4f} (Volume: {train_volume_loss:.4f}, Gripper: {train_gripper_loss:.6f})")
        
        # Validate and get sample data for visualization
        val_loss, val_heatmap_loss, val_height_loss, val_gripper_loss, \
        val_error, val_height_error, val_height_error_tf, val_gripper_error, sample_val = validate(
            model, val_loader, device
        )
        val_heatmap_losses.append(val_heatmap_loss)  # volume loss
        val_height_losses.append(val_height_loss)
        val_gripper_losses.append(val_gripper_loss)

        print(f"Val - Loss: {val_loss:.4f}, Volume: {val_heatmap_loss:.4f}, Pixel Error: {val_error:.2f}px, Height Error: {val_height_error*1000:.3f}mm, Gripper: {val_gripper_error:.4f}")
        
        # Get train sample for visualization (all timesteps)
        model.eval()
        with torch.no_grad():
            train_sample_batch = next(iter(train_loader))
            train_rgb = train_sample_batch['rgb'][0:1].to(device)
            train_target_heatmap_all = train_sample_batch['heatmap_target'][0]  # (N_WINDOW, H, W)
            train_trajectory_2d = train_sample_batch['trajectory_2d'][0]  # (N_WINDOW, 2)
            train_trajectory_3d = train_sample_batch['trajectory_3d'][0]  # (N_WINDOW, 3)
            train_trajectory_gripper = train_sample_batch['trajectory_gripper'][0]  # (N_WINDOW,)
            train_start_keypoint_2d = train_trajectory_2d[0]  # (2,)
            train_camera_pose = train_sample_batch['camera_pose'][0].cpu().numpy()
            train_cam_K_norm = train_sample_batch['cam_K_norm'][0].cpu().numpy()
            train_cam_K = train_cam_K_norm.copy()
            train_cam_K[0] *= IMAGE_SIZE
            train_cam_K[1] *= IMAGE_SIZE

            train_volume_logits, train_gripper_logits = model(
                train_rgb,
                training=False,
                start_keypoint_2d=train_start_keypoint_2d,
                task="Pick the orange mug and place it onto the black saucer",
            )
            train_pred_2d, train_pred_height = extract_pred_2d_and_height_from_volume(train_volume_logits[0:1])
            train_pred_height = train_pred_height[0]  # (N_WINDOW,)
            train_gripper_at_pred = extract_gripper_logits_at_pixels(train_gripper_logits, train_pred_2d)
            train_pred_gripper = decode_gripper_bins(train_gripper_at_pred)[0]  # (N_WINDOW,)

            train_pred_heatmaps = []
            for t in range(N_WINDOW):
                vol_t = train_volume_logits[0, t]  # (Nh, H, W)
                # Softmax over full volume, then max along the ray for visualization
                vol_probs = F.softmax(vol_t.view(-1), dim=0).view(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                max_along_ray = vol_probs.max(dim=0)[0]  # (H, W)
                train_pred_heatmaps.append(max_along_ray)
            train_pred_heatmaps = torch.stack(train_pred_heatmaps)  # (N_WINDOW, H, W)

            train_pred_trajectory_3d_list = []
            for t in range(N_WINDOW):
                px, py = train_pred_2d[0, t, 0].item(), train_pred_2d[0, t, 1].item()
                h = train_pred_height[t].item()
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([px, py], dtype=np.float64), h, train_camera_pose, train_cam_K
                )
                train_pred_trajectory_3d_list.append(pt if pt is not None else train_trajectory_3d[t].cpu().numpy())
            train_pred_trajectory_3d_np = np.array(train_pred_trajectory_3d_list)

            sample_train = {
                'rgb': train_rgb[0],
                'target_heatmap': train_target_heatmap_all,
                'pred_heatmap': train_pred_heatmaps,
                'trajectory_2d': train_trajectory_2d,
                'trajectory_3d': train_trajectory_3d,
                'pred_trajectory_3d': train_pred_trajectory_3d_np,
                'camera_pose': train_camera_pose,
                'cam_K_at_size': train_cam_K,
                'pred_height': train_pred_height,
                'target_height': train_trajectory_3d[:, 2],
                'pred_gripper': train_pred_gripper,
                'target_gripper': train_trajectory_gripper,
            }
        
        # Update visualization
        epochs_range = np.arange(len(train_heatmap_losses))
        ax_loss_heatmap.clear()
        ax_loss_heatmap.plot(epochs_range, train_heatmap_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_heatmap.plot(epochs_range, val_heatmap_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_heatmap.set_xlabel('Epoch')
        ax_loss_heatmap.set_ylabel('Volume Loss (CE)')
        ax_loss_heatmap.set_title(f'Volume Loss | Train: {train_volume_loss:.3f} | Val: {val_heatmap_loss:.3f}')
        ax_loss_heatmap.legend(loc='upper right')
        ax_loss_heatmap.grid(alpha=0.3)

        ax_loss_height.clear()
        ax_loss_height.plot(epochs_range, train_height_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_height.plot(epochs_range, val_height_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_height.set_xlabel('Epoch')
        ax_loss_height.set_ylabel('Height (unused)')
        ax_loss_height.set_title('Height loss (0)')
        ax_loss_height.legend(loc='upper right')
        ax_loss_height.grid(alpha=0.3)
        
        # Gripper loss plot
        ax_loss_gripper.clear()
        epochs_range = np.arange(len(train_gripper_losses))
        ax_loss_gripper.plot(epochs_range, train_gripper_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_gripper.plot(epochs_range, val_gripper_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_gripper.set_xlabel('Epoch')
        ax_loss_gripper.set_ylabel('Gripper Loss (MSE)')
        ax_loss_gripper.set_title(f'Gripper Loss | Train: {train_gripper_loss:.6f} | Val: {val_gripper_loss:.6f}')
        ax_loss_gripper.legend(loc='upper right')
        ax_loss_gripper.grid(alpha=0.3)
        
        # Visualize samples from train and val (show all timesteps horizontally)
        for row_idx, (sample, split_name) in enumerate([
            (sample_train, 'Train'),
            (sample_val, 'Val')
        ]):
            if sample is None:
                continue
            
            # Denormalize RGB for visualization
            mean = torch.tensor([0.485, 0.456, 0.406], device=sample['rgb'].device).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=sample['rgb'].device).view(3, 1, 1)
            rgb_denorm = (sample['rgb'] * std + mean).cpu().numpy()
            rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
            
            # Visualize each timestep
            for t in range(N_WINDOW):
                ax = axes_vis[row_idx][t]
                ax.clear()
                
                # Get heatmaps and keypoints for this timestep
                target_heatmap_t = sample['target_heatmap'][t].cpu().numpy()  # (H, W)
                pred_heatmap_t = sample['pred_heatmap'][t].cpu().numpy()  # (H, W)
                target_2d_t = sample['trajectory_2d'][t].cpu().numpy()  # (2,)
                
                # Get predicted location (argmax)
                pred_y, pred_x = np.unravel_index(pred_heatmap_t.argmax(), pred_heatmap_t.shape)
                pred_2d_t = np.array([pred_x, pred_y])
                
                # Show RGB with heatmap overlay and keypoints
                ax.imshow(rgb_vis)
                ax.imshow(pred_heatmap_t, alpha=0.6, cmap='hot')
                
                # Plot GT and predicted keypoints
                ax.scatter(target_2d_t[0], target_2d_t[1], c='white', s=100, 
                          marker='o', edgecolors='black', linewidths=2, 
                          label='GT', zorder=10)
                ax.scatter(pred_2d_t[0], pred_2d_t[1], c='lime', s=100, 
                          marker='x', linewidths=3, 
                          label='Pred', zorder=10)
                
                # Compute errors
                pixel_err = np.linalg.norm(target_2d_t - pred_2d_t)
                title_parts = [f"t={t}", f"Px:{pixel_err:.1f}px"]
                if 'pred_height' in sample and 'target_height' in sample:
                    # Ensure pred_height is a 1D tensor before indexing
                    pred_h_tensor = sample['pred_height']
                    if pred_h_tensor.dim() == 0:
                        # Scalar - use it directly (shouldn't happen, but handle gracefully)
                        pred_h = pred_h_tensor.cpu().item()
                    else:
                        pred_h = pred_h_tensor[t].cpu().item() if pred_h_tensor.dim() > 0 else pred_h_tensor.cpu().item()
                    target_h_tensor = sample['target_height']
                    target_h = target_h_tensor[t].cpu().item() if target_h_tensor.dim() > 0 else target_h_tensor.cpu().item()
                    height_err = abs(pred_h - target_h) * 1000  # mm
                    title_parts.append(f"H:{pred_h*1000:.1f}mm")
                if 'pred_gripper' in sample and 'target_gripper' in sample:
                    # Ensure pred_gripper is a 1D tensor before indexing
                    pred_g_tensor = sample['pred_gripper']
                    pred_g = pred_g_tensor[t].cpu().item() if pred_g_tensor.dim() > 0 else pred_g_tensor.cpu().item()
                    target_g_tensor = sample['target_gripper']
                    target_g = target_g_tensor[t].cpu().item() if target_g_tensor.dim() > 0 else target_g_tensor.cpu().item()
                    title_parts.append(f"G:{pred_g:.2f}(GT:{target_g:.2f})")
                ax.set_title("\n".join(title_parts), fontsize=8)
                
                ax.axis('off')
                if t == 0:
                    ax.legend(loc='upper right', fontsize=6)

        # Single line charts: height and gripper (pred vs GT over timesteps), using val sample
        import model as model_module
        sample_for_lines = sample_val if sample_val is not None else sample_train
        if sample_for_lines is not None:
            ts = np.arange(N_WINDOW)
            ax_height_line.clear()
            ax_height_line.set_xlabel('Timestep')
            ax_height_line.set_ylabel('Height (mm)')
            ax_height_line.set_title('Height: Pred vs GT')
            ax_height_line.grid(alpha=0.3)
            ax_height_line.set_ylim([model_module.MIN_HEIGHT * 1000 - 10, model_module.MAX_HEIGHT * 1000 + 10])
            if 'pred_height' in sample_for_lines and 'target_height' in sample_for_lines:
                pred_h = sample_for_lines['pred_height']
                target_h = sample_for_lines['target_height']
                if hasattr(pred_h, 'cpu'):
                    pred_h = np.atleast_1d(pred_h.cpu().numpy())
                else:
                    pred_h = np.atleast_1d(np.asarray(pred_h))
                if hasattr(target_h, 'cpu'):
                    target_h = np.atleast_1d(target_h.cpu().numpy())
                else:
                    target_h = np.atleast_1d(np.asarray(target_h))
                if len(pred_h) == N_WINDOW and len(target_h) == N_WINDOW:
                    ax_height_line.plot(ts, pred_h * 1000, 'o-', color='red', label='Pred', linewidth=2, markersize=4)
                    ax_height_line.plot(ts, target_h * 1000, 's-', color='green', label='GT', linewidth=2, markersize=4)
            ax_height_line.legend(loc='upper right')

            ax_gripper_line.clear()
            ax_gripper_line.set_xlabel('Timestep')
            ax_gripper_line.set_ylabel('Gripper')
            ax_gripper_line.set_title('Gripper: Pred vs GT')
            ax_gripper_line.grid(alpha=0.3)
            ax_gripper_line.set_ylim([model_module.MIN_GRIPPER - 0.1, model_module.MAX_GRIPPER + 0.1])
            if 'pred_gripper' in sample_for_lines and 'target_gripper' in sample_for_lines:
                pred_g = sample_for_lines['pred_gripper']
                target_g = sample_for_lines['target_gripper']
                if hasattr(pred_g, 'cpu'):
                    pred_g = np.atleast_1d(pred_g.cpu().numpy())
                else:
                    pred_g = np.atleast_1d(np.asarray(pred_g))
                if hasattr(target_g, 'cpu'):
                    target_g = np.atleast_1d(target_g.cpu().numpy())
                else:
                    target_g = np.atleast_1d(np.asarray(target_g))
                if len(pred_g) == N_WINDOW and len(target_g) == N_WINDOW:
                    ax_gripper_line.plot(ts, pred_g, 'o-', color='red', label='Pred', linewidth=2, markersize=4)
                    ax_gripper_line.plot(ts, target_g, 's-', color='green', label='GT', linewidth=2, markersize=4)
            ax_gripper_line.legend(loc='upper right')

        # 3D volume pane: volume (transparent white) + GT (red) + Pred (blue)
        ax_3d.clear()
        sample_for_3d = sample_val if sample_val is not None else sample_train
        if sample_for_3d is not None and 'trajectory_3d' in sample_for_3d and 'pred_trajectory_3d' in sample_for_3d:
            traj_3d = sample_for_3d['trajectory_3d']
            if hasattr(traj_3d, 'cpu'):
                traj_3d = traj_3d.cpu().numpy()
            pred_3d = sample_for_3d['pred_trajectory_3d']
            cam_pose = sample_for_3d['camera_pose']
            cam_K = sample_for_3d['cam_K_at_size']
            z_min = float(traj_3d[:, 2].min()) - 0.01
            z_max = float(traj_3d[:, 2].max()) + 0.01
            height_centers = np.linspace(z_min, z_max, N_HEIGHT_BINS)
            volume_pts = build_volume_3d_points_for_vis(
                IMAGE_SIZE, IMAGE_SIZE, cam_pose, cam_K, height_centers, pixel_step=32
            )
            if len(volume_pts) > 0:
                ax_3d.scatter(
                    volume_pts[:, 0], volume_pts[:, 1], volume_pts[:, 2],
                    c='white', alpha=0.03, s=1, edgecolors='none'
                )
            ax_3d.scatter(
                traj_3d[:, 0], traj_3d[:, 1], traj_3d[:, 2],
                c='red', alpha=1.0, s=60, edgecolors='darkred', linewidths=1.5, label='GT'
            )
            ax_3d.scatter(
                pred_3d[:, 0], pred_3d[:, 1], pred_3d[:, 2],
                c='blue', alpha=1.0, s=60, edgecolors='darkblue', linewidths=1.5, label='Pred'
            )
            ax_3d.set_xlabel('X (m)')
            ax_3d.set_ylabel('Y (m)')
            ax_3d.set_zlabel('Z (m)')
            ax_3d.set_title('Volume (white) + GT (red) vs Pred (blue)')
            ax_3d.legend()
        
        # Update figure (better for macOS)
        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.pause(0.01)  # Slightly longer pause for better responsiveness
        
        # Use val loss for checkpointing
        
        # Save checkpoint (use model constants, computed once from dataset)
        import model as model_module
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height': model_module.MIN_HEIGHT,  # Use model constants (computed once from dataset)
            'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,  # Regression: [-0.2, 0.8]
            'max_gripper': model_module.MAX_GRIPPER,
        }
        
        # Save latest
        torch.save(checkpoint, CHECKPOINT_DIR / 'latest.pth')
        
        # Save best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, CHECKPOINT_DIR / 'best.pth')
            print(f"✓ Saved best model (val_loss={val_loss:.4f})")
        
    
    plt.ioff()
    plt.show()
    
    print("\n" + "="*60)
    print("✓ Training complete!")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints saved to: {CHECKPOINT_DIR}")


if __name__ == "__main__":
    main()
