"""Train PARA on smith300 arm data.

Adapted from panda_streaming/train_panda_para.py. Same loss heads (volume,
gripper, rotation), same wandb EEF heatmap visualizations -- just point at
the smith300 data loader and zero the gripper loss weight (we did not
record the gripper position on this dataset; setting weight to 0 keeps the
head wired up so it can be re-enabled once gripper data exists, without
modifying the model).

Usage:
  cd /data/cameron/para/para_mac
  CUDA_VISIBLE_DEVICES=9 MUJOCO_GL=egl \\
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \\
  DINO_WEIGHTS_PATH=/data/cameron/.cache/torch/hub/checkpoints/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \\
  python train_smith300_para.py \\
    --data_dir /data/cameron/mac_robot_datasets/dataset_20260501_180125 \\
    --run_name smith300_para_v0 \\
    --epochs 500 --batch_size 4
"""
import sys, os
sys.path.insert(0, os.path.dirname(__file__))

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
from pathlib import Path
import argparse
import wandb
import json
import cv2

from data_smith300_para import (
    Smith300TrajectoryDataset,
    N_WINDOW,
    project_to_pixel,
    DEFAULT_SMITH300_XML,
    EEF_BODY_NAME,
)

# Use the smith300-local model.py (sibling file) -- it differs from
# /data/cameron/para/libero/model.py in the gripper head, which we changed
# from a single-logit BCE to a multi-bin CE so smith300's continuous gripper q
# in radians can be supervised across the recorded range.
import importlib.util
_local_dir = os.path.dirname(__file__)
_spec = importlib.util.spec_from_file_location(
    "model", os.path.join(_local_dir, "model.py"))
model_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(model_module)
TrajectoryHeatmapPredictor = model_module.TrajectoryHeatmapPredictor
N_HEIGHT_BINS = model_module.N_HEIGHT_BINS
N_GRIPPER_BINS = model_module.N_GRIPPER_BINS
N_ROT_BINS = model_module.N_ROT_BINS
PRED_SIZE = model_module.PRED_SIZE

IMAGE_SIZE = 448
VOLUME_LOSS_WEIGHT = 1.0
GRIPPER_LOSS_WEIGHT = 5.0   # matches panda; data loader passes 0 for datasets
                            #  without a recorded gripper, so loss stays trivial there.
ROTATION_LOSS_WEIGHT = 0.5
VIS_EVERY_EPOCHS = 10


# ── Loss helpers (identical to panda) ─────────────────────────────────────

def discretize_height(height_values):
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    normalized = ((height_values - min_h) / (max_h - min_h + 1e-8)).clamp(0, 1)
    return (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins,
                        valid_mask=None):
    """Per-pixel CE over flattened (Nh*H*W) volume grid, masked by `valid_mask`
    (B, N) so timesteps with off-frame / behind-camera EEF (the loader writes
    pixel=(0,0) for those) don't train a top-left attractor in the heatmap."""
    B, N, Nh, H, W = pred_volume_logits.shape
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)
    h_bin = target_height_bins.clamp(0, Nh - 1)
    losses = []
    for t in range(N):
        logits_flat = pred_volume_logits[:, t].reshape(B, -1)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()
        per_sample = F.cross_entropy(logits_flat, target_idx, reduction='none')  # (B,)
        if valid_mask is not None:
            m = valid_mask[:, t].float()
            denom = m.sum().clamp(min=1.0)
            losses.append((per_sample * m).sum() / denom)
        else:
            losses.append(per_sample.mean())
    return torch.stack(losses).mean()


def discretize_gripper(values):
    """Map continuous gripper q (rad) to bin indices using model_module.MIN/MAX_GRIPPER."""
    min_g = float(model_module.MIN_GRIPPER)
    max_g = float(model_module.MAX_GRIPPER)
    normalized = ((values - min_g) / (max_g - min_g + 1e-8)).clamp(0, 1)
    return (normalized * (N_GRIPPER_BINS - 1)).long().clamp(0, N_GRIPPER_BINS - 1)


def compute_gripper_loss(pred_gripper_logits, target_gripper, valid_mask=None):
    """CE over N_GRIPPER_BINS bins.  pred_gripper_logits: (B, N, Ng).
    target_gripper: (B, N) continuous q in rad.
    valid_mask (B, N): if given, off-frame timesteps are dropped — their
    head input was sampled at the (0,0) clamped pixel and is meaningless."""
    target_bins = discretize_gripper(target_gripper)
    B, N, Ng = pred_gripper_logits.shape
    logits = pred_gripper_logits.reshape(B * N, Ng)
    target = target_bins.reshape(B * N)
    per_elem = F.cross_entropy(logits, target, reduction='none')  # (B*N,)
    if valid_mask is not None:
        m = valid_mask.reshape(B * N).float()
        return (per_elem * m).sum() / m.sum().clamp(min=1.0)
    return per_elem.mean()


