"""Train trajectory volume predictor on LIBERO demonstrations.

Model predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (CE at trajectory pixel only).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel (teacher forcing), decoded at pred pixel in val/inference.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import wandb
import json
import random

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

import cv2
from robosuite.utils.camera_utils import project_points_from_world_to_camera
import robosuite.utils.transform_utils as T_rob

from data import RealTrajectoryDataset, CachedTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE
import model as model_module  # Import module to access updated MIN_HEIGHT/MAX_HEIGHT at runtime
from utils import recover_3d_from_direct_keypoint_and_height

# Lazy imports for ablation models (avoids import errors if deps missing)
def get_model_class(model_type):
    if model_type == "para":
        return TrajectoryHeatmapPredictor
    elif model_type == "act":
        from model_act import ACTPredictor
        return ACTPredictor
    elif model_type == "da3":
        from model_da3 import DA3Predictor
        return DA3Predictor
    elif model_type == "moge":
        from model_moge import MoGePredictor
        return MoGePredictor
    elif model_type == "dino_vla":
        from model_dino_vla import DinoVLAPredictor
        return DinoVLAPredictor
    elif model_type == "internvl":
        from model_vla_internvl import InternVLAPredictor
        return InternVLAPredictor
    elif model_type == "internvl_act":
        from model_vla_internvl_act import InternVLACTPredictor
        return InternVLACTPredictor
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

def is_heatmap_model(model_type):
    """ACT-style models use direct regression; everything else uses pixel-aligned heatmaps."""
    return model_type not in ("act", "internvl_act")

# Helper functions for discretization
def discretize_height(height_values):
    """Discretize continuous height values into bin indices.

    Args:
        height_values: (B, N_WINDOW) or (N_WINDOW,) tensor of heights in [MIN_HEIGHT, MAX_HEIGHT]

    Returns:
        bin_indices: (B, N_WINDOW) or (N_WINDOW,) tensor of bin indices in [0, N_HEIGHT_BINS-1]
    """
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = (height_values - min_h) / (max_h - min_h + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    bin_indices = (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)
    return bin_indices




def decode_height_bins(bin_logits):
    """Decode height bin logits back to continuous height values.

    Args:
        bin_logits: (B, N_WINDOW, N_HEIGHT_BINS) logits for each bin

    Returns:
        height_values: (B, N_WINDOW) continuous height values in [MIN_HEIGHT, MAX_HEIGHT]
    """
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=bin_logits.device)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    height_values = normalized * (max_h - min_h) + min_h
    return height_values


def discretize_gripper(gripper_values):
    """Discretize continuous gripper values into bin indices."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    normalized = (gripper_values - min_g) / (max_g - min_g + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)


def decode_gripper_bins(bin_logits):
    """Decode (B, N_WINDOW, N_GRIPPER_BINS) logits → (B, N_WINDOW) continuous gripper values."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    bin_indices = bin_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)
    return bin_centers[bin_indices] * (max_g - min_g) + min_g


def discretize_rotation(euler_values):
    """Discretize (B, N_WINDOW, 3) euler angles → (B, N_WINDOW, 3) bin indices, per axis."""
    min_r = torch.tensor(model_module.MIN_ROT, device=euler_values.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=euler_values.device, dtype=torch.float32)
    normalized = (euler_values - min_r) / (max_r - min_r + 1e-8)
    normalized = normalized.clamp(0.0, 1.0)
    return (normalized * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def decode_rotation_bins(rot_logits):
    """Decode (B, N_WINDOW, 3, N_ROT_BINS) → (B, N_WINDOW, 3) continuous euler angles."""
    min_r = torch.tensor(model_module.MIN_ROT, device=rot_logits.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=rot_logits.device, dtype=torch.float32)
    bin_indices = rot_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_ROT_BINS, device=rot_logits.device)
    return bin_centers[bin_indices] * (max_r - min_r) + min_r


def compute_rotation_loss(pred_rotation_logits, target_euler):
    """Cross-entropy for 3 euler axes, averaged so scale matches gripper loss (~log(N_ROT_BINS)).

    Args:
        pred_rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)
        target_euler:         (B, N_WINDOW, 3) euler angles in radians
    """
    target_bins = discretize_rotation(target_euler)
    B, N, _, Nr = pred_rotation_logits.shape
    losses = []
    for axis in range(3):
        logits_axis = pred_rotation_logits[:, :, axis, :].reshape(B * N, Nr)
        target_axis = target_bins[:, :, axis].reshape(B * N)
        losses.append(F.cross_entropy(logits_axis, target_axis))
    return torch.stack(losses).mean()


# Training configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000
IMAGE_SIZE = 448  # 28x28 DINO patches at patch_size=16

VOLUME_LOSS_WEIGHT    = 1.0
GRIPPER_LOSS_WEIGHT   = 5.0
ROTATION_LOSS_WEIGHT  = 0.5
POS_LOSS_WEIGHT       = 1.0
ACT_GRIPPER_LOSS_WEIGHT = 1.0


def build_volume_3d_points_for_vis(H, W, camera_pose, cam_K, height_bucket_centers, pixel_step=32):
    """Build 3D points for volume visualization (numpy). Returns (N, 3)."""
    points = []
    for y in range(0, H, pixel_step):
        for x in range(0, W, pixel_step):
            for height in height_bucket_centers:
                pt = recover_3d_from_direct_keypoint_and_height(
                    np.array([x, y], dtype=np.float64), float(height), camera_pose, cam_K
                )
                if pt is not None:
                    points.append(pt)
    return np.array(points) if points else np.zeros((0, 3))


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins):
    """Cross-entropy with softmax over all 3D cells (per timestep).

    For each timestep, flatten volume to (B, H*W*N_HEIGHT_BINS), softmax over cells,
    and supervise with the correct 3D cell index: (height_bin, y, x) -> h_bin*(H*W) + y*W + x.

    Args:
        pred_volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
        trajectory_2d: (B, N_WINDOW, 2) pixel coords [x, y]
        target_height_bins: (B, N_WINDOW) bin indices in [0, N_HEIGHT_BINS-1]
    """
    B, N, Nh, H, W = pred_volume_logits.shape
    device = pred_volume_logits.device
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)  # (B, N)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)  # (B, N)
    h_bin = target_height_bins.clamp(0, Nh - 1)  # (B, N)
    losses = []
    for t in range(N):
        logits_t = pred_volume_logits[:, t]  # (B, Nh, H, W)
        logits_flat = logits_t.reshape(B, -1)  # (B, Nh*H*W)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()  # (B,)
        losses.append(F.cross_entropy(logits_flat, target_idx, reduction='mean'))
    return torch.stack(losses).mean()


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.

    For each t: max over height bins gives (H,W) score; argmax gives (x,y); at (x,y) argmax over bins gives height bin.
    Returns:
        pred_2d: (B, N_WINDOW, 2) float pixel coords
        pred_height: (B, N_WINDOW) continuous height from decode_height_bins at that pixel
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.reshape(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def compute_gripper_loss(pred_gripper_logits, target_gripper):
    """Binary cross-entropy for gripper (open/close).

    IMPORTANT: gripper MLP outputs (B, N_WINDOW) raw logits — NOT bins.
    Target: -1 (open) → 0, +1 (close) → 1.
    DO NOT revert this to cross_entropy — the model outputs 1 logit, not 32 bins.

    Args:
        pred_gripper_logits: (B, N_WINDOW) raw logits (pre-sigmoid)
        target_gripper:      (B, N_WINDOW) values in [-1, 1]
    """
    target_binary = (target_gripper > 0).float()
    return F.binary_cross_entropy_with_logits(pred_gripper_logits, target_binary)


def visualize_sample(rgb, target_heatmap, pred_heatmap, target_2d):
    """Get visualization arrays for a single sample."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb.device).view(3, 1, 1)
    rgb_denorm = (rgb * std + mean).cpu().numpy()
    rgb_vis = np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)
    pred_heat = pred_heatmap.cpu().numpy()
    target_pt = target_2d.cpu().numpy()
    pred_y, pred_x = np.unravel_index(pred_heat.argmax(), pred_heat.shape)
    pred_pt = np.array([pred_x, pred_y])
    return rgb_vis, pred_heat, target_pt, pred_pt


