"""Train DA3VolumeModel on smith300 — factored KV-attention volume.

Losses (per timestep, masked by gt_pix_valid):
  - volume CE over flat (Z × H × W) with target = z*(HW) + v*W + u (joint pixel+height)
  - depth L1 distillation vs precomputed frozen DA3-LARGE depth (optional)
  EMA loss balancing: each loss divided by its running EMA so they contribute equally.

Metrics:
  - val/pix_err_504  (argmax of marginalised heatmap = sum over Z, then decode pixel)
  - val/z_bin_err    (argmax of marginalised z-distrib at GT pixel)
  - val/joint_err    (argmax over (Z×H×W) — joint top-1 accuracy as a fraction)

Wandb visualizations (every --vis_every_steps):
  1. Rainbow predicted keypoints over GT (white)  — from marginalised heatmap
  2. Per-timestep heatmap strip (marginalised over Z)
  3. Per-timestep height distribution (line plot per t, vertical bar at GT z)
  4. Depth pred vs GT side-by-side
  5. DINO feature PCA
"""
import argparse, os, sys, time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/da3_repo/src")

from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT, N_WINDOW, N_HEIGHT_BINS
from model_da3_volume import DA3VolumeModel


def rainbow_bgr(t, T):
    h = int(180.0 * (t / max(T - 1, 1)))
    return tuple(int(x) for x in cv2.cvtColor(np.uint8([[[h, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0])


def rainbow_keypoints_overlay(rgb_chw, pred_pix, gt_pix, valid):
    """rgb_chw: (3,H,W) [0,1]; pred_pix, gt_pix: (T, 2) in 504-space; valid: (T,) bool."""
    img = (rgb_chw.permute(1, 2, 0).numpy() * 255).astype(np.uint8).copy()
    H, W = img.shape[:2]
    T = pred_pix.shape[0]
    # GT polyline (white)
    pts_gt = [tuple(int(c) for c in p) for i, p in enumerate(gt_pix) if valid[i]]
    for i in range(len(pts_gt) - 1):
        cv2.line(img, pts_gt[i], pts_gt[i + 1], (255, 255, 255), 1)
    for p in pts_gt:
        cv2.drawMarker(img, p, (255, 255, 255), cv2.MARKER_CROSS, 6, 1)
    # Pred polyline (rainbow)
    pts_pred = [tuple(int(c) for c in p) for p in pred_pix]
    for i in range(T - 1):
        col = rainbow_bgr(i, T)
        cv2.line(img, pts_pred[i], pts_pred[i + 1], col, 2)
    for i, p in enumerate(pts_pred):
        col = rainbow_bgr(i, T)
        cv2.drawMarker(img, p, col, cv2.MARKER_TILTED_CROSS, 9, 2)
        cv2.putText(img, str(i), (p[0] + 5, p[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, col, 1)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def marginal_heatmap_strip(vol_logits, H_out):
    """vol_logits: (T, Z, h, w) raw — softmax over (Z,h,w), marginalise over Z, viz."""
    T = vol_logits.shape[0]
    tiles = []
    for t in range(T):
        v = vol_logits[t]  # (Z, h, w)
        # joint softmax over (Z, h, w), marginalize over Z
        p_joint = torch.softmax(v.reshape(-1), dim=0).reshape(v.shape)
        h_map = p_joint.sum(dim=0)
        h_map = (h_map - h_map.min()) / (h_map.max() - h_map.min() + 1e-8)
        h8 = (h_map.cpu().numpy() * 255).astype(np.uint8)
        col = cv2.applyColorMap(h8, cv2.COLORMAP_INFERNO)
        col = cv2.resize(col, (H_out, H_out), interpolation=cv2.INTER_NEAREST)
        cv2.putText(col, str(t), (4, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        tiles.append(col)
    return cv2.cvtColor(np.concatenate(tiles, axis=1), cv2.COLOR_BGR2RGB)


def height_dist_strip(vol_logits, gt_pix, gt_z_bin, valid):
    """For each t, draw the height distribution at the GT pixel + a vertical bar at gt_z_bin.
    vol_logits: (T, Z, h, w); gt_pix in 504-space; gt_z_bin: (T,)."""
    T, Z, h, w = vol_logits.shape
    scale_x = w / DA3_INPUT; scale_y = h / DA3_INPUT
    tiles = []
    for t in range(T):
        if not valid[t]:
            tiles.append(np.zeros((96, Z * 6 + 20, 3), dtype=np.uint8))
            continue
        gx = int(min(max(0, gt_pix[t, 0] * scale_x), w - 1))
        gy = int(min(max(0, gt_pix[t, 1] * scale_y), h - 1))
        # Softmax over (Z, h, w) joint; condition on (gy, gx) → distribution over Z.
        v = vol_logits[t]
        p_joint = torch.softmax(v.reshape(-1), dim=0).reshape(v.shape)
        p_z = p_joint[:, gy, gx]
        p_z = (p_z / (p_z.sum() + 1e-8)).cpu().numpy()
        bar = np.zeros((80, Z * 6 + 20, 3), dtype=np.uint8)
        for z in range(Z):
            h_bar = int(p_z[z] / (p_z.max() + 1e-8) * 70)
            cv2.rectangle(bar, (10 + z * 6, 75 - h_bar), (14 + z * 6, 75), (100, 200, 100), -1)
        # GT bar in red
        z_gt = int(gt_z_bin[t])
        cv2.rectangle(bar, (10 + z_gt * 6, 5), (14 + z_gt * 6, 75), (0, 0, 255), 1)
        cv2.putText(bar, f"t{t}", (2, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        tiles.append(bar)
    out = np.concatenate(tiles, axis=0)
    return cv2.cvtColor(out, cv2.COLOR_BGR2RGB)


def depth_strip(pred_d, gt_d):
    p = pred_d.detach().cpu().numpy(); g = gt_d.detach().cpu().numpy()
    p_lo, p_hi = np.percentile(p, [2, 98]); p = np.clip((p - p_lo) / max(p_hi - p_lo, 1e-6), 0, 1)
    g_lo, g_hi = np.percentile(g, [2, 98]); g = np.clip((g - g_lo) / max(g_hi - g_lo, 1e-6), 0, 1)
    p8 = (p * 255).astype(np.uint8); g8 = (g * 255).astype(np.uint8)
    p_col = cv2.applyColorMap(p8, cv2.COLORMAP_MAGMA); g_col = cv2.applyColorMap(g8, cv2.COLORMAP_MAGMA)
    return cv2.cvtColor(np.concatenate([p_col, g_col], axis=1), cv2.COLOR_BGR2RGB)


def dino_pca(feats):
    """feats: (B*S, T, C) — take batch 0, PCA over tokens to 3 components, reshape to a grid."""
    f = feats[0].detach().float().cpu().numpy() if hasattr(feats, 'detach') else np.asarray(feats[0])
    if f.ndim != 2:
        f = f.reshape(-1, f.shape[-1])
    f = f - f.mean(0, keepdims=True)
    try:
        u, s, vt = np.linalg.svd(f, full_matrices=False)
        pcs = f @ vt[:3].T
    except Exception:
        pcs = f[:, :3]
    pcs = (pcs - pcs.min(0)) / (pcs.max(0) - pcs.min(0) + 1e-8)
    n = pcs.shape[0]
    side = int(round(np.sqrt(n)))
    if side * side != n:
        side = int(np.floor(np.sqrt(n))); pcs = pcs[:side * side]
    img = (pcs.reshape(side, side, 3) * 255).astype(np.uint8)
    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_NEAREST)
    return img


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--vis_every_steps", type=int, default=25)
    p.add_argument("--log_scalars_every", type=int, default=5)
    p.add_argument("--num_workers", type=int, default=0)
    p.add_argument("--overfit_episode", type=int, default=-1)
    p.add_argument("--overfit_sample", type=int, default=-1)
    p.add_argument("--use_ema_loss_balance", type=int, default=1)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--ema_momentum", type=float, default=0.99)
    p.add_argument("--depth_loss_weight", type=float, default=1.0)
    p.add_argument("--depth_loss_type", type=str, default="log_l1", choices=["l1", "log_l1"])
    p.add_argument("--run_name", type=str, default="da3_volume_LARGE_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    p.add_argument("--weights_path", type=str, default="/data/cameron/da3_large_weights")
    p.add_argument("--depth_subdir", type=str, default="da3_depth_large")
    p.add_argument("--save_every_epochs", type=int, default=5)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("Loading dataset...")
    full = Smith300DA3VolumeDataset(root_dir=args.root_dir, image_size=DA3_INPUT,
                                     n_window=N_WINDOW, frame_stride=args.frame_stride,
                                     depth_subdir=args.depth_subdir)
    if args.overfit_sample >= 0:
        full = torch.utils.data.Subset(full, [args.overfit_sample])
        train_ds = val_ds = full
        print(f"  OVERFIT mode — single sample {args.overfit_sample}")
    elif args.overfit_episode >= 0:
        keep = [i for i, (ep, _t) in enumerate(full.samples) if ep == args.overfit_episode]
        full = torch.utils.data.Subset(full, keep)
        train_ds = val_ds = full
    else:
        n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
        train_ds, val_ds = torch.utils.data.random_split(
            full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
        print(f"  Train: {n_tr}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True,
                              persistent_workers=args.num_workers > 0)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True,
                              persistent_workers=args.num_workers > 0)

    print("Building model...")
    model = DA3VolumeModel(weights_path=args.weights_path).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)
    ema = {"volume": 1.0, "depth": 1.0}
    mom = args.ema_momentum

    best_val = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb      = batch["rgb"].to(device, non_blocking=True)
            gt_pix   = batch["gt_pix_504"].to(device, non_blocking=True)
            gt_z_bin = batch["gt_z_bin"].to(device, non_blocking=True)
            valid    = batch["gt_pix_valid"].to(device, non_blocking=True)
            gt_depth = batch["da3_depth"].to(device, non_blocking=True)
            B = rgb.shape[0]

            out = model(rgb)
            vol = out["volume_logits"]              # (B, T, Z, h_out, w_out)
            pred_d = out["pred_depth"]
            _, T, Z, h_out, w_out = vol.shape

            # Volume CE per timestep over flat (Z × h × w)
            scale_x = w_out / DA3_INPUT
            scale_y = h_out / DA3_INPUT
            gx = (gt_pix[..., 0] * scale_x).long().clamp(0, w_out - 1)    # (B, T)
            gy = (gt_pix[..., 1] * scale_y).long().clamp(0, h_out - 1)    # (B, T)
            gz = gt_z_bin.clamp(0, Z - 1)                                  # (B, T)
            tgt_flat = gz * (h_out * w_out) + gy * w_out + gx              # (B, T)
            mask = valid.float().reshape(-1)
            denom = mask.sum().clamp_min(1.0)
            ce_per_step = F.cross_entropy(
                vol.reshape(B * T, Z * h_out * w_out),
                tgt_flat.reshape(-1),
                reduction='none',
            )
            volume_loss = (ce_per_step * mask).sum() / denom

            # Depth distill
            if args.depth_loss_type == "l1":
                depth_loss = F.l1_loss(pred_d, gt_depth)
            else:
                eps = 1e-6
                depth_loss = F.l1_loss(torch.log(pred_d.clamp(min=eps)),
                                       torch.log(gt_depth.clamp(min=eps)))

            ema["volume"] = mom * ema["volume"] + (1 - mom) * float(volume_loss.detach())
            ema["depth"]  = mom * ema["depth"]  + (1 - mom) * float(depth_loss.detach())
            if args.use_ema_loss_balance:
                scaled_v = volume_loss / (ema["volume"] + 1e-8)
                scaled_d = depth_loss  / (ema["depth"]  + 1e-8)
                total = scaled_v + args.depth_loss_weight * scaled_d
            else:
                total = volume_loss + args.depth_loss_weight * depth_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(v=f"{volume_loss.item():.3f}", d=f"{depth_loss.item():.3f}",
                              w_v=f"{1/ema['volume']:.2g}", w_d=f"{1/ema['depth']:.2g}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                with torch.no_grad():
                    # Joint argmax → (z, y, x)
                    pred_flat = vol.reshape(B, T, -1).argmax(dim=-1)
                    pred_z = pred_flat // (h_out * w_out)
                    pred_yx = pred_flat % (h_out * w_out)
                    py = (pred_yx // w_out).float() / scale_y
                    px = (pred_yx % w_out).float() / scale_x
                    pred_pix = torch.stack([px, py], dim=-1)
                    pix_err = (pred_pix - gt_pix).norm(dim=-1) * valid.float()
                    pix_err = pix_err.sum() / denom
                    z_err = (pred_z - gz).abs().float() * valid.float()
                    z_err = z_err.sum() / denom
                    joint_acc = ((pred_z == gz) & (pred_yx == gy * w_out + gx)).float()
                    joint_acc = (joint_acc * valid.float()).sum() / denom
                wandb.log({"train/total": total.item(),
                           "train/volume_loss": volume_loss.item(),
                           "train/depth_loss": depth_loss.item(),
                           "train/pix_err_504": pix_err.item(),
                           "train/z_bin_err": z_err.item(),
                           "train/joint_top1_acc": joint_acc.item(),
                           "train/ema_v": ema["volume"], "train/ema_d": ema["depth"],
                           "epoch": epoch}, step=global_step)

            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                model.eval()
                with torch.no_grad():
                    rand_idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    s = train_ds[rand_idx]
                    v_rgb = s["rgb"].unsqueeze(0).to(device)
                    v_pix = s["gt_pix_504"].cpu().numpy()
                    v_z = s["gt_z_bin"].cpu().numpy()
                    v_valid = s["gt_pix_valid"].cpu().numpy()
                    v_d_gt = s["da3_depth"].to(device)
                    vo = model(v_rgb)
                    v_vol = vo["volume_logits"][0]                          # (T, Z, h, w)
                    v_pred_d = vo["pred_depth"][0]
                    h, w = v_vol.shape[-2:]
                    flat = v_vol.reshape(T, -1).argmax(dim=-1).cpu().numpy()
                    pz = flat // (h * w)
                    pyx = flat % (h * w)
                    py = (pyx // w).astype(np.float32) / (h / DA3_INPUT)
                    px = (pyx % w).astype(np.float32) / (w / DA3_INPUT)
                    v_pred_pix = np.stack([px, py], axis=-1)
                viz_kp = rainbow_keypoints_overlay(s["rgb"], v_pred_pix, v_pix, v_valid)
                viz_hm = marginal_heatmap_strip(v_vol, H_out=144)
                viz_zd = height_dist_strip(v_vol, s["gt_pix_504"].numpy(), v_z, v_valid)
                viz_d  = depth_strip(v_pred_d, v_d_gt)
                viz_pca = dino_pca(vo["dino_feats"][-1])
                wandb.log({
                    "vis/keypoints":         wandb.Image(viz_kp),
                    "vis/heatmap_marginal":  wandb.Image(viz_hm),
                    "vis/height_dist":       wandb.Image(viz_zd),
                    "vis/depth_pred_vs_gt":  wandb.Image(viz_d),
                    "vis/dino_pca":          wandb.Image(viz_pca),
                }, step=global_step)
                model.train()

        # End-of-epoch validation
        model.eval(); vv, vd, vp, vz, vj = [], [], [], [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device); gt_pix = batch["gt_pix_504"].to(device)
                gt_z_bin = batch["gt_z_bin"].to(device)
                valid = batch["gt_pix_valid"].to(device); gt_depth = batch["da3_depth"].to(device)
                out = model(rgb); vol = out["volume_logits"]; pred_d = out["pred_depth"]
                Bv, T, Z, h_out, w_out = vol.shape
                scale_x = w_out / DA3_INPUT; scale_y = h_out / DA3_INPUT
                gx = (gt_pix[..., 0] * scale_x).long().clamp(0, w_out - 1)
                gy = (gt_pix[..., 1] * scale_y).long().clamp(0, h_out - 1)
                gz = gt_z_bin.clamp(0, Z - 1)
                tgt_flat = gz * (h_out * w_out) + gy * w_out + gx
                mask = valid.float().reshape(-1); denom = mask.sum().clamp_min(1.0)
                ce = F.cross_entropy(vol.reshape(Bv * T, -1), tgt_flat.reshape(-1), reduction='none')
                vv.append((ce * mask).sum().item() / denom.item())
                if args.depth_loss_type == "l1":
                    vd.append(F.l1_loss(pred_d, gt_depth).item())
                else:
                    eps = 1e-6
                    vd.append(F.l1_loss(torch.log(pred_d.clamp(min=eps)),
                                        torch.log(gt_depth.clamp(min=eps))).item())
                pred_flat = vol.reshape(Bv, T, -1).argmax(dim=-1)
                pred_z = pred_flat // (h_out * w_out)
                pred_yx = pred_flat % (h_out * w_out)
                py = (pred_yx // w_out).float() / scale_y
                px = (pred_yx % w_out).float() / scale_x
                pred_pix = torch.stack([px, py], dim=-1)
                pe = (pred_pix - gt_pix).norm(dim=-1) * valid.float()
                vp.append(pe.sum().item() / denom.item())
                ze = (pred_z - gz).abs().float() * valid.float()
                vz.append(ze.sum().item() / denom.item())
                ja = ((pred_z == gz) & (pred_yx == gy * w_out + gx)).float() * valid.float()
                vj.append(ja.sum().item() / denom.item())

        v_v = float(np.mean(vv)); v_d = float(np.mean(vd))
        v_pe = float(np.mean(vp)); v_ze = float(np.mean(vz)); v_ja = float(np.mean(vj))
        print(f"Epoch {epoch}: val_v={v_v:.4f}  val_d={v_d:.4f}  val_pix={v_pe:.1f}px  "
              f"val_z_err={v_ze:.2f}bins  val_joint_top1={v_ja:.3f}")
        wandb.log({"epoch_end/val_v": v_v, "epoch_end/val_d": v_d, "epoch_end/val_pix": v_pe,
                   "epoch_end/val_z_err": v_ze, "epoch_end/val_joint_top1": v_ja,
                   "epoch": epoch, "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        is_save_epoch = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save_epoch:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "val_v": v_v, "args": vars(args)}
            torch.save(ckpt, ckpt_dir / "latest.pth")
            if v_v < best_val:
                best_val = v_v
                torch.save(ckpt, ckpt_dir / "best.pth")
                print(f"  ✓ best (val_v={v_v:.4f})")
        elif v_v < best_val:
            best_val = v_v

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