def discretize_rotation(euler_values):
    min_r = torch.tensor(model_module.MIN_ROT, device=euler_values.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=euler_values.device, dtype=torch.float32)
    normalized = ((euler_values - min_r) / (max_r - min_r + 1e-8)).clamp(0, 1)
    return (normalized * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def compute_rotation_loss(pred_rotation_logits, target_euler, valid_mask=None):
    """Per-axis CE; off-frame timesteps dropped via valid_mask (B, N)."""
    target_bins = discretize_rotation(target_euler)
    B, N, _, Nr = pred_rotation_logits.shape
    if valid_mask is not None:
        m_flat = valid_mask.reshape(B * N).float()
        denom = m_flat.sum().clamp(min=1.0)
    losses = []
    for axis in range(3):
        logits = pred_rotation_logits[:, :, axis, :].reshape(B * N, Nr)
        target = target_bins[:, :, axis].reshape(B * N)
        per_elem = F.cross_entropy(logits, target, reduction='none')
        if valid_mask is not None:
            losses.append((per_elem * m_flat).sum() / denom)
        else:
            losses.append(per_elem.mean())
    return torch.stack(losses).mean()


class EMALossEqualizer:
    """Multi-task loss balancer: each task's raw loss is divided by an EMA
    of its own magnitude (detached) so every term contributes ~unit-scale
    gradient regardless of intrinsic loss-value differences. Eliminates
    the need to hand-tune VOLUME_LOSS_WEIGHT / GRIPPER_LOSS_WEIGHT / ...
    constants when adding new heads."""

    def __init__(self, beta: float = 0.99, min_ema: float = 1e-3):
        self.beta = float(beta)
        self.min_ema = float(min_ema)
        self.ema: dict[str, float] = {}

    def normalize(self, name: str, loss):
        raw = float(loss.detach().item())
        if name not in self.ema:
            self.ema[name] = raw
        else:
            self.ema[name] = self.beta * self.ema[name] + (1.0 - self.beta) * raw
        scale = max(self.ema[name], self.min_ema)
        return loss / scale

    def state(self):
        return dict(self.ema)


def decode_pred_euler(rotation_logits):
    """Decode (B, N, 3, Nr) rotation logits into (B, N, 3) euler in rad
    using the per-axis bin centers (inverse of discretize_rotation)."""
    pred_bins = rotation_logits.argmax(dim=-1)  # (B, N, 3)
    Nr = rotation_logits.shape[-1]
    bin_centers = torch.linspace(0, 1, Nr, device=rotation_logits.device)
    normalized = bin_centers[pred_bins]
    min_r = torch.tensor(model_module.MIN_ROT, device=rotation_logits.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=rotation_logits.device, dtype=torch.float32)
    return normalized * (max_r - min_r) + min_r


def extract_pred_2d_and_height(volume_logits):
    B, N, Nh, H, W = volume_logits.shape
    pred_2d = torch.zeros(B, N, 2, device=volume_logits.device)
    pred_h_bins = torch.zeros(B, N, device=volume_logits.device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.reshape(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_h_bins[:, t] = vol_t[torch.arange(B, device=volume_logits.device), :, py, px].argmax(dim=1)
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    bin_centers = torch.linspace(0, 1, N_HEIGHT_BINS, device=volume_logits.device)
    pred_height = bin_centers[pred_h_bins] * (max_h - min_h) + min_h
    return pred_2d, pred_height


# ── Smith300-specific dataset stats (FK + euler ranges) ──────────────────

def compute_dataset_stats_fast(data_dirs_with_episodes, mujoco_xml=DEFAULT_SMITH300_XML):
    """Compute joint height/rotation/gripper ranges across one or more
    datasets. Each entry in `data_dirs_with_episodes` is a (data_dir,
    episodes_list) pair. Datasets with `eef_pos`/`eef_euler` saved
    (UMI capture path) skip arm-FK; datasets without (smith300 capture
    path) run FK on q_motors[:, :nq]."""
    import mujoco
    from scipy.spatial.transform import Rotation as Rot
    mj_model = mujoco.MjModel.from_xml_path(mujoco_xml)
    mj_data = mujoco.MjData(mj_model)
    eef_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY_NAME)
    n_qpos = mj_model.nq

    heights, eulers, grippers = [], [], []
    for data_dir, episodes in data_dirs_with_episodes:
        joints = np.load(os.path.join(data_dir, "joints.npz"))
        q_motors = np.asarray(joints["q_motors"], dtype=np.float64)
        n_recorded_motors = q_motors.shape[1]
        has_saved_eef = ("eef_pos" in joints.files
                         and "eef_euler" in joints.files)
        if has_saved_eef:
            eef_pos = np.asarray(joints["eef_pos"], dtype=np.float64)
            eef_euler = np.asarray(joints["eef_euler"], dtype=np.float64)
        for ep in episodes:
            ep_end = min(int(ep["end"]), q_motors.shape[0] - 1)
            for idx in range(int(ep["start"]), ep_end + 1):
                if has_saved_eef:
                    heights.append(float(eef_pos[idx, 2]))
                    eulers.append(eef_euler[idx])
                else:
                    q = np.zeros(n_qpos, dtype=np.float64)
                    q[:min(q_motors.shape[1], n_qpos)] = q_motors[idx, :n_qpos]
                    mj_data.qpos[:n_qpos] = q
                    mujoco.mj_forward(mj_model, mj_data)
                    pos = mj_data.xpos[eef_id].copy()
                    quat_wxyz = mj_data.xquat[eef_id].copy()
                    quat_xyzw = quat_wxyz[[1, 2, 3, 0]]
                    heights.append(float(pos[2]))
                    eulers.append(Rot.from_quat(quat_xyzw).as_euler('xyz'))
                if n_recorded_motors >= 7:
                    grippers.append(float(q_motors[idx, 6]))

    heights = np.array(heights)
    eulers = np.array(eulers)
    if grippers:
        gmin, gmax = float(min(grippers)), float(max(grippers))
        # Pad the range slightly so the boundary bins are reachable and a
        # tiny excursion past min/max during training doesn't clip to the
        # edge bin (small but non-trivial CE asymmetry otherwise).
        pad = 0.05 * max(1e-3, gmax - gmin)
        gmin -= pad; gmax += pad
        gripper_recorded = True
    else:
        gmin, gmax = -1.0, 1.0
        gripper_recorded = False
    stats = {
        "min_height": float(heights.min()), "max_height": float(heights.max()),
        "min_gripper": gmin, "max_gripper": gmax,
        "min_rot": eulers.min(axis=0).tolist(),
        "max_rot": eulers.max(axis=0).tolist(),
    }
    msg = (f"recorded range [{gmin:+.3f}, {gmax:+.3f}] "
           f"({len(grippers)} samples)") if gripper_recorded else "PLACEHOLDER (not recorded)"
    print(f"Stats: height=[{stats['min_height']:.4f}, {stats['max_height']:.4f}], "
          f"gripper={msg}")
    return stats


# ── Visualization (wandb timestep strip) ──────────────────────────────────

def project_world_to_pixel(pos_3d, image_size, T_cam_world, K_orig, img_w, img_h):
    """Project 3D world point to pixel coords at image_size resolution.
    Mirrors the panda version but takes the calibration as args (per-dataset)."""
    pix = project_to_pixel(pos_3d, T_cam_world, K_orig)
    if pix is None:
        return None
    u = pix[0] * image_size / img_w
    v = pix[1] * image_size / img_h
    return int(round(u)), int(round(v))


def _scale_K_to_image(K_orig, img_w, img_h, image_size):
    """Anisotropic K for projection at image_size x image_size raster."""
    K = K_orig.copy().astype(np.float64)
    K[0] *= image_size / float(img_w)
    K[1] *= image_size / float(img_h)
    return K


def draw_pose_axes(vis_bgr, origin_3d, euler_xyz, T_cam_world, K_image,
                   length=0.04, thickness=2, dashed=False):
    """Draw 3 colored axes (X=red, Y=green, Z=blue, BGR convention) at the
    EEF origin showing orientation. `dashed=True` for predicted overlay so
    GT vs pred read off side-by-side at the same origin."""
    from scipy.spatial.transform import Rotation as ScipyR
    R = ScipyR.from_euler('xyz', euler_xyz).as_matrix()
    origin_pix = project_to_pixel(np.asarray(origin_3d, dtype=np.float64),
                                   T_cam_world, K_image)
    if origin_pix is None:
        return
    h, w = vis_bgr.shape[:2]
    u0, v0 = int(round(origin_pix[0])), int(round(origin_pix[1]))
    if not (0 <= u0 < w and 0 <= v0 < h):
        return
    # X red, Y green, Z blue (BGR cv2 convention -> swap for display).
    colors_bgr = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
    for axis in range(3):
        tip_3d = np.asarray(origin_3d, dtype=np.float64) + length * R[:, axis]
        tip_pix = project_to_pixel(tip_3d, T_cam_world, K_image)
        if tip_pix is None:
            continue
        u1, v1 = int(round(tip_pix[0])), int(round(tip_pix[1]))
        color = colors_bgr[axis]
        if dashed:
            n = 8
            for i in range(n):
                if i % 2 != 0:
                    continue
                t0 = i / n
                t1 = (i + 1) / n
                p0 = (int(round(u0 + t0 * (u1 - u0))), int(round(v0 + t0 * (v1 - v0))))
                p1 = (int(round(u0 + t1 * (u1 - u0))), int(round(v0 + t1 * (v1 - v0))))
                cv2.line(vis_bgr, p0, p1, color, thickness, cv2.LINE_AA)
        else:
            cv2.line(vis_bgr, (u0, v0), (u1, v1), color, thickness, cv2.LINE_AA)


def build_wandb_strip(sample, split_name):
    """Horizontal strip: one tile per timestep with heatmap + GT/pred annotations.
    Pulls per-sample T_cam_world (= 'world_to_camera') and K (= cam_K_norm
    scaled by image_size) from the batch dict so the strip works correctly
    for samples drawn from any dataset (e.g. mixed smith300 + UMI training)."""
    tiles = []
    K_image = None
    T_cam_world = sample['world_to_camera'].cpu().numpy()
    cam_K_norm = sample['cam_K_norm'].cpu().numpy()
    n_window = int(sample['rgb_frames_raw'].shape[0])  # honor per-run n_window
    for t in range(n_window):
        frame = sample['rgb_frames_raw'][t].cpu().numpy()
        H, W = frame.shape[:2]
        if K_image is None:
            # cam_K_norm is K / (W_orig, H_orig). Multiplying both rows by
            # the rendered image_size gives K matched to the rendered raster
            # (anisotropic in the original capture is preserved -- the
            # division by W_orig vs H_orig already encoded that ratio).
            K_image = cam_K_norm.copy().astype(np.float64)
            K_image[0, :] *= H
            K_image[1, :] *= H
        vis = (frame * 255).astype(np.uint8).copy()

        if 'pred_heatmap' in sample:
            heat = sample['pred_heatmap'][t].detach().cpu().numpy()
            heat = heat - heat.min()
            if heat.max() > 1e-8:
                heat = heat / heat.max()
            heat_rgb = np.zeros_like(frame)
            heat_rgb[..., 0] = heat
            vis = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
            vis = (vis * 255).astype(np.uint8)

            pred_y, pred_x = np.unravel_index(heat.argmax(), heat.shape)
            if 0 <= pred_x < W and 0 <= pred_y < H:
                cv2.drawMarker(vis, (int(pred_x), int(pred_y)), (0, 255, 0),
                               cv2.MARKER_CROSS, 14, 2)
                cv2.putText(vis, "pred", (int(pred_x) + 8, int(pred_y) - 8),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1)

        eef_pos = sample['trajectory_3d'][t].cpu().numpy().astype(np.float64)
        # Project EEF using per-sample T_cam_world + K (matched to image size).
        pix = project_to_pixel(eef_pos, T_cam_world, K_image)
        pt = (int(round(pix[0])), int(round(pix[1]))) if pix is not None else None
        if pt is not None:
            u, v = pt
            if 0 <= u < W and 0 <= v < H:
                cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)
                cv2.putText(vis, "eef", (u + 8, v - 8),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)

                ground_pos = eef_pos.copy()
                ground_pos[2] = 0.0
                gpix = project_to_pixel(ground_pos, T_cam_world, K_image)
                gpt = (int(round(gpix[0])), int(round(gpix[1]))) if gpix is not None else None
                if gpt is not None:
                    ug, vg = gpt
                    if 0 <= ug < W and 0 <= vg < H:
                        cv2.circle(vis, (ug, vg), 6, (0, 255, 255), 2)
                        cv2.line(vis, (u, v), (ug, vg), (255, 255, 0), 2)
                        cv2.putText(vis, f"h={eef_pos[2]:.3f}", (ug + 8, vg + 12),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 255), 1)

        # Pose axes: GT (solid) + Pred (dashed) at the GT EEF 3D position.
        # Drawing both at the same origin makes the orientation difference
        # the visible signal: perfect prediction -> dashed lies on solid.
        if 'gt_euler' in sample and 'pred_euler' in sample:
            try:
                eef_3d = sample['trajectory_3d'][t].cpu().numpy().astype(np.float64)
                gt_euler = sample['gt_euler'][t].cpu().numpy().astype(np.float64)
                pred_euler = sample['pred_euler'][t].cpu().numpy().astype(np.float64)
                draw_pose_axes(vis, eef_3d, gt_euler, T_cam_world, K_image,
                               length=0.04, thickness=2, dashed=False)
                draw_pose_axes(vis, eef_3d, pred_euler, T_cam_world, K_image,
                               length=0.04, thickness=2, dashed=True)
            except Exception as exc:
                print(f"axes draw failed t={t}: {exc}", flush=True)

        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2)
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1)
        tiles.append(vis)

    strip = np.concatenate(tiles, axis=1)
    return wandb.Image(strip, caption=f"{split_name}: t=0..{n_window-1}")