def normalize_to_01(values, min_vals, max_vals):
    """Normalize values to [0, 1] given per-axis min/max. Clamps to [0, 1]."""
    range_vals = max_vals - min_vals + 1e-8
    return ((values - min_vals) / range_vals).clamp(0.0, 1.0)

def denormalize_from_01(values, min_vals, max_vals):
    """Denormalize values from [0, 1] back to original scale."""
    range_vals = max_vals - min_vals
    return values * range_vals + min_vals

def compute_act_loss(pos_pred, rot_pred, gripper_pred, trajectory_3d, trajectory_euler, trajectory_gripper):
    """ACT losses: MSE for pos/rot (normalized [0,1]), BCE for binary gripper."""
    device = pos_pred.device
    # Normalize pos/rot targets to [0, 1]
    min_pos = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
    max_pos = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
    min_rot = torch.tensor(model_module.MIN_ROT, device=device, dtype=torch.float32)
    max_rot = torch.tensor(model_module.MAX_ROT, device=device, dtype=torch.float32)

    pos_target  = normalize_to_01(trajectory_3d, min_pos, max_pos)
    rot_target  = normalize_to_01(trajectory_euler, min_rot, max_rot)

    pos_loss     = F.mse_loss(pos_pred, pos_target)
    rot_loss     = F.mse_loss(rot_pred, rot_target)
    # Binary gripper: -1 (open) → 0, +1 (close) → 1
    grip_target  = (trajectory_gripper > 0).float()
    gripper_loss = ACT_GRIPPER_LOSS_WEIGHT * F.binary_cross_entropy_with_logits(gripper_pred, grip_target)
    return pos_loss, rot_loss, gripper_loss


def train_epoch(
    model,
    dataloader,
    optimizer,
    device,
    just_heatmap=False,
    global_step_start=0,
    vis_every_steps=50,
    vis_callback=None,
    log_scalars_every=10,
    model_type="para",
    save_every_steps=0,
    save_callback=None,
):
    """Train for one epoch.

    Args:
        just_heatmap: if True, only volume loss is applied (gripper loss skipped).
        model_type: 'para'/'da3'/'moge' use heatmap CE, 'act' uses direct MSE.
    """
    model.train()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_rotation_loss = 0
    n_batches = 0

    heatmap_mode = is_heatmap_model(model_type)

    global_step = int(global_step_start)
    pbar = tqdm(dataloader, desc="Train", leave=False)
    for batch in pbar:
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)

        trajectory_euler = batch['trajectory_euler'].to(device)  # (B, N_WINDOW, 3)

        # Extra kwargs for models that accept them
        extra_kwargs = {}
        if model_type in ("dino_vla", "act") and 'clip_embedding' in batch:
            extra_kwargs['clip_embedding'] = batch['clip_embedding'].to(device)
        elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
            extra_kwargs['task_text'] = batch['task_description']  # list of strings

        if heatmap_mode:
            # Scale GT coords from IMAGE_SIZE → pred_size (needed before model call)
            target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)
            target_height_bins = discretize_height(target_height)
            pred_size_tmp = PRED_SIZE
            coord_scale = pred_size_tmp / IMAGE_SIZE
            trajectory_2d_pred = trajectory_2d * coord_scale  # GT pixels in pred_size space

            # Model: volume head (spatial), gripper/rotation MLP at GT pixels (teacher forcing)
            volume_logits, gripper_logits, rotation_logits, _feats = model(
                rgb, start_keypoint_2d, query_pixels=trajectory_2d_pred, **extra_kwargs
            )

            pred_size = volume_logits.shape[-1]
            coord_scale = pred_size / IMAGE_SIZE

            volume_loss   = VOLUME_LOSS_WEIGHT * compute_volume_loss(volume_logits, trajectory_2d_pred, target_height_bins)
            gripper_loss  = GRIPPER_LOSS_WEIGHT * compute_gripper_loss(gripper_logits, trajectory_gripper)
            rotation_loss = ROTATION_LOSS_WEIGHT * compute_rotation_loss(rotation_logits, trajectory_euler)

            if just_heatmap:
                loss = volume_loss
            else:
                loss = volume_loss + gripper_loss + rotation_loss
        else:
            # ACT: direct regression with proprioception
            # Normalize proprioception inputs to [0,1]
            device = rgb.device
            min_pos = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
            max_pos = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
            min_grip = torch.tensor(model_module.MIN_GRIPPER, device=device, dtype=torch.float32)
            max_grip = torch.tensor(model_module.MAX_GRIPPER, device=device, dtype=torch.float32)
            current_eef_norm = normalize_to_01(trajectory_3d[:, 0], min_pos, max_pos)
            current_grip_norm = normalize_to_01(trajectory_gripper[:, 0], min_grip, max_grip)

            pos_pred, rot_pred, gripper_pred = model(
                rgb, start_keypoint_2d,
                current_eef_pos=current_eef_norm,
                current_gripper=current_grip_norm,
                **extra_kwargs,
            )
            volume_loss, rotation_loss, gripper_loss = compute_act_loss(
                pos_pred, rot_pred, gripper_pred,
                trajectory_3d, trajectory_euler, trajectory_gripper,
            )
            loss = volume_loss + gripper_loss + rotation_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_volume_loss += volume_loss.item()
        total_gripper_loss += gripper_loss.item()
        total_rotation_loss += rotation_loss.item()
        n_batches += 1
        pbar.set_postfix(loss=f"{loss.item():.4f}", vol=f"{volume_loss.item():.4f}",
                         grip=f"{gripper_loss.item():.4f}", rot=f"{rotation_loss.item():.4f}")
        global_step += 1
        if log_scalars_every > 0 and (global_step % log_scalars_every == 0):
            wandb.log({
                "train_step/loss":          loss.item(),
                "train_step/volume_loss":   volume_loss.item(),
                "train_step/gripper_loss":  gripper_loss.item(),
                "train_step/rotation_loss": rotation_loss.item(),
            }, step=global_step)
        if vis_callback is not None and vis_every_steps > 0 and (global_step % vis_every_steps == 0):
            vis_callback(global_step)
        if save_callback is not None and save_every_steps > 0 and (global_step % save_every_steps == 0):
            save_callback(global_step)

    return total_loss / n_batches, total_volume_loss / n_batches, total_gripper_loss / n_batches, total_rotation_loss / n_batches, global_step


