"""Train DinoVolumeQuery on smith300 izzy_home_recording_2 with T=50 long horizon.

Per Cameron 2026-05-20: query-MLP design. Same DINO trunk (resume from v9),
per-timestep MLP with AdaLN-Zero(sin(t)) produces a spatial query + gripper/rotation
heads. Volume scoring is a cheap factored dot product (no 6D volume materialised).

Long horizon (T=50): the dataset already pads future frames at the last-real frame.
Per Cameron, we DON'T mask padded positions — let the model learn to "stay at end
pose" naturally.
"""
import os, sys, time, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
from tqdm import tqdm
import wandb
import cv2

sys.path.insert(0, os.path.dirname(__file__))
from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT
from model_dino_per_pixel import (DinoPerPixelMLP, IMG_SIZE, N_HEIGHT_BINS,
                                   N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE)


# ---------------- Viz helpers ----------------

def rainbow_overlay(rgb, pred_pix, gt_pix, img_size=IMG_SIZE):
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    img = (rgb.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)
    img = (img * 255).astype(np.uint8).copy()
    T = pred_pix.shape[0]
    for t in range(T):
        hue = int(t / max(T - 1, 1) * 170)
        col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
        px, py = int(pred_pix[t, 0]), int(pred_pix[t, 1])
        gx, gy = int(gt_pix[t, 0]),   int(gt_pix[t, 1])
        cv2.circle(img, (px, py), 4, col, -1)
        cv2.circle(img, (gx, gy), 4, (255, 255, 255), 1)
    return img


def marginal_heatmap(vol_logits, cols=10):
    T, Z, H, W = vol_logits.shape
    p = torch.softmax(vol_logits.reshape(T, -1), dim=-1).reshape(T, Z, H, W)
    margin = p.sum(dim=1).cpu().numpy()
    rows = (T + cols - 1) // cols
    tile_h = tile_w = 56
    grid = np.zeros((rows * tile_h, cols * tile_w, 3), dtype=np.uint8)
    for i in range(T):
        m = margin[i]
        m = (m / (m.max() + 1e-8) * 255).astype(np.uint8)
        m = cv2.applyColorMap(m, cv2.COLORMAP_JET)
        r, c = i // cols, i % cols
        grid[r*tile_h:(r+1)*tile_h, c*tile_w:(c+1)*tile_w] = m
    return cv2.cvtColor(grid, cv2.COLOR_BGR2RGB)


def piano_roll(logits, gt_bins, valid, cell_h=6, cell_w=6, gt_color=(0, 0, 255)):
    """Compact 2D piano-roll viz for long horizons.
       logits:  (T, n_bins). gt_bins: (T,) long. valid: (T,) bool.
       Returns (T*cell_h, n_bins*cell_w, 3) uint8 RGB image.
       Each row is the per-step softmax (row-normalised colormap), GT marked in red.
       Padded rows dimmed."""
    T, n = logits.shape
    p = torch.softmax(logits, dim=-1).cpu().numpy()                            # (T, n)
    p_norm = p / (p.max(axis=1, keepdims=True) + 1e-8)
    img8 = (p_norm * 255).astype(np.uint8)
    img = cv2.applyColorMap(img8, cv2.COLORMAP_VIRIDIS)                         # (T, n, 3) BGR
    # Mark GT with a thin red column per row
    for t in range(T):
        gt = int(gt_bins[t])
        img[t, gt] = gt_color
        if not valid[t]:
            img[t] = (img[t].astype(np.float32) * 0.3).astype(np.uint8)         # dim invalid rows
    img = cv2.resize(img, (n * cell_w, T * cell_h), interpolation=cv2.INTER_NEAREST)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def height_piano_roll(vol_logits, gt_pix, gt_z_bin, valid, cell_h=6, cell_w=6):
    """Per-step height distribution AT GT pixel, piano-roll style. (T, Z) grid."""
    T, Z, h, w = vol_logits.shape
    scale_x = w / IMG_SIZE; scale_y = h / IMG_SIZE
    z_logits = torch.zeros(T, Z, device=vol_logits.device)
    for t in range(T):
        gx = int(min(max(0, gt_pix[t, 0] * scale_x), w - 1))
        gy = int(min(max(0, gt_pix[t, 1] * scale_y), h - 1))
        p_joint = torch.softmax(vol_logits[t].reshape(-1), dim=0).reshape(Z, h, w)
        z_logits[t] = torch.log(p_joint[:, gy, gx].clamp_min(1e-12))             # log-probs so the piano roll's softmax recovers p(z|y,x)
    return piano_roll(z_logits, gt_z_bin, valid, cell_h=cell_h, cell_w=cell_w)