def build_vis_sample(model, batch, device):
    """Run model on first sample in batch, return dict for visualization."""
    model.eval()
    with torch.no_grad():
        rgb = batch['rgb'][0:1].to(device)
        traj_2d = batch['trajectory_2d'][0:1].to(device)

        start_kp = traj_2d[:, 0, :]
        # The rotation/gripper heads are *indexed*: they only return logits
        # when query_pixels is supplied. Same pattern as the training loop
        # (teacher-forcing at the GT pixel during training; here we use GT
        # since we want to compare GT vs pred orientation at known location).
        scale = PRED_SIZE / IMAGE_SIZE
        traj_2d_scaled = traj_2d * scale
        query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
        volume_logits, _, rotation_logits, _ = model(rgb, start_kp,
                                                     query_pixels=query_pixels)
        pred_euler = decode_pred_euler(rotation_logits)[0]  # (N, 3)

        pred_heatmaps = []
        n_window = int(volume_logits.shape[1])  # honor per-run n_window
        for t in range(n_window):
            vol_t = volume_logits[0, t]
            vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
            hm = vol_probs.max(dim=0)[0]
            hm_up = F.interpolate(hm[None, None], size=(IMAGE_SIZE, IMAGE_SIZE),
                                   mode='bilinear', align_corners=False)[0, 0]
            pred_heatmaps.append(hm_up)

        return {
            'rgb_frames_raw': batch['rgb_frames_raw'][0],
            'trajectory_3d': batch['trajectory_3d'][0],
            'trajectory_2d': batch['trajectory_2d'][0],
            'gt_euler': batch['trajectory_euler'][0],
            'pred_euler': pred_euler,
            'pred_heatmap': torch.stack(pred_heatmaps),
            # per-sample calibration so the strip viz works for any dataset
            'world_to_camera': batch['world_to_camera'][0],
            'cam_K_norm': batch['cam_K_norm'][0],
        }