def validate(model, dataloader, device, image_size=IMAGE_SIZE, model_type="para"):
    """Validate model."""
    model.eval()
    total_loss = 0
    total_volume_loss = 0
    total_gripper_loss = 0
    total_pixel_error = 0
    total_height_error = 0
    total_height_error_tf = 0
    total_gripper_error = 0
    n_samples = 0
    sample_data = None

    heatmap_mode = is_heatmap_model(model_type)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Val", leave=False)
        for batch_idx, batch in enumerate(pbar):
            rgb = batch['rgb'].to(device)
            trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
            trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
            trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
            camera_pose = batch['camera_pose']  # (B, 4, 4)
            cam_K_norm = batch['cam_K_norm']  # (B, 3, 3) normalized
            start_keypoint_2d = trajectory_2d[:, 0]  # (B, 2)
            target_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)

            trajectory_euler = batch['trajectory_euler'].to(device)

            extra_kwargs = {}
            if model_type in ("dino_vla", "act") and 'clip_embedding' in batch:
                extra_kwargs['clip_embedding'] = batch['clip_embedding'].to(device)
            elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
                extra_kwargs['task_text'] = batch['task_description']

            if not heatmap_mode:
                # ACT: direct regression validation with proprioception
                min_pos_t = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
                max_pos_t = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)
                min_grip_t = torch.tensor(model_module.MIN_GRIPPER, device=device, dtype=torch.float32)
                max_grip_t = torch.tensor(model_module.MAX_GRIPPER, device=device, dtype=torch.float32)
                current_eef_norm = normalize_to_01(trajectory_3d[:, 0], min_pos_t, max_pos_t)
                current_grip_norm = normalize_to_01(trajectory_gripper[:, 0], min_grip_t, max_grip_t)

                pos_pred, rot_pred, gripper_pred = model(
                    rgb, start_keypoint_2d,
                    current_eef_pos=current_eef_norm,
                    current_gripper=current_grip_norm,
                    **extra_kwargs,
                )
                vol_loss, rot_loss, grip_loss = compute_act_loss(
                    pos_pred, rot_pred, gripper_pred,
                    trajectory_3d, trajectory_euler, trajectory_gripper,
                )
                loss = vol_loss + grip_loss + rot_loss
                total_loss += loss.item() * rgb.shape[0]
                total_volume_loss += vol_loss.item() * rgb.shape[0]
                total_gripper_loss += grip_loss.item() * rgb.shape[0]
                # Denormalize pos predictions for mm error metric
                pos_pred_denorm = denormalize_from_01(pos_pred, min_pos_t, max_pos_t)
                pos_err_mm = (pos_pred_denorm - trajectory_3d).norm(dim=-1).mean(dim=1).sum().item() * 1000
                total_pixel_error += pos_err_mm
                # Binary gripper accuracy: logit > 0 → close (+1), else open (-1)
                grip_pred_binary = torch.where(gripper_pred > 0,
                                               torch.ones_like(gripper_pred),
                                               -torch.ones_like(gripper_pred))
                total_gripper_error += torch.abs(grip_pred_binary - trajectory_gripper).mean(dim=1).sum().item()
                n_samples += rgb.shape[0]
                pbar.set_postfix(loss=f"{loss.item():.4f}", pos_mm=f"{pos_err_mm/rgb.shape[0]:.1f}")
                continue

            target_height_bins = discretize_height(target_height)

            # Scale GT coords from IMAGE_SIZE → pred_size for loss supervision
            pred_size = PRED_SIZE
            coord_scale = pred_size / image_size
            trajectory_2d_pred = trajectory_2d * coord_scale

            # Model: volume head spatial + gripper/rotation MLP at GT pixels (teacher forcing)
            volume_logits, gripper_logits, rotation_logits, feats = model(
                rgb, start_keypoint_2d, query_pixels=trajectory_2d_pred, **extra_kwargs
            )
            volume_loss   = VOLUME_LOSS_WEIGHT * compute_volume_loss(volume_logits, trajectory_2d_pred, target_height_bins)
            gripper_loss  = GRIPPER_LOSS_WEIGHT * compute_gripper_loss(gripper_logits, trajectory_gripper)
            rotation_loss = ROTATION_LOSS_WEIGHT * compute_rotation_loss(rotation_logits, trajectory_euler)
            loss = volume_loss + gripper_loss + rotation_loss

            total_loss += loss.item() * rgb.shape[0]
            total_volume_loss += volume_loss.item() * rgb.shape[0]
            total_gripper_loss += gripper_loss.item() * rgb.shape[0]

            # pred_2d is in pred_size space; scale back to IMAGE_SIZE for pixel error + 3D recovery
            pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
            pred_2d_full = pred_2d / coord_scale
            # Gripper at predicted pixel (inference mode) for error metric
            pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d)
            pred_gripper = torch.where(torch.sigmoid(pred_gripper_logits) > 0.5,
                                       torch.ones_like(pred_gripper_logits),
                                       -torch.ones_like(pred_gripper_logits))

            B, N, H, W = volume_logits.shape[0], volume_logits.shape[1], volume_logits.shape[3], volume_logits.shape[4]
            for t in range(N):
                pixel_error_t = torch.norm(pred_2d_full[:, t] - trajectory_2d[:, t], dim=1).sum()
                total_pixel_error += pixel_error_t.item()
            total_height_error += torch.abs(pred_height - target_height).mean(dim=1).sum().item()
            total_height_error_tf += 0.0
            total_gripper_error += torch.abs(pred_gripper - trajectory_gripper).mean(dim=1).sum().item()
            n_samples += rgb.shape[0]
            pbar.set_postfix(loss=f"{loss.item():.4f}", px=f"{(torch.norm(pred_2d_full[:, 0] - trajectory_2d[:, 0], dim=1).mean().item()):.2f}")

            if batch_idx == 0 and sample_data is None:
                pred_heatmaps = []
                for t in range(N_WINDOW):
                    vol_t = volume_logits[0, t].contiguous()  # (Nh, pred_size, pred_size)
                    vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                    heatmap_t = vol_probs.max(dim=0)[0]  # (pred_size, pred_size)
                    heatmap_up = F.interpolate(heatmap_t.unsqueeze(0).unsqueeze(0), size=(image_size, image_size), mode='bilinear', align_corners=False)[0, 0]
                    pred_heatmaps.append(heatmap_up)
                pred_heatmaps = torch.stack(pred_heatmaps)  # (N_WINDOW, image_size, image_size)

                pred_h_0 = pred_height[0]
                pred_g_0 = pred_gripper[0]
                if pred_h_0.dim() == 0:
                    pred_h_0 = pred_h_0.unsqueeze(0).expand(N_WINDOW)
                if pred_g_0.dim() == 0:
                    pred_g_0 = pred_g_0.unsqueeze(0).expand(N_WINDOW)

                cam_pose_np = camera_pose[0].cpu().numpy()
                cam_K_norm_np = cam_K_norm[0].cpu().numpy()
                cam_K_np = cam_K_norm_np.copy()
                cam_K_np[0] *= image_size
                cam_K_np[1] *= image_size
                pred_trajectory_3d_list = []
                for t in range(N_WINDOW):
                    px, py = pred_2d_full[0, t, 0].item(), pred_2d_full[0, t, 1].item()
                    h = pred_height[0, t].item()
                    pt = recover_3d_from_direct_keypoint_and_height(
                        np.array([px, py], dtype=np.float64), h, cam_pose_np, cam_K_np
                    )
                    if pt is not None:
                        pred_trajectory_3d_list.append(pt)
                    else:
                        pred_trajectory_3d_list.append(trajectory_3d[0, t].cpu().numpy())
                pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

                sample_data = {
                    'rgb': rgb[0],
                    'target_heatmap': batch['heatmap_target'][0].to(device),
                    'pred_heatmap': pred_heatmaps,
                    'trajectory_2d': trajectory_2d[0],
                    'trajectory_3d': trajectory_3d[0],
                    'trajectory_quat': batch['trajectory_quat'][0],
                    'rgb_frames_raw': batch['rgb_frames_raw'][0],
                    'world_to_camera': batch['world_to_camera'][0].cpu().numpy(),
                    'base_z': batch['base_z'][0],
                    'pred_trajectory_3d': pred_trajectory_3d_np,
                    'camera_pose': camera_pose[0].cpu().numpy(),
                    'cam_K_norm': cam_K_norm[0].cpu().numpy(),
                    'cam_K_at_size': cam_K_np,
                    'pred_height': pred_h_0,
                    'target_height': target_height[0],
                    'pred_gripper': pred_g_0,
                    'target_gripper': trajectory_gripper[0],
                }

    n = max(1, n_samples)
    avg_pixel_error = total_pixel_error / (n * N_WINDOW)
    return (
        total_loss / n, total_volume_loss / n, 0.0, total_gripper_loss / n,
        avg_pixel_error, total_height_error / n, total_height_error_tf / n, total_gripper_error / n,
        sample_data
    )


