"""Train trajectory volume predictor on real data.

Model predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (CE at trajectory pixel only).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel (teacher forcing), decoded at pred pixel in val/inference.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
#matplotlib.use('TkAgg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from pathlib import Path
from tqdm import tqdm
import argparse

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS
import model as model_module  # Import module to access updated MIN_HEIGHT/MAX_HEIGHT at runtime
from utils import recover_3d_from_direct_keypoint_and_height

# Helper functions for discretization
def discretize_height(height_values):
    """Discretize continuous height values into bin indices.
    
    Args:
        height_values: (B, N_WINDOW) or (N_WINDOW,) tensor of heights in [MIN_HEIGHT, MAX_HEIGHT]
    
    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_HEIGHT_BINS-1]
    """
    # Access MIN_HEIGHT/MAX_HEIGHT from model module at runtime (updated by train.py)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    # Normalize to [0, 1]
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    # Map to bin indices [0, N_HEIGHT_BINS-1]
    bin_indices = (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)
    return bin_indices


def discretize_gripper(gripper_values):
    """Discretize continuous gripper values into bin indices.
    
    Args:
        gripper_values: (B, N_WINDOW) or (N_WINDOW,) tensor of gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    
    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_GRIPPER_BINS-1]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by train.py)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Normalize to [0, 1]
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    # Map to bin indices [0, N_GRIPPER_BINS-1]
    bin_indices = (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)
    return bin_indices


def decode_height_bins(bin_logits):
    """Decode height bin logits back to continuous height values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_HEIGHT_BINS) logits for each bin
    
    Returns:
        height_values: (B, N_WINDOW) continuous height values in [MIN_HEIGHT, MAX_HEIGHT]
    """
    # Access MIN_HEIGHT/MAX_HEIGHT from model module at runtime (updated by train.py)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=bin_logits.device)  # (N_HEIGHT_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_HEIGHT, MAX_HEIGHT]
    height_values = normalized * (max_h - min_h) + min_h
    return height_values


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits back to continuous gripper values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_GRIPPER_BINS) logits for each bin
    
    Returns:
        gripper_values: (B, N_WINDOW) continuous gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by train.py)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)  # (N_GRIPPER_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_GRIPPER, MAX_GRIPPER]
    gripper_values = normalized * (max_g - min_g) + min_g
    return gripper_values

# Training configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
IMAGE_SIZE = 448  # Higher resolution for better spatial precision (28x28 patches)



# Gripper loss: supervise at GT pixel (teacher forcing); decode at pred pixel in val/inference
GRIPPER_LOSS_WEIGHT = 1.0