# ── Main training loop ────────────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", required=True, nargs='+',
                   help="One or more dataset dirs (each with joints.npz, "
                        "meta.json, rgb_overlay/episodes.json, rgb_NNNNNN.jpg). "
                        "Multiple dirs are concatenated for joint training.")
    p.add_argument("--run_name", default="smith300_para_test")
    p.add_argument("--no_color_aug", action="store_true",
                   help="Disable color-channel-permutation + HSV/brightness "
                        "augmentation on the model input frame.")
    p.add_argument("--epochs", type=int, default=500)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--wandb_mode", default="online")
    p.add_argument("--wandb_project", default="para_smith300")
    p.add_argument("--freeze_backbone", action="store_true")
    p.add_argument("--val_split", type=float, default=0.15)
    p.add_argument("--vis_every", type=int, default=VIS_EVERY_EPOCHS,
                   help="(legacy) Epoch-level vis cadence. Kept for back-compat; "
                        "step-level vis (--vis_every_steps) is the primary path.")
    p.add_argument("--vis_every_steps", type=int, default=100,
                   help="Log a fresh wandb strip every N optimization steps. "
                        "Each call uses the CURRENT train batch (different "
                        "samples each time) plus a fresh val batch, instead "
                        "of the same cached pair, so the strip cycles through "
                        "the dataset and reflects recent gradients.")
    p.add_argument("--mujoco_xml", default=DEFAULT_SMITH300_XML)
    p.add_argument("--val_external_dir", default=None,
                   help="Optional path to a held-out evaluation dataset of a "
                        "different rig/session. Same wandb vis is generated "
                        "for it (logged as 'vis/val_external_strip' and "
                        "scalar metrics under 'val_external/'). Stats stay "
                        "from the training data — external is just probed.")
    p.add_argument("--save_every_steps", type=int, default=50,
                   help="Overwrite latest.pth every N optimizer steps. "
                        "Much finer-grained than the per-epoch cadence — "
                        "lets deploy grab a fresh checkpoint every ~minute.")
    p.add_argument("--volume_only_steps", type=int, default=0,
                   help="Curriculum warm-up: for the first N optimizer steps, "
                        "the total loss is volume-only (gripper + rotation are "
                        "computed and logged but excluded from backprop). Lets "
                        "the spatial-heatmap head settle before the MLP heads "
                        "start pulling on shared features. 0 = disabled.")
    p.add_argument("--use_keyframes", action="store_true",
                   help="Sample windows over user-annotated keyframes instead "
                        "of contiguous frames. Each episode pads to n_window "
                        "by repeating its last keyframe.")
    p.add_argument("--backbone", default="vits16plus",
                   choices=("vits16plus", "convnext_small"),
                   help="DINOv3 backbone variant. 'vits16plus' (default) is the "
                        "original ViT-S/16+ path with CLS-concat grip/rot heads. "
                        "'convnext_small' is the DINOv3 ConvNeXt-S path with "
                        "no CLS and grid_sample sub-pixel-indexed heads.")
    p.add_argument("--resume_from", default=None,
                   help="Path to a .pth checkpoint to warm-start model weights "
                        "from. Model weights only — optimizer/scheduler/step "
                        "counters reset (so this is fine-tuning, not literal "
                        "resume). Shape-mismatched tensors are skipped.")
    p.add_argument("--n_window", type=int, default=N_WINDOW,
                   help=f"Prediction window length (number of future "
                        f"timesteps the model predicts). Default {N_WINDOW}. "
                        f"Affects volume_head depth, MLP head outputs, and "
                        f"dataset window slicing -- threaded through both.")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}", flush=True)

    # Episodes per dataset (same fallback path as the dataset class).
    data_dirs = list(args.data_dir)
    dirs_with_episodes = []
    for ddir in data_dirs:
        epi_path = None
        for cand in ["rgb_overlay/episodes.json", "episodes.json"]:
            cp = os.path.join(ddir, cand)
            if os.path.exists(cp):
                epi_path = cp
                break
        if epi_path is None:
            raise FileNotFoundError(f"no episodes.json under {ddir}")
        with open(epi_path) as f:
            ep_data = json.load(f)
        dirs_with_episodes.append((ddir, ep_data["episodes"]))
        print(f"  dataset {ddir}: {len(ep_data['episodes'])} episodes")

    # Joint stats across all datasets.
    stats = compute_dataset_stats_fast(dirs_with_episodes,
                                       mujoco_xml=args.mujoco_xml)

    model_module.MIN_HEIGHT = stats["min_height"]
    model_module.MAX_HEIGHT = stats["max_height"]
    model_module.MIN_GRIPPER = stats["min_gripper"]
    model_module.MAX_GRIPPER = stats["max_gripper"]
    model_module.MIN_ROT = stats["min_rot"]
    model_module.MAX_ROT = stats["max_rot"]
    model_module.MIN_POS = [0, 0, stats["min_height"]]
    model_module.MAX_POS = [1, 1, stats["max_height"]]

    ckpt_dir = Path(f"checkpoints/{args.run_name}")
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    with open(ckpt_dir / "dataset_stats.json", "w") as f:
        json.dump(stats, f, indent=2)

    # Per-dataset: build TWO instances (augmented and non-augmented) and the
    # SAME train/val index split, then assemble train_ds (aug=True) and
    # val_ds (aug=False). ConcatDataset stitches multiple datasets together.
    from torch.utils.data import Subset, ConcatDataset
    train_subsets, val_subsets = [], []
    for ddir, _ in dirs_with_episodes:
        ds_aug = Smith300TrajectoryDataset(
            ddir, frame_stride=args.frame_stride,
            mujoco_xml=args.mujoco_xml,
            augment_color=(not args.no_color_aug),
            n_window=args.n_window,
            use_keyframes=args.use_keyframes,
        )
        ds_clean = Smith300TrajectoryDataset(
            ddir, frame_stride=args.frame_stride,
            mujoco_xml=args.mujoco_xml,
            augment_color=False,
            n_window=args.n_window,
            use_keyframes=args.use_keyframes,
        )
        n_total = len(ds_aug)
        n_val = max(1, int(n_total * args.val_split))
        n_train = n_total - n_val
        # Stable per-dataset seed so val frames are consistent across re-runs
        seed = 42 + (abs(hash(ddir)) % 10_000)
        gen = torch.Generator().manual_seed(seed)
        idxs = torch.randperm(n_total, generator=gen).tolist()
        train_subsets.append(Subset(ds_aug, idxs[:n_train]))
        val_subsets.append(Subset(ds_clean, idxs[n_train:]))
        print(f"  {Path(ddir).name}: train={n_train}, val={n_val}, total={n_total}")

    train_ds = ConcatDataset(train_subsets) if len(train_subsets) > 1 else train_subsets[0]
    val_ds = ConcatDataset(val_subsets) if len(val_subsets) > 1 else val_subsets[0]
    print(f"Combined: train={len(train_ds)}, val={len(val_ds)}, "
          f"color_aug={'OFF' if args.no_color_aug else 'ON'}", flush=True)

    # num_workers=8: in-RAM RGB cache means workers share the (read-only) cache,
    # so they parallelize the per-batch CPU work (heatmap-target construction,
    # uint8->float32 + ImageNet normalize) instead of bottlenecking single-
    # threaded. Historical deadlock comment removed — that was at workers=4
    # with on-disk reads on the Mac mount; cached reads don't trip the same
    # race. If the val iterator ever stalls again, drop back to 0.
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True,
                              persistent_workers=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                            num_workers=8, pin_memory=True,
                            persistent_workers=True)
    # Step-level vis: a rolling val iterator so we cycle through the val set
    # over time instead of always showing the same cached pair. Recreated on
    # exhaustion. Train side uses the live training batch each step.
    val_iter_for_vis = [iter(val_loader)]  # list so closure can rebind

    # Optional external val: a DIFFERENT-rig dataset (e.g. UMI->robot probe).
    # Uses the training set's stats so quantization bins are consistent; only
    # the RGB/EEF/K/etc. differ. Visualized identically to val.
    val_external_loader = None
    val_external_iter = [None]
    if args.val_external_dir:
        ext_ds = Smith300TrajectoryDataset(
            args.val_external_dir, frame_stride=args.frame_stride,
            mujoco_xml=args.mujoco_xml, augment_color=False,
            n_window=args.n_window, use_keyframes=args.use_keyframes,
        )
        val_external_loader = DataLoader(
            ext_ds, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, persistent_workers=True,
        )
        val_external_iter[0] = iter(val_external_loader)
        print(f"val_external: {Path(args.val_external_dir).name} -> "
              f"{len(ext_ds)} samples (logged as val_external/*).", flush=True)

    model = TrajectoryHeatmapPredictor(n_window=args.n_window,
                                       backbone=args.backbone).to(device)
    if args.resume_from:
        ck = torch.load(args.resume_from, map_location="cpu", weights_only=False)
        src_sd = ck.get("model_state_dict", ck)
        cur_sd = model.state_dict()
        loaded, skipped = 0, []
        for k, v in src_sd.items():
            if k in cur_sd and cur_sd[k].shape == v.shape:
                cur_sd[k] = v
                loaded += 1
            else:
                skipped.append(k)
        model.load_state_dict(cur_sd)
        print(f"[resume_from] loaded {loaded}/{len(src_sd)} tensors from "
              f"{args.resume_from} (skipped {len(skipped)} shape-mismatched).",
              flush=True)
    if args.freeze_backbone:
        for param in model.dino.parameters():
            param.requires_grad = False
        print("Backbone frozen", flush=True)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    wandb.init(project=args.wandb_project, name=args.run_name,
               mode=args.wandb_mode,
               config={**vars(args), **stats,
                       "VOLUME_LOSS_WEIGHT": VOLUME_LOSS_WEIGHT,
                       "GRIPPER_LOSS_WEIGHT": GRIPPER_LOSS_WEIGHT,
                       "ROTATION_LOSS_WEIGHT": ROTATION_LOSS_WEIGHT})

    best_val_loss = float("inf")
    global_step = 0

    def _next_val_batch():
        try:
            return next(val_iter_for_vis[0])
        except StopIteration:
            val_iter_for_vis[0] = iter(val_loader)
            return next(val_iter_for_vis[0])

    def _next_val_external_batch():
        if val_external_loader is None:
            return None
        try:
            return next(val_external_iter[0])
        except StopIteration:
            val_external_iter[0] = iter(val_external_loader)
            return next(val_external_iter[0])

    def _maybe_log_step_vis(current_train_batch):
        """Build + log a wandb strip from the current train batch, a fresh val
        batch, and (if configured) a fresh val_external batch. Cycles through
        samples over time. Restores model.train()."""
        try:
            train_sample = build_vis_sample(model, current_train_batch, device)
            val_sample = build_vis_sample(model, _next_val_batch(), device)
            train_strip = build_wandb_strip(train_sample, "train")
            val_strip = build_wandb_strip(val_sample, "val")
            log_payload = {"step": global_step}
            if train_strip is not None:
                log_payload["vis/train_strip"] = train_strip
            if val_strip is not None:
                log_payload["vis/val_strip"] = val_strip
            ext_batch = _next_val_external_batch()
            if ext_batch is not None:
                ext_sample = build_vis_sample(model, ext_batch, device)
                ext_strip = build_wandb_strip(ext_sample, "val_external")
                if ext_strip is not None:
                    log_payload["vis/val_external_strip"] = ext_strip
            wandb.log(log_payload)
        except Exception as e:
            print(f"step-vis err at step={global_step}: {e}", flush=True)
        finally:
            model.train()  # build_vis_sample sets eval; restore for next step

    loss_eq = EMALossEqualizer(beta=0.99)
    for epoch in range(args.epochs):
        model.train()
        train_losses, train_vol_losses, train_grip_losses, train_rot_losses = [], [], [], []
        for batch in train_loader:
            rgb = batch["rgb"].to(device)
            traj_2d = batch["trajectory_2d"].to(device)
            traj_3d = batch["trajectory_3d"].to(device)
            traj_grip = batch["trajectory_gripper"].to(device)
            traj_euler = batch["trajectory_euler"].to(device)
            traj_valid = batch.get("trajectory_valid")
            if traj_valid is not None:
                traj_valid = traj_valid.to(device)

            scale = PRED_SIZE / IMAGE_SIZE
            traj_2d_scaled = traj_2d * scale
            height_bins = discretize_height(traj_3d[:, :, 2])

            start_kp = traj_2d[:, 0, :]
            query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
            volume_logits, gripper_logits, rotation_logits, _ = model(
                rgb, start_kp, query_pixels=query_pixels)

            vol_loss = compute_volume_loss(volume_logits, traj_2d_scaled,
                                           height_bins, valid_mask=traj_valid)
            grip_loss = compute_gripper_loss(gripper_logits, traj_grip,
                                             valid_mask=traj_valid)
            rot_loss = compute_rotation_loss(rotation_logits, traj_euler,
                                             valid_mask=traj_valid)

            # EMA-equalized: each task contributes ~1 to the gradient. During
            # the volume-only warm-up window, grip+rot are computed for the
            # log but their EMA isn't updated yet, so we don't normalize them.
            vol_only = global_step < args.volume_only_steps
            loss = loss_eq.normalize("vol", vol_loss)
            if not vol_only:
                loss = (loss + loss_eq.normalize("grip", grip_loss)
                             + loss_eq.normalize("rot", rot_loss))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.item())
            train_vol_losses.append(vol_loss.item())
            train_grip_losses.append(grip_loss.item())
            train_rot_losses.append(rot_loss.item())
            global_step += 1

            # Step-level wandb visualization: more frequent than per-epoch
            # and cycles through samples (current train batch + rolling val
            # batch) so we see different scenes over time.
            if (args.vis_every_steps > 0
                    and global_step % args.vis_every_steps == 0):
                _maybe_log_step_vis(batch)

            # Fine-grained latest.pth: overwrite every N optimizer steps so
            # deploy can grab a fresh checkpoint without waiting for an
            # epoch boundary. best.pth still saves per-epoch on val win.
            if (args.save_every_steps > 0
                    and global_step % args.save_every_steps == 0):
                ckpt_step = {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch, "global_step": global_step, **stats,
                }
                torch.save(ckpt_step, ckpt_dir / "latest.pth")

        scheduler.step()
        train_loss = np.mean(train_losses)

        model.eval()
        val_losses, val_pixel_errs = [], []
        val_vol_losses, val_grip_losses, val_rot_losses = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device)
                traj_2d = batch["trajectory_2d"].to(device)
                traj_3d = batch["trajectory_3d"].to(device)
                traj_grip = batch["trajectory_gripper"].to(device)
                traj_euler = batch["trajectory_euler"].to(device)
                traj_valid = batch.get("trajectory_valid")
                if traj_valid is not None:
                    traj_valid = traj_valid.to(device)

                scale = PRED_SIZE / IMAGE_SIZE
                traj_2d_scaled = traj_2d * scale
                height_bins = discretize_height(traj_3d[:, :, 2])

                start_kp = traj_2d[:, 0, :]
                query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
                volume_logits, gripper_logits, rotation_logits, _ = model(
                    rgb, start_kp, query_pixels=query_pixels)

                vol_loss = compute_volume_loss(volume_logits, traj_2d_scaled,
                                               height_bins, valid_mask=traj_valid)
                grip_loss = compute_gripper_loss(gripper_logits, traj_grip,
                                                 valid_mask=traj_valid)
                rot_loss = compute_rotation_loss(rotation_logits, traj_euler,
                                                 valid_mask=traj_valid)
                # Same EMA scales used in training so train/val are comparable.
                v_loss = (loss_eq.normalize("vol", vol_loss)
                          + loss_eq.normalize("grip", grip_loss)
                          + loss_eq.normalize("rot", rot_loss))
                val_losses.append(v_loss.item())
                val_vol_losses.append(vol_loss.item())
                val_grip_losses.append(grip_loss.item())
                val_rot_losses.append(rot_loss.item())

                pred_2d, _ = extract_pred_2d_and_height(volume_logits)
                err = (pred_2d - traj_2d_scaled).norm(dim=-1).mean().item()
                val_pixel_errs.append(err)

        val_loss = np.mean(val_losses)
        val_px_err = np.mean(val_pixel_errs)

        # Optional external-val pass (transfer probe). Same forward, same
        # losses, but on a DIFFERENT rig/session. Reported separately.
        val_ext_loss = None
        val_ext_px_err = None
        if val_external_loader is not None:
            ext_losses, ext_pixel_errs = [], []
            ext_vol_losses, ext_grip_losses, ext_rot_losses = [], [], []
            with torch.no_grad():
                for batch in val_external_loader:
                    rgb = batch["rgb"].to(device)
                    traj_2d = batch["trajectory_2d"].to(device)
                    traj_3d = batch["trajectory_3d"].to(device)
                    traj_grip = batch["trajectory_gripper"].to(device)
                    traj_euler = batch["trajectory_euler"].to(device)
                    traj_valid = batch.get("trajectory_valid")
                    if traj_valid is not None:
                        traj_valid = traj_valid.to(device)
                    scale = PRED_SIZE / IMAGE_SIZE
                    traj_2d_scaled = traj_2d * scale
                    height_bins = discretize_height(traj_3d[:, :, 2])
                    start_kp = traj_2d[:, 0, :]
                    query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
                    volume_logits, gripper_logits, rotation_logits, _ = model(
                        rgb, start_kp, query_pixels=query_pixels)
                    vol_l = compute_volume_loss(volume_logits, traj_2d_scaled,
                                                height_bins, valid_mask=traj_valid)
                    grip_l = compute_gripper_loss(gripper_logits, traj_grip,
                                                  valid_mask=traj_valid)
                    rot_l = compute_rotation_loss(rotation_logits, traj_euler,
                                                  valid_mask=traj_valid)
                    # Use training EMA scales (read-only here — do NOT update
                    # them on external data).
                    def _ema_div(name, l):
                        scale_ = max(loss_eq.ema.get(name, 1.0), loss_eq.min_ema)
                        return (l / scale_).item()
                    ext_losses.append(_ema_div("vol", vol_l) + _ema_div("grip", grip_l) + _ema_div("rot", rot_l))
                    ext_vol_losses.append(vol_l.item())
                    ext_grip_losses.append(grip_l.item())
                    ext_rot_losses.append(rot_l.item())
                    pred_2d, _ = extract_pred_2d_and_height(volume_logits)
                    ext_pixel_errs.append((pred_2d - traj_2d_scaled).norm(dim=-1).mean().item())
            val_ext_loss = float(np.mean(ext_losses))
            val_ext_px_err = float(np.mean(ext_pixel_errs))

        log_dict = {
            "epoch": epoch,
            "train/loss": train_loss,
            "train/vol_loss": np.mean(train_vol_losses),
            "train/grip_loss": np.mean(train_grip_losses),
            "train/rot_loss": np.mean(train_rot_losses),
            "val/loss": val_loss,
            "val/vol_loss": np.mean(val_vol_losses),
            "val/grip_loss": np.mean(val_grip_losses),
            "val/rot_loss": np.mean(val_rot_losses),
            "val/pixel_error": val_px_err,
            "lr": scheduler.get_last_lr()[0],
            **{f"ema/{k}": v for k, v in loss_eq.state().items()},
        }
        if val_ext_loss is not None:
            log_dict["val_external/loss"] = val_ext_loss
            log_dict["val_external/vol_loss"] = float(np.mean(ext_vol_losses))
            log_dict["val_external/grip_loss"] = float(np.mean(ext_grip_losses))
            log_dict["val_external/rot_loss"] = float(np.mean(ext_rot_losses))
            log_dict["val_external/pixel_error"] = val_ext_px_err

        # Note: visualizations now log every --vis_every_steps inside the
        # train loop (see _maybe_log_step_vis above) instead of every N
        # epochs. The cached-batch path was removed because (a) we want
        # different samples over time, (b) epoch-level vis is too sparse for
        # short runs and small datasets where each epoch is fast.

        wandb.log(log_dict)

        if epoch % 10 == 0:
            print(f"Epoch {epoch:4d} | train={train_loss:.4f} | "
                  f"val={val_loss:.4f} | px_err={val_px_err:.1f}px", flush=True)

        ckpt_data = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch, **stats,
        }
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({**ckpt_data, "val_loss": val_loss}, ckpt_dir / "best.pth")

        # latest.pth is now step-based (--save_every_steps). Per-epoch save
        # removed to avoid double-saving; best.pth still tracks val-loss wins.

    wandb.finish()
    print(f"Done! Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