def _proj_world_to_vis(points_3d, world_to_camera, H, W):
    """Project (N,3) world points to pixel coords on the training image (flipud of obs).

    Matches debug_libero_projection.py exactly: project_points_from_world_to_camera returns
    (row, col) that can be drawn directly on flipud(obs_img) with NO additional row flip.
    Returns list of (u, v) = (col, row) ready for cv2 drawing.
    """
    pts = np.asarray(points_3d, dtype=np.float64)
    if pts.ndim == 1:
        pts = pts.reshape(1, 3)
    pix_rc = project_points_from_world_to_camera(
        points=pts,
        world_to_camera_transform=world_to_camera,
        camera_height=H,
        camera_width=W,
    )
    return [(int(round(float(rc[1]))), int(round(float(rc[0])))) for rc in pix_rc]  # (u=col, v=row)


def build_wandb_timestep_strip(sample, split_name):
    """Build a horizontal strip (one tile per timestep) matching debug_libero_projection.py style.

    Each tile shows the actual RGB frame at that timestep with:
      - predicted heatmap blended in red
      - predicted pixel: green crosshair
      - GT EEF projection: white filled circle + label
      - GT base-plane projection: cyan ring + yellow line to EEF
      - GT EEF rotation axes: red (x), green (y), blue (z) lines
    """
    if sample is None:
        return None

    world_to_camera = sample['world_to_camera']  # (4, 4) numpy
    trajectory_3d = sample['trajectory_3d']       # (N, 3) tensor
    trajectory_quat = sample['trajectory_quat']   # (N, 4) tensor
    trajectory_2d = sample['trajectory_2d']       # (N, 2) tensor  [x, y] upright
    rgb_frames_raw = sample['rgb_frames_raw']      # (N, H, W, 3) tensor float [0,1]
    base_z = float(sample['base_z'])

    tiles = []
    for t in range(N_WINDOW):
        frame = rgb_frames_raw[t].cpu().numpy()  # (H, W, 3) float [0,1]
        H, W = frame.shape[:2]

        # --- Heatmap overlay (red channel blend) ---
        pred_heatmap_t = sample['pred_heatmap'][t].detach().cpu().numpy()
        heat = pred_heatmap_t - pred_heatmap_t.min()
        if heat.max() > 1e-8:
            heat = heat / heat.max()
        heat_rgb = np.zeros_like(frame)
        heat_rgb[..., 0] = heat
        overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
        vis = (overlay * 255.0).astype(np.uint8)

        # --- Predicted pixel: green crosshair ---
        pred_y, pred_x = np.unravel_index(pred_heatmap_t.argmax(), pred_heatmap_t.shape)
        px, py = int(pred_x), int(pred_y)
        if 0 <= px < W and 0 <= py < H:
            cv2.drawMarker(vis, (px, py), (0, 255, 0), cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
            cv2.putText(vis, "pred", (px + 8, py - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1, cv2.LINE_AA)

        # --- GT EEF: project directly onto training image (flipud of obs), no extra row flip needed ---
        eef_pos = trajectory_3d[t].cpu().numpy().astype(np.float64)
        u, v = _proj_world_to_vis(eef_pos, world_to_camera, H, W)[0]
        if 0 <= u < W and 0 <= v < H:
            cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)
            cv2.putText(vis, "eef", (u + 8, v - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)

            # --- Base-plane projection (same XY, z=robot base z) ---
            eef_base = eef_pos.copy()
            eef_base[2] = base_z
            ug, vg = _proj_world_to_vis(eef_base, world_to_camera, H, W)[0]
            if 0 <= ug < W and 0 <= vg < H:
                cv2.circle(vis, (ug, vg), 6, (0, 255, 255), 2)
                cv2.putText(vis, f"base z={base_z:.3f}", (ug + 8, vg + 12), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 255), 1, cv2.LINE_AA)
                cv2.line(vis, (u, v), (ug, vg), (255, 255, 0), 2, cv2.LINE_AA)

            # --- EEF rotation axes: RGB colors matching debug_libero_projection.py ---
            eef_quat = trajectory_quat[t].cpu().numpy().astype(np.float64)
            eef_rot = T_rob.quat2mat(eef_quat)  # (3, 3)
            axis_len = 0.08
            axis_colors_rgb = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]  # x=red, y=green, z=blue (image is RGB)
            for i, color in enumerate(axis_colors_rgb):
                endpoint = eef_pos + eef_rot[:, i] * axis_len
                ua, va = _proj_world_to_vis(endpoint, world_to_camera, H, W)[0]
                if 0 <= ua < W and 0 <= va < H:
                    cv2.line(vis, (u, v), (ua, va), color, 2, cv2.LINE_AA)
                    cv2.circle(vis, (ua, va), 3, color, -1)

        # --- Timestep label ---
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)
        tiles.append(vis)

    strip = np.concatenate(tiles, axis=1)
    return wandb.Image(strip, caption=f"{split_name}: timesteps 0..{N_WINDOW-1} (left->right)")