def build_volume_3d_points_for_vis(H, W, camera_pose, cam_K, height_bucket_centers, pixel_step=32):
    """Build 3D points for volume visualization (numpy). Returns (N, 3)."""
    points = []
    for y in range(0, H, pixel_step):
        for x in range(0, W, pixel_step):
            for height in height_bucket_centers:
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([x, y], dtype=np.float64), float(height), camera_pose, cam_K
                )
                if pt is not None:
                    points.append(pt)
    return np.array(points) if points else np.zeros((0, 3))


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins):
    """Cross-entropy with softmax over all 3D cells (per timestep).
    
    For each timestep, flatten volume to (B, H*W*N_HEIGHT_BINS), softmax over cells,
    and supervise with the correct 3D cell index: (height_bin, y, x) -> h_bin*(H*W) + y*W + x.
    
    Args:
        pred_volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
        trajectory_2d: (B, N_WINDOW, 2) pixel coords [x, y]
        target_height_bins: (B, N_WINDOW) bin indices in [0, N_HEIGHT_BINS-1]
    """
    B, N, Nh, H, W = pred_volume_logits.shape
    device = pred_volume_logits.device
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)  # (B, N)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)  # (B, N)
    h_bin = target_height_bins.clamp(0, Nh - 1)  # (B, N)
    losses = []
    for t in range(N):
        logits_t = pred_volume_logits[:, t]  # (B, Nh, H, W)
        # Flatten to (B, Nh*H*W) so softmax is over all 3D cells
        logits_flat = logits_t.reshape(B, -1)  # (B, Nh*H*W)
        # Target index: cell (height_bin, y, x) -> height_bin*(H*W) + y*W + x
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()  # (B,)
        losses.append(F.cross_entropy(logits_flat, target_idx, reduction='mean'))
    return torch.stack(losses).mean() * 1.0


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.
    
    For each t: max over height bins gives (H,W) score; argmax gives (x,y); at (x,y) argmax over bins gives height bin.
    Returns:
        pred_2d: (B, N_WINDOW, 2) float pixel coords
        pred_height: (B, N_WINDOW) continuous height from decode_height_bins at that pixel
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        # Height bin at that pixel
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    # Decode bins to continuous height
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep (teacher forcing at GT or pred pixel).
    
    Args:
        gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W)
        pixel_2d: (B, N_WINDOW, 2) pixel coords [x, y]
    
    Returns:
        logits_at_pixels: (B, N_WINDOW, N_GRIPPER_BINS)
    """
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)  # (B, N)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)  # (B, N)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    logits_at_pixels = gripper_logits[batch_idx, time_idx, :, py, px]  # (B, N, Ng)
    return logits_at_pixels


def compute_gripper_loss(pred_gripper_logits, target_gripper, trajectory_2d, pos_weight=None):
    """Cross-entropy for gripper at ground-truth pixel (teacher forcing).
    
    Args:
        pred_gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W) per-pixel gripper logits
        target_gripper: (B, N_WINDOW) target gripper values (continuous) in [MIN_GRIPPER, MAX_GRIPPER]
        trajectory_2d: (B, N_WINDOW, 2) GT pixel coords [x, y]
        pos_weight: unused (kept for API compatibility)
    
    Returns:
        loss: scalar cross-entropy loss (averaged over timesteps)
    """
    logits_at_gt = extract_gripper_logits_at_pixels(pred_gripper_logits, trajectory_2d)  # (B, N, Ng)
    target_bins = discretize_gripper(target_gripper)  # (B, N_WINDOW)
    B, N, Ng = logits_at_gt.shape
    pred_flat = logits_at_gt.reshape(B * N, Ng)
    target_flat = target_bins.reshape(B * N)
    #print(pred_flat.argmax(dim=1), target_flat)
    loss = F.cross_entropy(pred_flat, target_flat, reduction='mean')*4
    return loss

# Commented out binary classification loss (kept for reference)
# def compute_gripper_loss(pred_gripper, target_gripper, pos_weight=None):
#     """Compute weighted binary cross-entropy loss for gripper prediction across all timesteps.
#     
#     Args:
#         pred_gripper: (B, N_WINDOW) predicted gripper values (sigmoid output, [0, 1])
#         target_gripper: (B, N_WINDOW) target gripper values (binary: 0.0=closed, 1.0=open)
#         pos_weight: scalar weight for positive class (1.0 = open). 
#                     pos_weight = n_negative / n_positive = n_closed / n_open
#                     This weights the open class (majority), so we invert it to weight closed (minority) more
#     
#     Returns:
#         loss: scalar weighted BCE loss (averaged over timesteps)
#     """
#     # Binary cross-entropy loss with class weighting to handle imbalance
#     # Note: target_gripper encoding: 0.0 = closed (minority), 1.0 = open (majority)
#     if pos_weight is not None and pos_weight != 1.0:
#         # Manually compute weighted BCE since pred_gripper is already sigmoid output
#         # BCE = -(y * log(p) + (1-y) * log(1-p))
#         # We want to weight the minority class (closed = 0.0) more
#         # So we weight the (1-y) term, which corresponds to closed predictions
#         eps = 1e-7
#         pred_gripper_clamped = pred_gripper.clamp(eps, 1 - eps)
#         
#         # Loss for positive class (open = 1.0) - standard weight
#         pos_loss = -target_gripper * torch.log(pred_gripper_clamped)
#         # Loss for negative class (closed = 0.0) - weighted more heavily
#         # pos_weight = n_open / n_closed, so we use it to weight closed (minority) more
#         neg_loss = -(1 - target_gripper) * torch.log(1 - pred_gripper_clamped) * pos_weight
#         
#         loss = (pos_loss + neg_loss).mean() * 1e1
#     else:
#         # Standard BCE (pred_gripper is already sigmoid output from model)
#         loss = F.binary_cross_entropy(pred_gripper, target_gripper) * 1e1
#     return loss


def visualize_sample(rgb, target_heatmap, pred_heatmap, target_2d):
    """Get visualization arrays for a single sample.
    
    Args:
        rgb: (3, H, W) normalized RGB image tensor
        target_heatmap: (H, W) ground truth heatmap tensor
        pred_heatmap: (H, W) predicted heatmap probabilities tensor
        target_2d: (2,) target pixel location tensor
    
    Returns:
        rgb_vis: numpy array of denormalized RGB
        pred_heat_vis: numpy array of predicted heatmap
        target_pt: numpy array of GT 2D location
        pred_pt: numpy array of predicted 2D location (argmax)
    """
    # Denormalize RGB for visualization
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb.device).view(3, 1, 1)
    
    rgb_denorm = (rgb * std + mean).cpu().numpy()
    rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
    
    pred_heat = pred_heatmap.cpu().numpy()
    target_pt = target_2d.cpu().numpy()
    
    # Get predicted location (argmax)
    pred_y, pred_x = np.unravel_index(pred_heat.argmax(), pred_heat.shape)
    pred_pt = np.array([pred_x, pred_y])
    
    return rgb_vis, pred_heat, target_pt, pred_pt


def train_epoch(model, dataloader, optimizer, device, just_heatmap=False):
    """Train for one epoch. Volume loss + gripper loss (weight 0).
    
    Returns:
        avg_total_loss, avg_volume_loss, avg_height_loss (0), avg_gripper_loss
    """
    model.train()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    n_batches = 0

    for batch in tqdm(dataloader):
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)

        volume_logits, gripper_logits = model(
            rgb, training=True, start_keypoint_2d=start_keypoint_2d
        )  # (B, N_WINDOW, N_HEIGHT_BINS, H, W), (B, N_WINDOW, N_GRIPPER_BINS)

        target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)
        target_height_bins = discretize_height(target_height)  # (B, N_WINDOW)

        volume_loss = compute_volume_loss(volume_logits, trajectory_2d, target_height_bins)
        gripper_loss = compute_gripper_loss(gripper_logits, trajectory_gripper, trajectory_2d)
        loss = volume_loss + GRIPPER_LOSS_WEIGHT * gripper_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_volume_loss += volume_loss.item()
        total_gripper_loss += gripper_loss.item()
        n_batches += 1

    return total_loss / n_batches, total_volume_loss / n_batches, 0.0, total_gripper_loss / n_batches