def height_dist_strip(vol_logits, gt_pix, gt_z_bin, valid):
    """For each t, draw the height distribution at the GT pixel + a vertical bar at gt_z_bin.
    vol_logits: (T, Z, h, w); gt_pix in 504-space; gt_z_bin: (T,)."""
    T, Z, h, w = vol_logits.shape
    scale_x = w / IMG_SIZE; scale_y = h / IMG_SIZE
    tiles = []
    for t in range(T):
        if not valid[t]:
            tiles.append(np.zeros((80, Z * 6 + 20, 3), dtype=np.uint8))
            continue
        gx = int(min(max(0, gt_pix[t, 0] * scale_x), w - 1))
        gy = int(min(max(0, gt_pix[t, 1] * scale_y), h - 1))
        v = vol_logits[t]
        p_joint = torch.softmax(v.reshape(-1), dim=0).reshape(v.shape)
        p_z = p_joint[:, gy, gx]
        p_z = (p_z / (p_z.sum() + 1e-8)).cpu().numpy()
        bar = np.zeros((80, Z * 6 + 20, 3), dtype=np.uint8)
        for z in range(Z):
            h_bar = int(p_z[z] / (p_z.max() + 1e-8) * 70)
            cv2.rectangle(bar, (10 + z * 6, 75 - h_bar), (14 + z * 6, 75), (100, 200, 100), -1)
        z_gt = int(gt_z_bin[t])
        cv2.rectangle(bar, (10 + z_gt * 6, 5), (14 + z_gt * 6, 75), (0, 0, 255), 1)
        cv2.putText(bar, f"t{t}", (2, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        tiles.append(bar)
    return cv2.cvtColor(np.concatenate(tiles, axis=0), cv2.COLOR_BGR2RGB)


def dino_pca(feats):
    """feats: any (..., C) tensor — PCA over spatial tokens to 3 components, reshape to a grid."""
    f = feats.detach().float().cpu().numpy() if hasattr(feats, 'detach') else np.asarray(feats)
    if f.ndim == 4:                                    # (B, C, H, W) — take batch 0, flatten spatial
        f = f[0].reshape(f.shape[1], -1).T              # (H*W, C)
    elif f.ndim == 3:                                  # (C, H, W) — flatten spatial
        f = f.reshape(f.shape[0], -1).T
    elif f.ndim != 2:
        f = f.reshape(-1, f.shape[-1])
    f = f - f.mean(0, keepdims=True)
    try:
        u, s, vt = np.linalg.svd(f, full_matrices=False)
        pcs = f @ vt[:3].T
    except Exception:
        pcs = f[:, :3]
    pcs = (pcs - pcs.min(0)) / (pcs.max(0) - pcs.min(0) + 1e-8)
    n = pcs.shape[0]
    side = int(round(np.sqrt(n)))
    if side * side != n:
        side = int(np.floor(np.sqrt(n))); pcs = pcs[:side * side]
    img = (pcs.reshape(side, side, 3) * 255).astype(np.uint8)
    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_NEAREST)
    return img


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--sessions_whitelist", type=str, default="izzy_home_recording_2")
    p.add_argument("--n_window", type=int, default=50)
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=100)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--gripper_loss_weight",  type=float, default=0.5)
    p.add_argument("--rotation_loss_weight", type=float, default=0.5)
    p.add_argument("--depth_subdir", type=str, default="da3_depth_large")
    p.add_argument("--vis_every_steps",  type=int, default=50)
    p.add_argument("--log_scalars_every", type=int, default=5)
    p.add_argument("--save_every_epochs", type=int, default=5)
    p.add_argument("--resume_from", type=str, default="")
    p.add_argument("--use_eef", type=int, default=1,
                   help="1 = concat eef_feat + cls as query input (default); 0 = cls only (ablation)")
    p.add_argument("--run_name", type=str, default="query_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    wl = [s.strip() for s in args.sessions_whitelist.split(",") if s.strip()] if args.sessions_whitelist else None
    print(f"Loading dataset: {args.root_dir} (whitelist={wl})  n_window={args.n_window}")
    full = Smith300DA3VolumeDataset(
        root_dir=args.root_dir, image_size=DA3_INPUT,
        n_window=args.n_window, frame_stride=args.frame_stride,
        depth_subdir=args.depth_subdir, sessions_whitelist=wl,
    )
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = random_split(full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
    print(f"  Train: {n_tr}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                               num_workers=args.num_workers, pin_memory=True, drop_last=True,
                               persistent_workers=args.num_workers > 0)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True,
                             persistent_workers=args.num_workers > 0)

    print("Building model...")
    model = DinoPerPixelMLP(
        n_window=args.n_window, n_height_bins=N_HEIGHT_BINS,
        n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
        image_size=IMG_SIZE, pred_size=PRED_SIZE,
    ).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    if args.resume_from:
        print(f"Resuming DINO weights from {args.resume_from}")
        sd_full = torch.load(args.resume_from, map_location=device, weights_only=False)
        ckpt_sd = sd_full["model_state_dict"]
        cur_sd = model.state_dict()
        loaded = {}
        for k, v in ckpt_sd.items():
            if k in cur_sd and cur_sd[k].shape == v.shape:
                loaded[k] = v
        missing, unexpected = model.load_state_dict(loaded, strict=False)
        print(f"  loaded={len(loaded)} keys (expect ~DINO trunk), missing={len(missing)}, unexpected={len(unexpected)}")

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)

    global_step = 0; t0 = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb       = batch["rgb"].to(device, non_blocking=True)
            gt_pix    = batch["gt_pix_504"].to(device, non_blocking=True)        # (B, T, 2)
            gt_z_bin  = batch["gt_z_bin"].to(device, non_blocking=True)          # (B, T)
            gt_rot    = batch["gt_rot_bin"].to(device, non_blocking=True)        # (B, T, 3) per-axis
            gt_grip   = batch["gt_grip_bin"].to(device, non_blocking=True)       # (B, T)
            B, T, _ = gt_pix.shape

            # Discretise pixel targets into PRED grid coords (teacher forcing: per-t GT pixel)
            _W = _H = PRED_SIZE
            gx = (gt_pix[..., 0] * (_W / DA3_INPUT)).long().clamp(0, _W - 1)
            gy = (gt_pix[..., 1] * (_H / DA3_INPUT)).long().clamp(0, _H - 1)
            query_pixels = torch.stack([gy, gx], dim=-1)                         # (B, T, 2) y,x grid

            out = model(rgb, query_pixels=query_pixels)
            vol         = out["volume_logits"]                                    # (B, T, Z, H, W)
            grip_logits = out["gripper_logits"]
            rot_logits  = out["rotation_logits"]
            Z = vol.shape[2]; H, W = vol.shape[-2:]
            gz = gt_z_bin.clamp(0, Z - 1)
            tgt_flat = gz * (H * W) + gy * W + gx                                 # (B, T)

            # Loss — NO MASK on padded positions (Cameron: "just pad" → supervise on last-real values)
            volume_loss   = F.cross_entropy(vol.reshape(B * T, Z * H * W), tgt_flat.reshape(-1))
            gripper_loss  = F.cross_entropy(grip_logits.reshape(B * T, -1), gt_grip.reshape(-1))
            # Rotation: per-axis CE, mean over 3 axes
            n_rot = rot_logits.shape[-1]
            rotation_loss = sum(
                F.cross_entropy(rot_logits[..., a, :].reshape(B * T, n_rot),
                                 gt_rot[..., a].reshape(-1))
                for a in range(3)
            ) / 3.0
            total = volume_loss + args.gripper_loss_weight * gripper_loss \
                                + args.rotation_loss_weight * rotation_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(v=f"{volume_loss.item():.2f}",
                              g=f"{gripper_loss.item():.2f}",
                              r=f"{rotation_loss.item():.2f}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                # Train-fit metrics (argmax-decoded), valid-masked. Headline per Cameron.
                with torch.no_grad():
                    valid_mask = batch["gt_pix_valid"].to(device).float()
                    mask_sum = valid_mask.sum().clamp_min(1.0)
                    flat_v = vol.reshape(B, T, -1).argmax(dim=-1)
                    pyx = flat_v % (H * W)
                    py_p = (pyx // W).float() / (H / DA3_INPUT)
                    px_p = (pyx % W).float()  / (W / DA3_INPUT)
                    pred_pix = torch.stack([px_p, py_p], dim=-1)
                    train_pix = ((pred_pix - gt_pix).norm(dim=-1) * valid_mask).sum() / mask_sum
                    train_grip_acc = ((grip_logits.argmax(-1) == gt_grip).float() * valid_mask).sum() / mask_sum
                    rot_correct = (rot_logits.argmax(-1) == gt_rot).float().mean(-1)
                    train_rot_acc = (rot_correct * valid_mask).sum() / mask_sum
                wandb.log({"train/volume_loss":   volume_loss.item(),
                           "train/gripper_loss":  gripper_loss.item(),
                           "train/rotation_loss": rotation_loss.item(),
                           "train/total":         total.item(),
                           "train/pix_argmax":    train_pix.item(),
                           "train/grip_acc":      train_grip_acc.item(),
                           "train/rot_acc":       train_rot_acc.item(),
                           "epoch": epoch}, step=global_step)

            if global_step % args.vis_every_steps == 0:
                model.eval()
                with torch.no_grad():
                    idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    s = train_ds[idx]
                    v_rgb   = s["rgb"].unsqueeze(0).to(device)
                    v_pix   = s["gt_pix_504"].cpu().numpy()
                    v_grip  = s["gt_grip_bin"].cpu().numpy()
                    v_rot   = s["gt_rot_bin"].cpu().numpy()
                    v_valid = s["gt_pix_valid"].cpu().numpy()
                    v_gt = s["gt_pix_504"].to(device)                                  # (T, 2)
                    v_gx = (v_gt[..., 0] * (PRED_SIZE / DA3_INPUT)).long().clamp(0, PRED_SIZE - 1)
                    v_gy = (v_gt[..., 1] * (PRED_SIZE / DA3_INPUT)).long().clamp(0, PRED_SIZE - 1)
                    v_qp = torch.stack([v_gy, v_gx], dim=-1).unsqueeze(0)             # (1, T, 2)
                    vo = model(v_rgb, query_pixels=v_qp)
                    v_vol = vo["volume_logits"][0]
                    v_grip_logits = vo["gripper_logits"][0]
                    v_rot_logits  = vo["rotation_logits"][0]
                    Tv, Zv, Hv, Wv = v_vol.shape
                    flat = v_vol.reshape(Tv, -1).argmax(dim=-1).cpu().numpy()
                    pyx = flat % (Hv * Wv)
                    py = (pyx // Wv).astype(np.float32) / (Hv / DA3_INPUT)
                    px = (pyx % Wv).astype(np.float32) / (Wv / DA3_INPUT)
                    v_pred_pix = np.stack([px, py], axis=-1)
                    # Train pix err on REAL (valid) positions only — fair metric
                    if v_valid.any():
                        err = np.linalg.norm(v_pred_pix - v_pix, axis=-1)[v_valid].mean()
                    else:
                        err = 0.0
                    v_z = s["gt_z_bin"].cpu().numpy()
                    v_pix_t = s["gt_pix_504"]                                          # tensor for height_dist
                    v_feats = vo["pixel_feats"][0]                                     # (C, H, W) for DINO PCA viz
                # Piano-roll dist viz (compact 2D heatmaps for long horizons)
                viz_kp  = rainbow_overlay(s["rgb"], v_pred_pix, v_pix)
                viz_hm  = marginal_heatmap(v_vol)
                viz_gd  = piano_roll(v_grip_logits, v_grip, v_valid)
                # Rotation: stack the 3 per-axis piano rolls side-by-side
                v_rot_np = s["gt_rot_bin"].cpu().numpy()                             # (T, 3)
                rot_panels = [piano_roll(v_rot_logits[:, a, :], v_rot_np[:, a], v_valid)
                              for a in range(3)]
                # add small separators between panels
                sep = np.full((rot_panels[0].shape[0], 8, 3), 50, dtype=np.uint8)
                viz_rd = np.concatenate([rot_panels[0], sep, rot_panels[1], sep, rot_panels[2]], axis=1)
                viz_zd  = height_piano_roll(v_vol, v_pix_t, v_z, v_valid)
                viz_pca = dino_pca(v_feats)
                wandb.log({"vis/keypoints":        wandb.Image(viz_kp),
                           "vis/heatmap_marginal": wandb.Image(viz_hm),
                           "vis/gripper_dist":     wandb.Image(viz_gd),
                           "vis/rotation_dist":    wandb.Image(viz_rd),
                           "vis/height_dist":      wandb.Image(viz_zd),
                           "vis/dino_pca":         wandb.Image(viz_pca),
                           "train/pix_argmax_valid": float(err)}, step=global_step)
                model.train()

        # End-of-epoch val
        model.eval()
        vv, vp, vg, vr = [], [], [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device); gt_pix = batch["gt_pix_504"].to(device)
                gt_z_bin = batch["gt_z_bin"].to(device); gt_rot = batch["gt_rot_bin"].to(device)
                gt_grip = batch["gt_grip_bin"].to(device)
                valid = batch["gt_pix_valid"].to(device)
                gx_ = (gt_pix[..., 0] * (PRED_SIZE / DA3_INPUT)).long().clamp(0, PRED_SIZE - 1)
                gy_ = (gt_pix[..., 1] * (PRED_SIZE / DA3_INPUT)).long().clamp(0, PRED_SIZE - 1)
                qp = torch.stack([gy_, gx_], dim=-1)
                out = model(rgb, query_pixels=qp)
                vol = out["volume_logits"]; grip = out["gripper_logits"]; rot = out["rotation_logits"]
                Bv, Tv, Zv, Hv, Wv = vol.shape
                gz_ = gt_z_bin.clamp(0, Zv - 1)
                tgt = gz_ * (Hv * Wv) + gy_ * Wv + gx_
                vv.append(F.cross_entropy(vol.reshape(Bv*Tv, -1), tgt.reshape(-1)).item())
                # Val pix err on valid positions only
                flat = vol.reshape(Bv, Tv, -1).argmax(dim=-1)
                pyx = flat % (Hv * Wv)
                py = (pyx // Wv).float() / (Hv / DA3_INPUT)
                px = (pyx % Wv).float()  / (Wv / DA3_INPUT)
                pp = torch.stack([px, py], dim=-1)
                mask = valid.float()
                err = (pp - gt_pix).norm(dim=-1) * mask
                vp.append((err.sum() / mask.sum().clamp_min(1)).item())
                vg.append(((grip.argmax(-1) == gt_grip).float() * mask).sum().item() / mask.sum().clamp_min(1).item())
                # Per-axis rotation accuracy, averaged over 3 axes
                rot_correct = (rot.argmax(-1) == gt_rot).float()                       # (B, T, 3)
                vr.append(((rot_correct.mean(dim=-1)) * mask).sum().item() / mask.sum().clamp_min(1).item())
        v_v, v_p, v_g, v_r = float(np.mean(vv)), float(np.mean(vp)), float(np.mean(vg)), float(np.mean(vr))
        print(f"Epoch {epoch}: val_v={v_v:.3f}  val_pix={v_p:.1f}px  val_grip_acc={v_g:.3f}  val_rot_acc={v_r:.3f}")
        wandb.log({"epoch_end/val_v": v_v, "epoch_end/val_pix": v_p,
                   "epoch_end/val_grip_acc": v_g, "epoch_end/val_rot_acc": v_r,
                   "epoch_end/epoch": epoch}, step=global_step)

        is_save = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "args": vars(args),
                    "n_rot_bins":    int(full.n_rot_bins),
                    "min_height":    float(full.min_height),
                    "max_height":    float(full.max_height),
                    "min_grip":      float(full.min_grip),
                    "max_grip":      float(full.max_grip),
                    "min_rot":       list(full.min_rot),     # per-axis (3,)
                    "max_rot":       list(full.max_rot)}     # per-axis (3,)
            torch.save(ckpt, ckpt_dir / "latest.pth")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