def build_sample_data_for_logging(model, batch, device, image_size=IMAGE_SIZE, model_type="para"):
    """Build visualization sample dict from a single batch."""
    model.eval()
    with torch.no_grad():
        rgb = batch['rgb'][0:1].to(device)
        trajectory_2d = batch['trajectory_2d'][0:1].to(device)
        trajectory_3d = batch['trajectory_3d'][0:1].to(device)
        trajectory_gripper = batch['trajectory_gripper'][0:1].to(device)
        camera_pose = batch['camera_pose'][0].cpu().numpy()
        cam_K_norm = batch['cam_K_norm'][0].cpu().numpy()
        cam_K = cam_K_norm.copy()
        cam_K[0] *= image_size
        cam_K[1] *= image_size
        start_keypoint_2d = trajectory_2d[:, 0]

        extra_kwargs = {}
        if model_type == "dino_vla" and 'clip_embedding' in batch:
            extra_kwargs['clip_embedding'] = batch['clip_embedding'][0:1].to(device)
        elif model_type in ("internvl", "internvl_act") and 'task_description' in batch:
            extra_kwargs['task_text'] = [batch['task_description'][0]]

        volume_logits, _, _, feats = model(rgb, start_keypoint_2d, **extra_kwargs)
        pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
        pred_size = volume_logits.shape[-1]
        coord_scale = pred_size / image_size
        pred_2d_full = pred_2d / coord_scale  # scale back to IMAGE_SIZE space for 3D recovery
        pred_gripper_logits, _ = model.predict_at_pixels(feats, pred_2d)
        pred_gripper = torch.where(torch.sigmoid(pred_gripper_logits) > 0.5,
                                       torch.ones_like(pred_gripper_logits),
                                       -torch.ones_like(pred_gripper_logits))

        # Heatmaps are pred_size×pred_size; upsample to image_size for visualization overlay
        pred_heatmaps = []
        for t in range(N_WINDOW):
            vol_t = volume_logits[0, t].contiguous()  # (Nh, pred_size, pred_size)
            vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
            heatmap_t = vol_probs.max(dim=0)[0]  # (pred_size, pred_size)
            heatmap_up = F.interpolate(heatmap_t.unsqueeze(0).unsqueeze(0), size=(image_size, image_size), mode='bilinear', align_corners=False)[0, 0]
            pred_heatmaps.append(heatmap_up)
        pred_heatmaps = torch.stack(pred_heatmaps)

        pred_trajectory_3d_list = []
        for t in range(N_WINDOW):
            px, py = pred_2d_full[0, t, 0].item(), pred_2d_full[0, t, 1].item()
            h = pred_height[0, t].item()
            pt = recover_3d_from_direct_keypoint_and_height(np.array([px, py], dtype=np.float64), h, camera_pose, cam_K)
            pred_trajectory_3d_list.append(pt if pt is not None else trajectory_3d[0, t].cpu().numpy())
        pred_trajectory_3d_np = np.array(pred_trajectory_3d_list)

        return {
            'rgb': rgb[0],
            'target_heatmap': batch['heatmap_target'][0].to(device),
            'pred_heatmap': pred_heatmaps,
            'trajectory_2d': trajectory_2d[0],
            'trajectory_3d': trajectory_3d[0],
            'trajectory_quat': batch['trajectory_quat'][0],
            'rgb_frames_raw': batch['rgb_frames_raw'][0],
            'world_to_camera': batch['world_to_camera'][0].cpu().numpy(),
            'base_z': batch['base_z'][0],
            'pred_trajectory_3d': pred_trajectory_3d_np,
            'camera_pose': camera_pose,
            'cam_K_at_size': cam_K,
            'pred_height': pred_height[0],
            'target_height': trajectory_3d[0, :, 2],
            'pred_gripper': pred_gripper[0],
            'target_gripper': trajectory_gripper[0],
        }