def validate(model, dataloader, device, image_size=IMAGE_SIZE):
    """Validate model. Volume loss; extract pred 2D/height from volume; compute pred 3D for 3D pane.
    
    Returns:
        avg_loss, avg_volume_loss, 0, avg_gripper_loss, avg_pixel_error, avg_height_error,
        avg_height_error_tf, avg_gripper_error, sample_data (with trajectory_3d, pred_trajectory_3d, camera_pose, cam_K)
    """
    model.eval()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_pixel_error = 0
    total_height_error = 0
    total_height_error_tf = 0
    total_gripper_error = 0
    n_samples = 0
    sample_data = None

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
            trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
            trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
            camera_pose = batch['camera_pose']  # (B, 4, 4)
            cam_K_norm = batch['cam_K_norm']  # (B, 3, 3) normalized
            start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)
            target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)

            volume_logits, gripper_logits = model(
                rgb, training=False, start_keypoint_2d=start_keypoint_2d
            )  # (B, N_WINDOW, N_HEIGHT_BINS, H, W), (B, N_WINDOW, N_GRIPPER_BINS)

            target_height_bins = discretize_height(target_height)  # (B, N_WINDOW)
            volume_loss = compute_volume_loss(volume_logits, trajectory_2d, target_height_bins)
            gripper_loss = compute_gripper_loss(gripper_logits, trajectory_gripper, trajectory_2d)
            loss = volume_loss + GRIPPER_LOSS_WEIGHT * gripper_loss

            total_loss += loss.item() * rgb.shape[0]
            total_volume_loss += volume_loss.item() * rgb.shape[0]
            total_gripper_loss += gripper_loss.item() * rgb.shape[0]

            pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)  # (B, N_WINDOW, 2), (B, N_WINDOW)
            gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)  # (B, N_WINDOW, N_GRIPPER_BINS)
            pred_gripper = decode_gripper_bins(gripper_logits_at_pred)  # (B, N_WINDOW)

            # Pixel error
            B, N, H, W = volume_logits.shape[0], volume_logits.shape[1], volume_logits.shape[3], volume_logits.shape[4]
            for t in range(N):
                pixel_error_t = torch.norm(pred_2d[:, t] - trajectory_2d[:, t], dim=1).sum()
                total_pixel_error += pixel_error_t.item()
            total_height_error += torch.abs(pred_height - target_height).mean(dim=1).sum().item()
            total_height_error_tf += 0.0  # no teacher forcing for volume
            total_gripper_error += torch.abs(pred_gripper - trajectory_gripper).mean(dim=1).sum().item()
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                # Pred heatmap per timestep = max over height bins (for image panes)
                pred_heatmaps = []
                for t in range(N_WINDOW):
                    vol_t = volume_logits[0, t]  # (Nh, H, W)
                    # Softmax over full volume, then max along the ray for visualization
                    vol_probs = F.softmax(vol_t.view(-1), dim=0).view(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                    max_along_ray = vol_probs.max(dim=0)[0]  # (H, W)
                    pred_heatmaps.append(max_along_ray)
                pred_heatmaps = torch.stack(pred_heatmaps)  # (N_WINDOW, H, W)

                pred_h_0 = pred_height[0]
                pred_g_0 = pred_gripper[0]
                if pred_h_0.dim() == 0:
                    pred_h_0 = pred_h_0.unsqueeze(0).expand(N_WINDOW)
                if pred_g_0.dim() == 0:
                    pred_g_0 = pred_g_0.unsqueeze(0).expand(N_WINDOW)

                # Build pred_trajectory_3d by unprojecting pred_2d + pred_height (first sample)
                cam_pose_np = camera_pose[0].cpu().numpy()
                cam_K_norm_np = cam_K_norm[0].cpu().numpy()
                cam_K_np = cam_K_norm_np.copy()
                cam_K_np[0] *= image_size
                cam_K_np[1] *= image_size
                pred_trajectory_3d_list = []
                for t in range(N_WINDOW):
                    px, py = pred_2d[0, t, 0].item(), pred_2d[0, t, 1].item()
                    h = pred_height[0, t].item()
                    pt = recover_3d_from_direct_keypoint_and_height(
                        np.array([px, py], dtype=np.float64), h, cam_pose_np, cam_K_np
                    )
                    if pt is not None:
                        pred_trajectory_3d_list.append(pt)
                    else:
                        pred_trajectory_3d_list.append(trajectory_3d[0, t].cpu().numpy())
                pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

                sample_data = {
                    'rgb': rgb[0],
                    'target_heatmap': batch['heatmap_target'][0].to(device),  # (N_WINDOW, H, W)
                    'pred_heatmap': pred_heatmaps,
                    'trajectory_2d': trajectory_2d[0],
                    'trajectory_3d': trajectory_3d[0],
                    'pred_trajectory_3d': pred_trajectory_3d_np,
                    'camera_pose': camera_pose[0].cpu().numpy(),
                    'cam_K_norm': cam_K_norm[0].cpu().numpy(),
                    'cam_K_at_size': cam_K_np,
                    'pred_height': pred_h_0,
                    'target_height': target_height[0],
                    'pred_gripper': pred_g_0,
                    'target_gripper': trajectory_gripper[0],
                }

    n = max(1, n_samples)
    avg_pixel_error = total_pixel_error / (n * N_WINDOW)
    return (
        total_loss / n, total_volume_loss / n, 0.0, total_gripper_loss / n,
        avg_pixel_error, total_height_error / n, total_height_error_tf / n, total_gripper_error / n,
        sample_data
    )


def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train trajectory heatmap predictor")
    parser.add_argument("--dataset_root", "-d",nargs="+", default=["scratch/parsed_pickplace_exp1_feb9"],
                       help="One or more root directories with episodes (same structure); datasets are concatenated")
    parser.add_argument("--val_split", type=float, default=0.05,
                       help="Fraction of episodes to use for validation")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE,
                       help="Batch size for training")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                       help="Learning rate")
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS,
                       help="Number of epochs")
    parser.add_argument("--checkpoint", type=str, default="",
                       help="Path to checkpoint to resume from (default: 2D-only model)")
    parser.add_argument("--run_name", type=str, default="volume_dino_tracks",
                       help="Name of run (used for checkpoint and visualization paths)")
    args = parser.parse_args()

    # Checkpoint and visualization paths
    CHECKPOINT_DIR = Path(f"volume_dino_tracks/checkpoints/{args.run_name}")
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    
    # Setup device
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load full dataset (single root or concatenation of multiple roots)
    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    print("\nLoading dataset...")
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(
            dataset_root=dataset_roots[0],
            image_size=IMAGE_SIZE
        )
        print(f"  Single root: {dataset_roots[0]}")
    else:
        datasets = [
            RealTrajectoryDataset(dataset_root=root, image_size=IMAGE_SIZE)
            for root in dataset_roots
        ]
        full_dataset = ConcatDataset(datasets)
        for root, d in zip(dataset_roots, datasets):
            print(f"  {root}: {len(d)} samples")
    print(f"  Total: {len(full_dataset)} samples")
    
    # Split into train/val (simple split by samples)
    dataset_size = len(full_dataset)
    val_size = int(dataset_size * args.val_split)
    train_size = dataset_size - val_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")
    
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=True, 
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    
    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")
    
    # Initialize model
    print("\nInitializing model...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    
    # Count parameters
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {n_trainable:,} / {n_total:,} ({100*n_trainable/n_total:.2f}%)")
    
    # Setup optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=1e-4
    )
    
    # Load checkpoint if specified (before computing stats, so we can use checkpoint values if available).
    # Supports 2D-only model (latest.pt): only matching keys are loaded; height/gripper heads are randomly initialized.
    # Try .pt and .pth so both user's "latest.pt" and our saved "latest.pth" work.
    start_epoch = 0
    checkpoint_height_values = None
    checkpoint_gripper_values = None
    checkpoint_path = args.checkpoint
    if args.checkpoint:
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
    if args.checkpoint and os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint: {checkpoint_path} (initializing from 2D-only model; height/gripper 32-bin heads random if missing)")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model state dict with partial loading (skip missing keys and shape mismatches for height/gripper heads if 2D-only model)
        model_state = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        
        # Filter: only load keys that exist in current model AND have matching shapes
        filtered_state = {}
        shape_mismatches = []
        for k, v in model_state.items():
            if k in model_dict:
                if v.shape == model_dict[k].shape:
                    filtered_state[k] = v
                else:
                    shape_mismatches.append(f"{k}: checkpoint {v.shape} vs model {model_dict[k].shape}")
            # else: key not in current model, skip it
        
        missing_keys = set(model_dict.keys()) - set(model_state.keys())
        unexpected_keys = set(model_state.keys()) - set(model_dict.keys())
        
        if missing_keys:
            print(f"⚠ Missing keys in checkpoint (will use random initialization): {sorted(missing_keys)}")
        if unexpected_keys:
            print(f"⚠ Unexpected keys in checkpoint (will be ignored): {sorted(unexpected_keys)}")
        if shape_mismatches:
            print(f"⚠ Shape mismatches (will use random initialization for these):")
            for msg in shape_mismatches:
                print(f"    {msg}")
        
        model_dict.update(filtered_state)
        model.load_state_dict(model_dict, strict=False)  # strict=False allows remaining missing keys
        
        # Load optimizer state if available (may not match if architecture changed)
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        except Exception as e:
            print(f"⚠ Could not load optimizer state (architecture may have changed): {e}")
            print("  Starting with fresh optimizer state")
        
        #start_epoch = checkpoint.get('epoch', 0) + 1
        
        # Save height values from checkpoint if available (use these instead of recomputing)
        if 'min_height' in checkpoint and 'max_height' in checkpoint:
            checkpoint_height_values = (checkpoint['min_height'], checkpoint['max_height'])
            print(f"✓ Found height range in checkpoint: [{checkpoint_height_values[0]:.6f}, {checkpoint_height_values[1]:.6f}] m")
        # Load gripper min/max from checkpoint if available
        if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
            checkpoint_gripper_values = (checkpoint['min_gripper'], checkpoint['max_gripper'])
            print(f"✓ Found gripper range in checkpoint: [{checkpoint_gripper_values[0]:.6f}, {checkpoint_gripper_values[1]:.6f}]")
        print(f"✓ Resumed from epoch {start_epoch}")
    elif args.checkpoint:
        print(f"\n⚠ Checkpoint not found: {args.checkpoint}")
        print("  Initializing model from scratch (height/gripper heads will be randomly initialized)")
    
    # Compute or load min/max height and gripper from dataset (computed once, used as constants throughout training)
    # If resuming from checkpoint with values, use those instead of recomputing
    import model as model_module
    if checkpoint_height_values is not None:
        # Use height values from checkpoint (model was trained with these)
        model_module.MIN_HEIGHT, model_module.MAX_HEIGHT = checkpoint_height_values
        print(f"✓ Using height range from checkpoint: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset or checkpoint.")
    else:
        # Compute min/max height from entire dataset (once, used as constants)
        print("\nComputing height statistics from dataset...")
        all_heights = []
        for dataset in [train_dataset, val_dataset]:
            for i in range(len(dataset)):
                sample = dataset[i]
                trajectory_3d = sample['trajectory_3d'].numpy()  # (N_WINDOW, 3)
                waypoint_heights = trajectory_3d[:, 2]  # z-coordinates for all waypoints
                all_heights.extend(waypoint_heights.tolist())
        
        all_heights = np.array(all_heights)
        min_height = float(all_heights.min())
        max_height = float(all_heights.max())
        model_module.MIN_HEIGHT = min_height
        model_module.MAX_HEIGHT = max_height
        print(f"✓ Height range computed from dataset: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        print(f"  ({model_module.MIN_HEIGHT*1000:.2f}mm to {model_module.MAX_HEIGHT*1000:.2f}mm)")
        print(f"  (computed from {len(all_heights)} waypoint heights across all samples)")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset.")
        print(f"✓ These values will be used as constants throughout training and saved in checkpoints")
    
    # Gripper range: compute from dataset or use checkpoint/hard-coded values
    if checkpoint_gripper_values is not None:
        # Use gripper values from checkpoint (model was trained with these)
        model_module.MIN_GRIPPER, model_module.MAX_GRIPPER = checkpoint_gripper_values
        print(f"✓ Using gripper range from checkpoint: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
        if abs(model_module.MIN_GRIPPER - model_module.MAX_GRIPPER) < 1e-6:
            print(f"  ⚠ WARNING: MIN_GRIPPER == MAX_GRIPPER! All gripper predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset or checkpoint.")
    else:
        # Compute min/max gripper from entire dataset (once, used as constants)
        print("\nComputing gripper statistics from dataset...")
        all_grippers = []
        for dataset in [train_dataset, val_dataset]:
            for i in range(len(dataset)):
                sample = dataset[i]
                trajectory_gripper = sample['trajectory_gripper'].numpy()  # (N_WINDOW,)
                all_grippers.extend(trajectory_gripper.tolist())
        
        all_grippers = np.array(all_grippers)
        min_gripper = float(all_grippers.min())
        max_gripper = float(all_grippers.max())
        model_module.MIN_GRIPPER = min_gripper
        model_module.MAX_GRIPPER = max_gripper
        print(f"✓ Gripper range computed from dataset: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
        print(f"  (computed from {len(all_grippers)} gripper values across all samples)")
        if abs(model_module.MIN_GRIPPER - model_module.MAX_GRIPPER) < 1e-6:
            print(f"  ⚠ WARNING: MIN_GRIPPER == MAX_GRIPPER! All gripper predictions will be constant!")
            print(f"  This will cause all bins to decode to the same value. Check dataset.")
        print(f"✓ These values will be used as constants throughout training and saved in checkpoints")
    
    # Setup live visualization
    print("\nSetting up live visualization...")
    plt.ion()
    from mpl_toolkits.mplot3d import Axes3D
    # Grid: 3 loss + 2 image rows + 1 height line + 1 gripper line + 1 big 3D = 8 rows
    fig = plt.figure(figsize=(4*N_WINDOW, 14))
    gs = GridSpec(8, N_WINDOW, figure=fig, hspace=0.35, wspace=0.2,
                  height_ratios=[1, 1, 1, 2, 2, 0.8, 0.8, 5])
    
    ax_loss_heatmap = fig.add_subplot(gs[0, :])
    ax_loss_height = fig.add_subplot(gs[1, :])
    ax_loss_gripper = fig.add_subplot(gs[2, :])
    
    # 2 rows for train, val images (each row N_WINDOW cols)
    axes_vis = []
    for row_idx in range(2):
        row_axes = []
        for t in range(N_WINDOW):
            ax = fig.add_subplot(gs[3 + row_idx, t])
            ax.axis('off')
            if row_idx == 0:
                ax.set_title(f"t={t}", fontweight='bold', fontsize=10)
            row_axes.append(ax)
        axes_vis.append(row_axes)
    
    # Single line charts: height (pred vs GT over timesteps), gripper (pred vs GT over timesteps)
    ax_height_line = fig.add_subplot(gs[5, :])
    ax_gripper_line = fig.add_subplot(gs[6, :])

    # 3D volume pane (GT vs Pred trajectory) - larger height_ratio
    ax_3d = fig.add_subplot(gs[7, :], projection='3d')
    
    plt.show(block=False)
    fig.canvas.draw()  # Initial draw to ensure window is ready
    plt.pause(0.1)  # Small pause to let window initialize
    
    # Training loop
    print(f"\nStarting training for {args.epochs} epochs...")
    best_val_loss = float('inf')
    
    # Track losses separately
    train_heatmap_losses = []
    train_height_losses = []
    train_gripper_losses = []
    val_heatmap_losses = []
    val_height_losses = []
    val_gripper_losses = []
    
    for epoch in range(start_epoch, args.epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{args.epochs}")
        print(f"{'='*60}")
        
        # Train
        train_loss, train_volume_loss, _, train_gripper_loss = train_epoch(
            model, train_loader, optimizer, device, just_heatmap=epoch<2
        )
        train_heatmap_losses.append(train_volume_loss)  # store volume loss in same list for plot
        train_height_losses.append(0.0)
        train_gripper_losses.append(train_gripper_loss)
        print(f"Train Loss: {train_loss:.4f} (Volume: {train_volume_loss:.4f}, Gripper: {train_gripper_loss:.6f})")
        
        # Validate and get sample data for visualization
        val_loss, val_heatmap_loss, val_height_loss, val_gripper_loss, \
        val_error, val_height_error, val_height_error_tf, val_gripper_error, sample_val = validate(
            model, val_loader, device
        )
        val_heatmap_losses.append(val_heatmap_loss)  # volume loss
        val_height_losses.append(val_height_loss)
        val_gripper_losses.append(val_gripper_loss)

        print(f"Val - Loss: {val_loss:.4f}, Volume: {val_heatmap_loss:.4f}, Pixel Error: {val_error:.2f}px, Height Error: {val_height_error*1000:.3f}mm, Gripper: {val_gripper_error:.4f}")
        
        # Get train sample for visualization (all timesteps)
        model.eval()
        with torch.no_grad():
            train_sample_batch = next(iter(train_loader))
            train_rgb = train_sample_batch['rgb'][0:1].to(device)
            train_target_heatmap_all = train_sample_batch['heatmap_target'][0]  # (N_WINDOW, H, W)
            train_trajectory_2d = train_sample_batch['trajectory_2d'][0]  # (N_WINDOW, 2)
            train_trajectory_3d = train_sample_batch['trajectory_3d'][0]  # (N_WINDOW, 3)
            train_trajectory_gripper = train_sample_batch['trajectory_gripper'][0]  # (N_WINDOW,)
            train_start_keypoint_2d = train_trajectory_2d[0]  # (2,)
            train_camera_pose = train_sample_batch['camera_pose'][0].cpu().numpy()
            train_cam_K_norm = train_sample_batch['cam_K_norm'][0].cpu().numpy()
            train_cam_K = train_cam_K_norm.copy()
            train_cam_K[0] *= IMAGE_SIZE
            train_cam_K[1] *= IMAGE_SIZE

            train_volume_logits, train_gripper_logits = model(
                train_rgb, training=False, start_keypoint_2d=train_start_keypoint_2d
            )
            train_pred_2d, train_pred_height = extract_pred_2d_and_height_from_volume(train_volume_logits[0:1])
            train_pred_height = train_pred_height[0]  # (N_WINDOW,)
            train_gripper_at_pred = extract_gripper_logits_at_pixels(train_gripper_logits, train_pred_2d)
            train_pred_gripper = decode_gripper_bins(train_gripper_at_pred)[0]  # (N_WINDOW,)

            train_pred_heatmaps = []
            for t in range(N_WINDOW):
                vol_t = train_volume_logits[0, t]  # (Nh, H, W)
                # Softmax over full volume, then max along the ray for visualization
                vol_probs = F.softmax(vol_t.view(-1), dim=0).view(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                max_along_ray = vol_probs.max(dim=0)[0]  # (H, W)
                train_pred_heatmaps.append(max_along_ray)
            train_pred_heatmaps = torch.stack(train_pred_heatmaps)  # (N_WINDOW, H, W)

            train_pred_trajectory_3d_list = []
            for t in range(N_WINDOW):
                px, py = train_pred_2d[0, t, 0].item(), train_pred_2d[0, t, 1].item()
                h = train_pred_height[t].item()
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([px, py], dtype=np.float64), h, train_camera_pose, train_cam_K
                )
                train_pred_trajectory_3d_list.append(pt if pt is not None else train_trajectory_3d[t].cpu().numpy())
            train_pred_trajectory_3d_np = np.array(train_pred_trajectory_3d_list)

            sample_train = {
                'rgb': train_rgb[0],
                'target_heatmap': train_target_heatmap_all,
                'pred_heatmap': train_pred_heatmaps,
                'trajectory_2d': train_trajectory_2d,
                'trajectory_3d': train_trajectory_3d,
                'pred_trajectory_3d': train_pred_trajectory_3d_np,
                'camera_pose': train_camera_pose,
                'cam_K_at_size': train_cam_K,
                'pred_height': train_pred_height,
                'target_height': train_trajectory_3d[:, 2],
                'pred_gripper': train_pred_gripper,
                'target_gripper': train_trajectory_gripper,
            }
        
        # Update visualization
        epochs_range = np.arange(len(train_heatmap_losses))
        ax_loss_heatmap.clear()
        ax_loss_heatmap.plot(epochs_range, train_heatmap_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_heatmap.plot(epochs_range, val_heatmap_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_heatmap.set_xlabel('Epoch')
        ax_loss_heatmap.set_ylabel('Volume Loss (CE)')
        ax_loss_heatmap.set_title(f'Volume Loss | Train: {train_volume_loss:.3f} | Val: {val_heatmap_loss:.3f}')
        ax_loss_heatmap.legend(loc='upper right')
        ax_loss_heatmap.grid(alpha=0.3)

        ax_loss_height.clear()
        ax_loss_height.plot(epochs_range, train_height_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_height.plot(epochs_range, val_height_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_height.set_xlabel('Epoch')
        ax_loss_height.set_ylabel('Height (unused)')
        ax_loss_height.set_title('Height loss (0)')
        ax_loss_height.legend(loc='upper right')
        ax_loss_height.grid(alpha=0.3)
        
        # Gripper loss plot
        ax_loss_gripper.clear()
        epochs_range = np.arange(len(train_gripper_losses))
        ax_loss_gripper.plot(epochs_range, train_gripper_losses, 'o-', label='Train', color='blue', linewidth=2)
        ax_loss_gripper.plot(epochs_range, val_gripper_losses, 's-', label='Val', color='green', linewidth=2)
        ax_loss_gripper.set_xlabel('Epoch')
        ax_loss_gripper.set_ylabel('Gripper Loss (MSE)')
        ax_loss_gripper.set_title(f'Gripper Loss | Train: {train_gripper_loss:.6f} | Val: {val_gripper_loss:.6f}')
        ax_loss_gripper.legend(loc='upper right')
        ax_loss_gripper.grid(alpha=0.3)
        
        # Visualize samples from train and val (show all timesteps horizontally)
        for row_idx, (sample, split_name) in enumerate([
            (sample_train, 'Train'),
            (sample_val, 'Val')
        ]):
            if sample is None:
                continue
            
            # Denormalize RGB for visualization
            mean = torch.tensor([0.485, 0.456, 0.406], device=sample['rgb'].device).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=sample['rgb'].device).view(3, 1, 1)
            rgb_denorm = (sample['rgb'] * std + mean).cpu().numpy()
            rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
            
            # Visualize each timestep
            for t in range(N_WINDOW):
                ax = axes_vis[row_idx][t]
                ax.clear()
                
                # Get heatmaps and keypoints for this timestep
                target_heatmap_t = sample['target_heatmap'][t].cpu().numpy()  # (H, W)
                pred_heatmap_t = sample['pred_heatmap'][t].cpu().numpy()  # (H, W)
                target_2d_t = sample['trajectory_2d'][t].cpu().numpy()  # (2,)
                
                # Get predicted location (argmax)
                pred_y, pred_x = np.unravel_index(pred_heatmap_t.argmax(), pred_heatmap_t.shape)
                pred_2d_t = np.array([pred_x, pred_y])
                
                # Show RGB with heatmap overlay and keypoints
                ax.imshow(rgb_vis)
                ax.imshow(pred_heatmap_t, alpha=0.6, cmap='hot')
                
                # Plot GT and predicted keypoints
                ax.scatter(target_2d_t[0], target_2d_t[1], c='white', s=100, 
                          marker='o', edgecolors='black', linewidths=2, 
                          label='GT', zorder=10)
                ax.scatter(pred_2d_t[0], pred_2d_t[1], c='lime', s=100, 
                          marker='x', linewidths=3, 
                          label='Pred', zorder=10)
                
                # Compute errors
                pixel_err = np.linalg.norm(target_2d_t - pred_2d_t)
                title_parts = [f"t={t}", f"Px:{pixel_err:.1f}px"]
                if 'pred_height' in sample and 'target_height' in sample:
                    # Ensure pred_height is a 1D tensor before indexing
                    pred_h_tensor = sample['pred_height']
                    if pred_h_tensor.dim() == 0:
                        # Scalar - use it directly (shouldn't happen, but handle gracefully)
                        pred_h = pred_h_tensor.cpu().item()
                    else:
                        pred_h = pred_h_tensor[t].cpu().item() if pred_h_tensor.dim() > 0 else pred_h_tensor.cpu().item()
                    target_h_tensor = sample['target_height']
                    target_h = target_h_tensor[t].cpu().item() if target_h_tensor.dim() > 0 else target_h_tensor.cpu().item()
                    height_err = abs(pred_h - target_h) * 1000  # mm
                    title_parts.append(f"H:{pred_h*1000:.1f}mm")
                if 'pred_gripper' in sample and 'target_gripper' in sample:
                    # Ensure pred_gripper is a 1D tensor before indexing
                    pred_g_tensor = sample['pred_gripper']
                    pred_g = pred_g_tensor[t].cpu().item() if pred_g_tensor.dim() > 0 else pred_g_tensor.cpu().item()
                    target_g_tensor = sample['target_gripper']
                    target_g = target_g_tensor[t].cpu().item() if target_g_tensor.dim() > 0 else target_g_tensor.cpu().item()
                    title_parts.append(f"G:{pred_g:.2f}(GT:{target_g:.2f})")
                ax.set_title("\n".join(title_parts), fontsize=8)
                
                ax.axis('off')
                if t == 0:
                    ax.legend(loc='upper right', fontsize=6)

        # Single line charts: height and gripper (pred vs GT over timesteps), using val sample
        import model as model_module
        sample_for_lines = sample_val if sample_val is not None else sample_train
        if sample_for_lines is not None:
            ts = np.arange(N_WINDOW)
            ax_height_line.clear()
            ax_height_line.set_xlabel('Timestep')
            ax_height_line.set_ylabel('Height (mm)')
            ax_height_line.set_title('Height: Pred vs GT')
            ax_height_line.grid(alpha=0.3)
            ax_height_line.set_ylim([model_module.MIN_HEIGHT * 1000 - 10, model_module.MAX_HEIGHT * 1000 + 10])
            if 'pred_height' in sample_for_lines and 'target_height' in sample_for_lines:
                pred_h = sample_for_lines['pred_height']
                target_h = sample_for_lines['target_height']
                if hasattr(pred_h, 'cpu'):
                    pred_h = np.atleast_1d(pred_h.cpu().numpy())
                else:
                    pred_h = np.atleast_1d(np.asarray(pred_h))
                if hasattr(target_h, 'cpu'):
                    target_h = np.atleast_1d(target_h.cpu().numpy())
                else:
                    target_h = np.atleast_1d(np.asarray(target_h))
                if len(pred_h) == N_WINDOW and len(target_h) == N_WINDOW:
                    ax_height_line.plot(ts, pred_h * 1000, 'o-', color='red', label='Pred', linewidth=2, markersize=4)
                    ax_height_line.plot(ts, target_h * 1000, 's-', color='green', label='GT', linewidth=2, markersize=4)
            ax_height_line.legend(loc='upper right')

            ax_gripper_line.clear()
            ax_gripper_line.set_xlabel('Timestep')
            ax_gripper_line.set_ylabel('Gripper')
            ax_gripper_line.set_title('Gripper: Pred vs GT')
            ax_gripper_line.grid(alpha=0.3)
            ax_gripper_line.set_ylim([model_module.MIN_GRIPPER - 0.1, model_module.MAX_GRIPPER + 0.1])
            if 'pred_gripper' in sample_for_lines and 'target_gripper' in sample_for_lines:
                pred_g = sample_for_lines['pred_gripper']
                target_g = sample_for_lines['target_gripper']
                if hasattr(pred_g, 'cpu'):
                    pred_g = np.atleast_1d(pred_g.cpu().numpy())
                else:
                    pred_g = np.atleast_1d(np.asarray(pred_g))
                if hasattr(target_g, 'cpu'):
                    target_g = np.atleast_1d(target_g.cpu().numpy())
                else:
                    target_g = np.atleast_1d(np.asarray(target_g))
                if len(pred_g) == N_WINDOW and len(target_g) == N_WINDOW:
                    ax_gripper_line.plot(ts, pred_g, 'o-', color='red', label='Pred', linewidth=2, markersize=4)
                    ax_gripper_line.plot(ts, target_g, 's-', color='green', label='GT', linewidth=2, markersize=4)
            ax_gripper_line.legend(loc='upper right')

        # 3D volume pane: volume (transparent white) + GT (red) + Pred (blue)
        ax_3d.clear()
        sample_for_3d = sample_val if sample_val is not None else sample_train
        if sample_for_3d is not None and 'trajectory_3d' in sample_for_3d and 'pred_trajectory_3d' in sample_for_3d:
            traj_3d = sample_for_3d['trajectory_3d']
            if hasattr(traj_3d, 'cpu'):
                traj_3d = traj_3d.cpu().numpy()
            pred_3d = sample_for_3d['pred_trajectory_3d']
            cam_pose = sample_for_3d['camera_pose']
            cam_K = sample_for_3d['cam_K_at_size']
            z_min = float(traj_3d[:, 2].min()) - 0.01
            z_max = float(traj_3d[:, 2].max()) + 0.01
            height_centers = np.linspace(z_min, z_max, N_HEIGHT_BINS)
            volume_pts = build_volume_3d_points_for_vis(
                IMAGE_SIZE, IMAGE_SIZE, cam_pose, cam_K, height_centers, pixel_step=32
            )
            if len(volume_pts) > 0:
                ax_3d.scatter(
                    volume_pts[:, 0], volume_pts[:, 1], volume_pts[:, 2],
                    c='white', alpha=0.03, s=1, edgecolors='none'
                )
            ax_3d.scatter(
                traj_3d[:, 0], traj_3d[:, 1], traj_3d[:, 2],
                c='red', alpha=1.0, s=60, edgecolors='darkred', linewidths=1.5, label='GT'
            )
            ax_3d.scatter(
                pred_3d[:, 0], pred_3d[:, 1], pred_3d[:, 2],
                c='blue', alpha=1.0, s=60, edgecolors='darkblue', linewidths=1.5, label='Pred'
            )
            ax_3d.set_xlabel('X (m)')
            ax_3d.set_ylabel('Y (m)')
            ax_3d.set_zlabel('Z (m)')
            ax_3d.set_title('Volume (white) + GT (red) vs Pred (blue)')
            ax_3d.legend()
        
        # Update figure (better for macOS)
        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.pause(0.01)  # Slightly longer pause for better responsiveness
        
        # Use val loss for checkpointing
        
        # Save checkpoint (use model constants, computed once from dataset)
        import model as model_module
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height': model_module.MIN_HEIGHT,  # Use model constants (computed once from dataset)
            'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,  # Regression: [-0.2, 0.8]
            'max_gripper': model_module.MAX_GRIPPER,
        }
        
        # Save latest
        torch.save(checkpoint, CHECKPOINT_DIR / 'latest.pth')
        
        # Save best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, CHECKPOINT_DIR / 'best.pth')
            print(f"✓ Saved best model (val_loss={val_loss:.4f})")
        
    
    plt.ioff()
    plt.show()
    
    print("\n" + "="*60)
    print("✓ Training complete!")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints saved to: {CHECKPOINT_DIR}")


if __name__ == "__main__":
    main()