def main():
    parser = argparse.ArgumentParser(description="Train PARA trajectory heatmap predictor on LIBERO")
    parser.add_argument("--model_type", type=str, default="para",
                        choices=["para", "act", "da3", "moge", "dino_vla", "internvl", "internvl_act"],
                        help="Model architecture to train")
    parser.add_argument("--model_name", type=str, default="OpenGVLab/InternVL2_5-1B",
                        help="HuggingFace model name (used by internvl model_type)")
    parser.add_argument("--benchmark", type=str, default="libero_spatial",
                        help="LIBERO benchmark name")
    parser.add_argument("--task_id", type=int, default=0,
                        help="Task index within benchmark")
    parser.add_argument("--task_ids", type=str, default="",
                        help="Comma-separated task indices to train on all at once (e.g. '0,1,2' or 'all'). Overrides --task_id.")
    parser.add_argument("--camera", type=str, default="agentview",
                        help="Camera name used for training observations")
    parser.add_argument("--max_demos", type=int, default=10,
                        help="Maximum demos to load from task dataset")
    parser.add_argument("--val_split", type=float, default=0.05,
                        help="Fraction of episodes to use for validation")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE,
                        help="Batch size for training")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                        help="Learning rate")
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS,
                        help="Number of epochs")
    parser.add_argument("--checkpoint", type=str, default="",
                        help="Path to checkpoint to resume from")
    parser.add_argument("--run_name", type=str, default="para_libero",
                        help="Name of run (used for checkpoint paths and W&B)")
    parser.add_argument("--wandb_project", type=str, default="para_libero",
                        help="W&B project name")
    parser.add_argument("--wandb_entity", type=str, default=None,
                        help="W&B entity/team (optional)")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"],
                        help="W&B mode")
    parser.add_argument("--stats_cache_path", type=str, default="",
                        help="Path to JSON cache for height/gripper stats")
    parser.add_argument("--stats_sample_limit", type=int, default=500,
                        help="Random number of samples to use for stats computation (0 = full dataset)")
    parser.add_argument("--stats_seed", type=int, default=42,
                        help="Random seed for stats subsampling")
    parser.add_argument("--vis_every_steps", type=int, default=50,
                        help="Log visualization images every N train steps")
    parser.add_argument("--frame_stride", type=int, default=3,
                        help="Sample every Nth frame from the demo (default 3 → ~6.7Hz @ 20Hz, N_WINDOW=6 spans ~0.9s)")
    parser.add_argument("--cache_root", type=str, default="",
                        help="Path to pre-rendered dataset (e.g. /data/libero/parsed_libero). Uses CachedTrajectoryDataset when set.")
    parser.add_argument("--pos_loss_weight", type=float, default=1.0,
                        help="Weight for position loss (ACT: normalizes pos MSE to match rot/grip)")
    parser.add_argument("--volume_loss_weight", type=float, default=1.0,
                        help="Weight for volume/heatmap CE loss (PARA: scale relative to grip/rot CE)")
    parser.add_argument("--gripper_loss_weight", type=float, default=5.0,
                        help="Weight for gripper CE loss")
    parser.add_argument("--rotation_loss_weight", type=float, default=0.5,
                        help="Weight for rotation CE loss")
    parser.add_argument("--act_gripper_loss_weight", type=float, default=1.0,
                        help="Weight for ACT gripper BCE loss")
    parser.add_argument("--save_every_steps", type=int, default=1000,
                        help="Save checkpoint every N training steps (0 = only save per-epoch)")
    parser.add_argument("--overfit_one_sample", action="store_true",
                        help="Overfit to a single dataset sample (sanity check)")
    parser.add_argument("--pretrained_backbone", type=str, default="",
                        help="Path to point-track pretrained checkpoint (e.g. point_track_pretraining/checkpoints/.../best.pth). "
                             "Loads only the DINO backbone weights (dino.* keys) from the checkpoint, "
                             "initializing all other heads randomly. Use this for the pretrain→finetune pipeline.")
    args = parser.parse_args()

    global POS_LOSS_WEIGHT, VOLUME_LOSS_WEIGHT, GRIPPER_LOSS_WEIGHT, ROTATION_LOSS_WEIGHT, ACT_GRIPPER_LOSS_WEIGHT
    POS_LOSS_WEIGHT = args.pos_loss_weight
    VOLUME_LOSS_WEIGHT = args.volume_loss_weight
    GRIPPER_LOSS_WEIGHT = args.gripper_loss_weight
    ROTATION_LOSS_WEIGHT = args.rotation_loss_weight
    ACT_GRIPPER_LOSS_WEIGHT = args.act_gripper_loss_weight

    # BUG FIX: checkpoint dir is now relative to script location (libero/checkpoints/),
    # not hardcoded to the old volume_dino_tracks/ path
    script_dir = Path(__file__).parent
    CHECKPOINT_DIR = script_dir / "checkpoints" / args.run_name
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    stats_cache_path = Path(args.stats_cache_path) if args.stats_cache_path else (CHECKPOINT_DIR / "dataset_stats.json")

    # Device: MPS on Mac, CUDA on server
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps" if torch.backends.mps.is_available() else
        "cpu"
    )
    print(f"Using device: {device}")

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.run_name,
        mode=args.wandb_mode,
        config={
            "benchmark": args.benchmark,
            "task_id": args.task_id,
            "camera": args.camera,
            "max_demos": args.max_demos,
            "val_split": args.val_split,
            "batch_size": args.batch_size,
            "lr": args.lr,
            "epochs": args.epochs,
            "image_size": IMAGE_SIZE,
            "n_window": N_WINDOW,
            "n_height_bins": N_HEIGHT_BINS,
            "n_gripper_bins": N_GRIPPER_BINS,
        },
    )

    print("\nLoading dataset...")
    if args.cache_root:
        # Use pre-rendered cached dataset (fast, supports num_workers>0)
        task_id_list = None
        if args.task_ids:
            if args.task_ids.strip().lower() == "all":
                task_id_list = None  # CachedTrajectoryDataset loads all by default
            else:
                task_id_list = [int(x) for x in args.task_ids.split(",")]
        else:
            task_id_list = [args.task_id]
        full_dataset = CachedTrajectoryDataset(
            cache_root=args.cache_root,
            benchmark_name=args.benchmark,
            task_ids=task_id_list,
            image_size=IMAGE_SIZE,
            n_window=N_WINDOW,
            frame_stride=args.frame_stride,
            max_demos=args.max_demos,
        )
        dataset_source = {"cache_root": args.cache_root, "benchmark": args.benchmark, "task_ids": task_id_list, "frame_stride": args.frame_stride}
    elif args.task_ids:
        from libero.libero import benchmark as _bm
        n_tasks = _bm.get_benchmark_dict()[args.benchmark]().get_num_tasks()
        if args.task_ids.strip().lower() == "all":
            task_id_list = list(range(n_tasks))
        else:
            task_id_list = [int(x) for x in args.task_ids.split(",")]
        print(f"  Multi-task mode: tasks {task_id_list}")
        datasets = []
        for tid in task_id_list:
            ds = RealTrajectoryDataset(
                image_size=IMAGE_SIZE,
                benchmark_name=args.benchmark,
                task_id=tid,
                camera=args.camera,
                max_demos=args.max_demos,
                frame_stride=args.frame_stride,
            )
            datasets.append(ds)
        full_dataset = ConcatDataset(datasets)
        dataset_source = {"benchmark": args.benchmark, "task_ids": task_id_list, "camera": args.camera, "max_demos": args.max_demos}
    else:
        full_dataset = RealTrajectoryDataset(
            image_size=IMAGE_SIZE,
            benchmark_name=args.benchmark,
            task_id=args.task_id,
            camera=args.camera,
            max_demos=args.max_demos,
            frame_stride=args.frame_stride,
        )
        dataset_source = {
            "benchmark": args.benchmark,
            "task_id": args.task_id,
            "camera": args.camera,
            "max_demos": args.max_demos,
        }
    print(f"  Source: {dataset_source}")
    print(f"  Total: {len(full_dataset)} samples")

    if args.overfit_one_sample:
        # Grab one sample and repeat it — useful for verifying the supervision signal
        print("\n⚠ OVERFIT MODE: using a single repeated sample for train and val")
        one_sample = full_dataset[0]
        class _RepeatDataset(torch.utils.data.Dataset):
            def __init__(self, sample, n=2000):
                self.sample = sample
                self.n = n
            def __len__(self):
                return self.n
            def __getitem__(self, _):
                return self.sample
        train_dataset = _RepeatDataset(one_sample, n=2000)
        val_dataset   = _RepeatDataset(one_sample, n=10)
    else:
        dataset_size = len(full_dataset)
        val_size = max(1, int(dataset_size * args.val_split))
        train_size = dataset_size - val_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )

    print(f"✓ Train: {len(train_dataset)} samples")
    print(f"✓ Val: {len(val_dataset)} samples")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

    print(f"\nInitializing model (type={args.model_type})...")
    ModelClass = get_model_class(args.model_type)
    model_kwargs = dict(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    if args.model_type in ("internvl", "internvl_act"):
        model_kwargs["model_name"] = args.model_name
    model = ModelClass(**model_kwargs)
    model = model.to(device)

    # Load pretrained backbone from point-track pretraining (before checkpoint resume)
    if args.pretrained_backbone and os.path.exists(args.pretrained_backbone):
        print(f"\nLoading pretrained backbone from: {args.pretrained_backbone}")
        pt_ckpt = torch.load(args.pretrained_backbone, map_location=device)
        pt_state = pt_ckpt['model_state_dict']
        model_dict = model.state_dict()
        # Transfer only dino.* keys (the shared backbone)
        backbone_keys = {k: v for k, v in pt_state.items() if k.startswith("dino.")}
        loaded = []
        skipped = []
        for k, v in backbone_keys.items():
            if k in model_dict and v.shape == model_dict[k].shape:
                model_dict[k] = v
                loaded.append(k)
            else:
                skipped.append(k)
        model.load_state_dict(model_dict, strict=False)
        print(f"  Loaded {len(loaded)} backbone keys from pretrained checkpoint")
        if skipped:
            print(f"  Skipped {len(skipped)} keys (shape mismatch or not in model): {skipped[:5]}")
        pt_epoch = pt_ckpt.get('epoch', '?')
        pt_val_loss = pt_ckpt.get('val_loss', '?')
        print(f"  Pretrained checkpoint: epoch={pt_epoch}, val_loss={pt_val_loss}")
    elif args.pretrained_backbone:
        print(f"\nWARNING: pretrained backbone not found: {args.pretrained_backbone}")

    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {n_trainable:,} / {n_total:,} ({100*n_trainable/n_total:.2f}%)")

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=1e-4
    )

    start_epoch = 0
    checkpoint = None
    checkpoint_height_values = None
    checkpoint_gripper_values = None
    checkpoint_path = args.checkpoint
    if args.checkpoint:
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
    if args.checkpoint and os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        model_state = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        filtered_state = {}
        shape_mismatches = []
        for k, v in model_state.items():
            if k in model_dict:
                if v.shape == model_dict[k].shape:
                    filtered_state[k] = v
                else:
                    shape_mismatches.append(f"{k}: checkpoint {v.shape} vs model {model_dict[k].shape}")

        missing_keys = set(model_dict.keys()) - set(model_state.keys())
        unexpected_keys = set(model_state.keys()) - set(model_dict.keys())

        if missing_keys:
            print(f"⚠ Missing keys (random init): {sorted(missing_keys)}")
        if unexpected_keys:
            print(f"⚠ Unexpected keys (ignored): {sorted(unexpected_keys)}")
        if shape_mismatches:
            print(f"⚠ Shape mismatches (random init):")
            for msg in shape_mismatches:
                print(f"    {msg}")

        model_dict.update(filtered_state)
        model.load_state_dict(model_dict, strict=False)

        if shape_mismatches:
            print(f"⚠ Skipping optimizer state (model shape changed)")
        else:
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            except Exception as e:
                print(f"⚠ Could not load optimizer state: {e}")

        start_epoch = checkpoint.get('epoch', 0) + 1

        if 'min_height' in checkpoint and 'max_height' in checkpoint:
            checkpoint_height_values = (checkpoint['min_height'], checkpoint['max_height'])
            print(f"✓ Height range from checkpoint: [{checkpoint_height_values[0]:.6f}, {checkpoint_height_values[1]:.6f}] m")
        if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
            checkpoint_gripper_values = (checkpoint['min_gripper'], checkpoint['max_gripper'])
            print(f"✓ Gripper range from checkpoint: [{checkpoint_gripper_values[0]:.6f}, {checkpoint_gripper_values[1]:.6f}]")
        print(f"✓ Resumed from epoch {start_epoch}")
    elif args.checkpoint:
        print(f"\n⚠ Checkpoint not found: {args.checkpoint}")

    stats_cache = None
    if stats_cache_path.exists():
        try:
            with open(stats_cache_path, "r", encoding="utf-8") as f:
                stats_cache = json.load(f)
            print(f"✓ Loaded stats cache: {stats_cache_path}")
        except Exception as e:
            print(f"⚠ Failed to read stats cache: {e}")
            stats_cache = None

    def compute_stats_from_dataset():
        total_len = len(train_dataset) + len(val_dataset)
        if args.stats_sample_limit and args.stats_sample_limit > 0:
            sample_count = min(args.stats_sample_limit, total_len)
            rng = random.Random(args.stats_seed)
            print(f"\nComputing dataset stats from random subset: {sample_count}/{total_len} samples (seed={args.stats_seed})")
        else:
            sample_count = total_len
            print(f"\nComputing dataset stats from full dataset: {total_len} samples")

        all_heights = []
        all_grippers = []
        all_eulers = []   # list of (3,) arrays
        all_positions = []  # list of (N, 3) arrays for per-axis pos min/max
        seen = set()
        success_count = 0
        max_attempts = max(sample_count * 20, total_len if sample_count == total_len else sample_count)
        attempts = 0
        pbar = tqdm(total=sample_count, desc="Stats subset", leave=False)
        while success_count < sample_count and attempts < max_attempts:
            attempts += 1
            if sample_count == total_len:
                global_idx = success_count
            else:
                global_idx = rng.randrange(total_len)
                if global_idx in seen:
                    continue
            seen.add(global_idx)
            dataset = train_dataset if global_idx < len(train_dataset) else val_dataset
            local_idx = global_idx if global_idx < len(train_dataset) else (global_idx - len(train_dataset))
            try:
                sample = dataset[local_idx]
            except Exception:
                continue
            trajectory_3d = sample['trajectory_3d'].numpy()
            trajectory_gripper = sample['trajectory_gripper'].numpy()
            trajectory_euler = sample['trajectory_euler'].numpy()  # (N, 3)
            all_heights.extend(trajectory_3d[:, 2].tolist())
            all_grippers.extend(trajectory_gripper.tolist())
            all_eulers.append(trajectory_euler)
            all_positions.append(trajectory_3d)
            success_count += 1
            pbar.update(1)
        pbar.close()

        if len(all_heights) == 0 or len(all_grippers) == 0:
            print("⚠ No valid samples found for stats; falling back to model defaults.")
            return {
                "min_height": float(model_module.MIN_HEIGHT),
                "max_height": float(model_module.MAX_HEIGHT),
                "min_gripper": float(model_module.MIN_GRIPPER),
                "max_gripper": float(model_module.MAX_GRIPPER),
                "min_rot": model_module.MIN_ROT,
                "max_rot": model_module.MAX_ROT,
                "num_height_values": 0,
                "num_gripper_values": 0,
            }

        all_heights_np  = np.array(all_heights, dtype=np.float64)
        all_grippers_np = np.array(all_grippers, dtype=np.float64)
        all_eulers_np   = np.concatenate(all_eulers, axis=0)  # (total_N, 3)
        all_positions_np = np.concatenate(all_positions, axis=0)  # (total_N, 3)
        return {
            "min_height":  float(all_heights_np.min()),
            "max_height":  float(all_heights_np.max()),
            "min_gripper": float(all_grippers_np.min()),
            "max_gripper": float(all_grippers_np.max()),
            "min_rot":     all_eulers_np.min(axis=0).tolist(),  # [min_x, min_y, min_z]
            "max_rot":     all_eulers_np.max(axis=0).tolist(),  # [max_x, max_y, max_z]
            "min_pos":     all_positions_np.min(axis=0).tolist(),  # [min_x, min_y, min_z]
            "max_pos":     all_positions_np.max(axis=0).tolist(),  # [max_x, max_y, max_z]
            "num_height_values":  int(all_heights_np.size),
            "num_gripper_values": int(all_grippers_np.size),
            "dataset_source": dataset_source,
            "n_window":   int(N_WINDOW),
            "image_size": int(IMAGE_SIZE),
        }

    # Resolve height range from checkpoint > cache > dataset
    if checkpoint_height_values is not None:
        model_module.MIN_HEIGHT, model_module.MAX_HEIGHT = checkpoint_height_values
        print(f"✓ Using height range from checkpoint: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
            print(f"✓ Saved stats cache: {stats_cache_path}")
        model_module.MIN_HEIGHT = float(stats_cache["min_height"])
        model_module.MAX_HEIGHT = float(stats_cache["max_height"])
        print(f"✓ Height range from dataset: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
        if abs(model_module.MIN_HEIGHT - model_module.MAX_HEIGHT) < 1e-6:
            print("  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT — all height predictions will be constant!")

    # Resolve gripper range from checkpoint > cache > dataset
    if checkpoint_gripper_values is not None:
        model_module.MIN_GRIPPER, model_module.MAX_GRIPPER = checkpoint_gripper_values
        print(f"✓ Using gripper range from checkpoint: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
            print(f"✓ Saved stats cache: {stats_cache_path}")
        model_module.MIN_GRIPPER = float(stats_cache["min_gripper"])
        model_module.MAX_GRIPPER = float(stats_cache["max_gripper"])
        print(f"✓ Gripper range from dataset: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")

    # Resolve rotation range from checkpoint > cache > dataset
    checkpoint_rot_values = None
    if checkpoint is not None and 'min_rot' in checkpoint and 'max_rot' in checkpoint:
        checkpoint_rot_values = (checkpoint['min_rot'], checkpoint['max_rot'])
    if checkpoint_rot_values is not None:
        model_module.MIN_ROT, model_module.MAX_ROT = checkpoint_rot_values
        print(f"✓ Rotation range from checkpoint: {model_module.MIN_ROT} .. {model_module.MAX_ROT}")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
        if "min_rot" in stats_cache:
            model_module.MIN_ROT = stats_cache["min_rot"]
            model_module.MAX_ROT = stats_cache["max_rot"]
            print(f"✓ Rotation range from dataset: {[f'{v:.3f}' for v in model_module.MIN_ROT]} .. {[f'{v:.3f}' for v in model_module.MAX_ROT]}")

    # Resolve position range from checkpoint > cache > dataset
    checkpoint_pos_values = None
    if checkpoint is not None and 'min_pos' in checkpoint and 'max_pos' in checkpoint:
        checkpoint_pos_values = (checkpoint['min_pos'], checkpoint['max_pos'])
    if checkpoint_pos_values is not None:
        model_module.MIN_POS, model_module.MAX_POS = checkpoint_pos_values
        print(f"✓ Position range from checkpoint: {model_module.MIN_POS} .. {model_module.MAX_POS}")
    else:
        if stats_cache is None:
            stats_cache = compute_stats_from_dataset()
            stats_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_cache_path, "w", encoding="utf-8") as f:
                json.dump(stats_cache, f, indent=2)
        if "min_pos" in stats_cache:
            model_module.MIN_POS = stats_cache["min_pos"]
            model_module.MAX_POS = stats_cache["max_pos"]
            print(f"✓ Position range from dataset: {[f'{v:.3f}' for v in model_module.MIN_POS]} .. {[f'{v:.3f}' for v in model_module.MAX_POS]}")

    print(f"\nStarting training for {args.epochs} epochs...")
    best_val_loss = float('inf')
    global_step = 0

    def log_visualizations(step):
        if not is_heatmap_model(args.model_type):
            return  # skip heatmap-specific visualization for ACT
        train_batch = next(iter(train_loader))
        val_batch = next(iter(val_loader))
        sample_train = build_sample_data_for_logging(model, train_batch, device, image_size=IMAGE_SIZE, model_type=args.model_type)
        sample_val_local = build_sample_data_for_logging(model, val_batch, device, image_size=IMAGE_SIZE, model_type=args.model_type)
        payload = {}
        train_strip = build_wandb_timestep_strip(sample_train, "train")
        val_strip = build_wandb_timestep_strip(sample_val_local, "val")
        if train_strip is not None:
            payload["vis/train_strip"] = train_strip
        if val_strip is not None:
            payload["vis/val_strip"] = val_strip
        if payload:
            wandb.log(payload, step=step)

    if args.vis_every_steps > 0:
        log_visualizations(0)

    def save_step_checkpoint(step):
        """Save a mid-epoch checkpoint at a given global step."""
        ckpt = {
            'global_step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'min_height':  model_module.MIN_HEIGHT,
            'max_height':  model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
            'min_rot':     model_module.MIN_ROT,
            'max_rot':     model_module.MAX_ROT,
            'min_pos':     model_module.MIN_POS,
            'max_pos':     model_module.MAX_POS,
        }
        torch.save(ckpt, CHECKPOINT_DIR / f'step_{step}.pth')
        # Also overwrite latest so eval scripts can find it
        torch.save(ckpt, CHECKPOINT_DIR / 'latest.pth')
        print(f"\n  Saved step checkpoint: step_{step}.pth")

    for epoch in tqdm(range(start_epoch, args.epochs), desc="Epochs"):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{args.epochs}")
        print(f"{'='*60}")

        train_loss, train_volume_loss, train_gripper_loss, train_rotation_loss, global_step = train_epoch(
            model, train_loader, optimizer, device,
            just_heatmap=False,  # always supervise gripper + rotation from the start
            global_step_start=global_step,
            vis_every_steps=args.vis_every_steps,
            vis_callback=log_visualizations,
            model_type=args.model_type,
            save_every_steps=args.save_every_steps,
            save_callback=save_step_checkpoint,
        )
        print(f"Train Loss: {train_loss:.4f} (Volume: {train_volume_loss:.4f}, Gripper: {train_gripper_loss:.4f}, Rotation: {train_rotation_loss:.4f})")

        val_loss, val_heatmap_loss, val_height_loss, val_gripper_loss, \
        val_error, val_height_error, val_height_error_tf, val_gripper_error, sample_val = validate(
            model, val_loader, device, model_type=args.model_type,
        )
        print(f"Val - Loss: {val_loss:.4f}, Volume: {val_heatmap_loss:.4f}, Pixel Error: {val_error:.2f}px, Height Error: {val_height_error*1000:.3f}mm, Gripper: {val_gripper_error:.4f}")

        wandb.log({
            "epoch": epoch,
            "train/loss": train_loss,
            "train/volume_loss": train_volume_loss,
            "train/gripper_loss": train_gripper_loss,
            "train/rotation_loss": train_rotation_loss,
            "val/loss": val_loss,
            "val/volume_loss": val_heatmap_loss,
            "val/gripper_loss": val_gripper_loss,
            "val/pixel_error": val_error,
            "val/height_error_mm": val_height_error * 1000.0,
            "val/gripper_abs_error": val_gripper_error,
        }, step=global_step)

        checkpoint_data = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height':  model_module.MIN_HEIGHT,
            'max_height':  model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
            'min_rot':     model_module.MIN_ROT,
            'max_rot':     model_module.MAX_ROT,
            'min_pos':     model_module.MIN_POS,
            'max_pos':     model_module.MAX_POS,
        }

        torch.save(checkpoint_data, CHECKPOINT_DIR / 'latest.pth')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint_data, CHECKPOINT_DIR / 'best.pth')
            print(f"✓ Saved best model (val_loss={val_loss:.4f})")

    wandb.finish()
    print("\n" + "=" * 60)
    print("✓ Training complete!")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints saved to: {CHECKPOINT_DIR}")


if __name__ == "__main__":
    main()
