"""Train DA3PixelModel on smith300.

Losses:
  - heatmap CE (per timestep, over 288×288 flat)
  - depth L1 distillation (vs precomputed frozen-DA3 depth)
  EMA loss balancing: each loss is divided by its running EMA so they contribute equally.

Wandb visualizations (every --vis_every_steps):
  1. Rainbow predicted keypoints (8 timesteps, polyline + numbered crosshairs) over GT (white)
  2. Per-timestep heatmap strip (softmax → colormap)
  3. Pred depth vs GT depth side-by-side (color-normalized)
  4. DINO feature PCA (RGB from 3 principal components)
"""
import argparse, os, sys, time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/da3_repo/src")

from data_da3_pixel import Smith300DA3Dataset, DA3_INPUT, N_WINDOW
from model_da3_pixel import DA3PixelModel


def rainbow_bgr(t, T):
    h = int(180.0 * (t / max(T - 1, 1)))
    return tuple(int(x) for x in cv2.cvtColor(np.uint8([[[h, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0])


def heatmap_strip(heat_logits, H_out):
    """heat_logits: (N_WINDOW, h, w) raw — softmax + apply INFERNO colormap + tile."""
    T = heat_logits.shape[0]
    tiles = []
    for t in range(T):
        h = heat_logits[t]
        h = torch.softmax(h.reshape(-1), dim=0).reshape(h.shape).cpu().numpy()
        # Normalize per-frame for visibility
        h = (h - h.min()) / (h.max() - h.min() + 1e-8)
        h8 = (h * 255).astype(np.uint8)
        col = cv2.applyColorMap(h8, cv2.COLORMAP_INFERNO)
        col = cv2.resize(col, (H_out, H_out), interpolation=cv2.INTER_NEAREST)
        cv2.putText(col, str(t), (4, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                    rainbow_bgr(t, T), 2, cv2.LINE_AA)
        tiles.append(col)
    return np.concatenate(tiles, axis=1)


def rainbow_keypoints_overlay(rgb_504, pred_pix_504, gt_pix_504, gt_valid):
    """rgb_504: (3, 504, 504) in [0, 1]; pred_pix/gt_pix in 504-space; gt_valid (T,)."""
    img = (rgb_504.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR).copy()
    T = pred_pix_504.shape[0]
    # GT polyline (white) + circle
    for i in range(T):
        if not gt_valid[i]: continue
        cv2.circle(img, tuple(np.int32(gt_pix_504[i])), 5, (240, 240, 240), 1, cv2.LINE_AA)
    for i in range(1, T):
        if gt_valid[i] and gt_valid[i-1]:
            cv2.line(img, tuple(np.int32(gt_pix_504[i-1])), tuple(np.int32(gt_pix_504[i])),
                     (220, 220, 220), 2, cv2.LINE_AA)
    # Rainbow polyline + crosshairs + labels
    for i in range(1, T):
        cv2.line(img, tuple(np.int32(pred_pix_504[i-1])), tuple(np.int32(pred_pix_504[i])),
                 rainbow_bgr(i, T), 2, cv2.LINE_AA)
    for i in range(T):
        c = rainbow_bgr(i, T)
        cv2.drawMarker(img, tuple(np.int32(pred_pix_504[i])), c, cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
        cv2.putText(img, str(i), (int(pred_pix_504[i, 0]) + 6, int(pred_pix_504[i, 1]) - 6),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.45, c, 1, cv2.LINE_AA)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def depth_strip(pred_depth, gt_depth):
    """Both (H, W). Color-normalize side-by-side."""
    def norm_color(d):
        d = d.cpu().numpy()
        lo, hi = np.percentile(d, [2, 98])
        d = np.clip((d - lo) / (hi - lo + 1e-8), 0, 1)
        return cv2.applyColorMap((d * 255).astype(np.uint8), cv2.COLORMAP_VIRIDIS)
    return np.concatenate([norm_color(pred_depth), norm_color(gt_depth)], axis=1)


def dino_pca(dino_feats_layer, B_idx=0):
    """dino_feats_layer can be tensor of shape (B, T, C) OR (B, S, T, C). Pick B_idx and
    take the LAST sqrt(N)×sqrt(N) tokens as the spatial grid (drops cls/register prefix).
    PCA to 3 components → resize to 288×288 RGB."""
    f = dino_feats_layer
    while f.dim() > 2:
        f = f[B_idx if f.shape[0] > B_idx else 0]                # collapse leading dims one at a time
    f_np = f.detach().float().cpu().numpy()                       # (T, C)
    # Take the last n² patches (skip any prefix tokens like cls/register)
    T = f_np.shape[0]
    n = int(np.floor(np.sqrt(T)))
    f_np = f_np[-(n * n):]
    f_np = f_np - f_np.mean(axis=0, keepdims=True)
    # SVD-based PCA, top 3 components
    U, S, Vt = np.linalg.svd(f_np, full_matrices=False)
    pcs = f_np @ Vt[:3].T                                         # (n², 3)
    pcs = (pcs - pcs.min(0)) / (pcs.max(0) - pcs.min(0) + 1e-8)
    img = (pcs * 255).astype(np.uint8).reshape(n, n, 3)
    return cv2.resize(img, (288, 288), interpolation=cv2.INTER_NEAREST)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--vis_every_steps", type=int, default=20)
    p.add_argument("--log_scalars_every", type=int, default=5)
    p.add_argument("--num_workers", type=int, default=0,
                   help="DataLoader workers. 0 = main process (cheap, no copy-on-fork). "
                        "With in-memory tensors, workers share memory via CoW on Linux fork.")
    p.add_argument("--overfit_episode", type=int, default=-1,
                   help="If >=0, restrict to one episode's samples for an overfit sanity check.")
    p.add_argument("--overfit_sample", type=int, default=-1,
                   help="If >=0, restrict to ONE specific sample index. Takes precedence over --overfit_episode.")
    p.add_argument("--use_ema_loss_balance", type=int, default=1,
                   help="If 0, just sum the losses (with depth weight) without EMA scaling.")
    p.add_argument("--grad_clip", type=float, default=0.0,
                   help="If >0, clip gradient norm to this value before optimizer.step().")
    p.add_argument("--ema_momentum", type=float, default=0.99)
    p.add_argument("--depth_loss_weight", type=float, default=1.0,
                   help="Multiplier on the depth distillation loss in the total. Set 0 to ablate.")
    p.add_argument("--depth_loss_type", type=str, default="log_l1", choices=["l1", "log_l1"],
                   help="l1 = L1 on raw depth (penalizes near/far equally in metres). "
                        "log_l1 = L1 in log-depth space (matches DA3's exp activation; "
                        "equivalent to L1 on the head's pre-exp raw logits).")
    p.add_argument("--run_name", type=str, default="da3_pixel_smith300_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    p.add_argument("--weights_path", type=str, default="/data/cameron/da3_weights",
                   help="DA3 weights dir (SMALL=/data/cameron/da3_weights, LARGE=/data/cameron/da3_large_weights)")
    p.add_argument("--depth_subdir", type=str, default="da3_depth",
                   help="Sub-dir name to load cached depth from (da3_depth or da3_depth_large)")
    p.add_argument("--save_every_epochs", type=int, default=1,
                   help="Save latest.pth every N epochs (best.pth still on improvement). For LARGE/GIANT overfit runs, set ~50.")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("Loading dataset...")
    full = Smith300DA3Dataset(root_dir=args.root_dir, image_size=DA3_INPUT,
                               n_window=N_WINDOW, frame_stride=args.frame_stride,
                               depth_subdir=args.depth_subdir)
    if args.overfit_sample >= 0:
        keep = [args.overfit_sample]
        full = torch.utils.data.Subset(full, keep)
        train_ds = full; val_ds = full
        print(f"  OVERFIT mode — single sample {args.overfit_sample}")
    elif args.overfit_episode >= 0:
        keep = [i for i, (ep, _t) in enumerate(full.samples) if ep == args.overfit_episode]
        if not keep:
            raise SystemExit(f"No samples for overfit_episode={args.overfit_episode}")
        full = torch.utils.data.Subset(full, keep)
        train_ds = full; val_ds = full
        print(f"  OVERFIT mode — episode {args.overfit_episode}: {len(keep)} samples")
    else:
        n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
        train_ds, val_ds = torch.utils.data.random_split(
            full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
        print(f"  Train: {n_tr}  Val: {n_val}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True,
                              persistent_workers=args.num_workers > 0)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True,
                              persistent_workers=args.num_workers > 0)

    print("Building model...")
    model = DA3PixelModel(weights_path=args.weights_path).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})

    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)

    # EMA loss tracking
    ema = {"heatmap": 1.0, "depth": 1.0}
    mom = args.ema_momentum

    best_val = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb     = batch["rgb"].to(device, non_blocking=True)
            gt_pix  = batch["gt_pix_504"].to(device, non_blocking=True)
            valid   = batch["gt_pix_valid"].to(device, non_blocking=True)
            gt_depth = batch["da3_depth"].to(device, non_blocking=True)
            B = rgb.shape[0]

            out = model(rgb)
            pred_h = out["pred_heatmap"]        # (B, N_WINDOW, h_out, w_out)
            pred_d = out["pred_depth"]          # (B, H_in, W_in)

            # Heatmap CE per timestep: scale GT pixels from 504 → h_out
            h_out, w_out = pred_h.shape[-2:]
            scale_x = w_out / DA3_INPUT
            scale_y = h_out / DA3_INPUT
            gx = (gt_pix[..., 0] * scale_x).long().clamp(0, w_out - 1)
            gy = (gt_pix[..., 1] * scale_y).long().clamp(0, h_out - 1)
            tgt_flat = gy * w_out + gx                                          # (B, N_WINDOW)
            mask = valid.float().reshape(-1)
            denom = mask.sum().clamp_min(1.0)
            ce_per_step = F.cross_entropy(
                pred_h.reshape(B * N_WINDOW, -1),
                tgt_flat.reshape(-1),
                reduction='none',
            )                                                                   # (B*N_WINDOW,)
            heatmap_loss = (ce_per_step * mask).sum() / denom

            # Depth distillation
            if args.depth_loss_type == "l1":
                depth_loss = F.l1_loss(pred_d, gt_depth)
            else:  # log_l1 — matches DA3's exp activation; same as L1 on pre-exp logits
                eps = 1e-6
                depth_loss = F.l1_loss(torch.log(pred_d.clamp(min=eps)),
                                       torch.log(gt_depth.clamp(min=eps)))

            # EMA balancing
            ema["heatmap"] = mom * ema["heatmap"] + (1 - mom) * float(heatmap_loss.detach())
            ema["depth"]   = mom * ema["depth"]   + (1 - mom) * float(depth_loss.detach())
            if args.use_ema_loss_balance:
                scaled_h = heatmap_loss / (ema["heatmap"] + 1e-8)
                scaled_d = depth_loss   / (ema["depth"]   + 1e-8)
                total = scaled_h + args.depth_loss_weight * scaled_d
            else:
                total = heatmap_loss + args.depth_loss_weight * depth_loss

            opt.zero_grad(); total.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(h=f"{heatmap_loss.item():.3f}", d=f"{depth_loss.item():.3f}",
                              w_h=f"{1/ema['heatmap']:.2g}", w_d=f"{1/ema['depth']:.2g}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                # Per-step pixel-error: argmax of pred → upscale → distance to GT
                with torch.no_grad():
                    pred_flat = pred_h.reshape(B, N_WINDOW, -1).argmax(dim=-1)
                    py = (pred_flat // w_out).float() / scale_y
                    px = (pred_flat %  w_out).float() / scale_x
                    pred_pix = torch.stack([px, py], dim=-1)                    # (B, N_WINDOW, 2)
                    pix_err = (pred_pix - gt_pix).norm(dim=-1) * valid.float()
                    pix_err = pix_err.sum() / denom
                wandb.log({"train/total": total.item(),
                           "train/heatmap_loss": heatmap_loss.item(),
                           "train/depth_loss": depth_loss.item(),
                           "train/pix_err_504": pix_err.item(),
                           "train/ema_h": ema["heatmap"], "train/ema_d": ema["depth"],
                           "epoch": epoch}, step=global_step)

            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                # Random TRAIN sample each viz call so we see variety
                model.eval()
                with torch.no_grad():
                    rand_idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    val_batch_sample = train_ds[rand_idx]
                    v_rgb = val_batch_sample["rgb"].unsqueeze(0).to(device)
                    v_pix = val_batch_sample["gt_pix_504"].cpu().numpy()
                    v_valid = val_batch_sample["gt_pix_valid"].cpu().numpy()
                    v_d_gt = val_batch_sample["da3_depth"].to(device)
                    val_batch = {"rgb": v_rgb.cpu()}   # for the rainbow overlay function
                    v_out = model(v_rgb)
                    v_pred_h = v_out["pred_heatmap"][0]                          # (N_WINDOW, h, w)
                    v_pred_d = v_out["pred_depth"][0]                            # (504, 504)
                    # Decode pred pixels in 504-space
                    h, w = v_pred_h.shape[-2:]
                    flat = v_pred_h.reshape(N_WINDOW, -1).argmax(dim=-1).cpu().numpy()
                    py = (flat // w).astype(np.float32) / (h / DA3_INPUT)
                    px = (flat %  w).astype(np.float32) / (w / DA3_INPUT)
                    v_pred_pix = np.stack([px, py], axis=-1)
                # (1) Rainbow keypoints
                viz_kp = rainbow_keypoints_overlay(val_batch_sample["rgb"], v_pred_pix, v_pix, v_valid)
                # (2) Heatmap strip
                viz_hm = heatmap_strip(v_pred_h, H_out=144)
                # (3) Depth pred vs GT
                viz_d = depth_strip(v_pred_d, v_d_gt)
                # (4) DINO PCA — pick the deepest layer
                feats_last = v_out["dino_feats"][-1]  # tensor (B*S, T, C)
                viz_pca = dino_pca(feats_last)
                wandb.log({
                    "vis/keypoints": wandb.Image(viz_kp),
                    "vis/heatmap_strip": wandb.Image(viz_hm),
                    "vis/depth_pred_vs_gt": wandb.Image(viz_d),
                    "vis/dino_pca": wandb.Image(viz_pca),
                }, step=global_step)
                model.train()

        # End-of-epoch validation
        model.eval(); v_h_losses, v_d_losses, v_pix_errs = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device); gt_pix = batch["gt_pix_504"].to(device)
                valid = batch["gt_pix_valid"].to(device); gt_depth = batch["da3_depth"].to(device)
                out = model(rgb); pred_h = out["pred_heatmap"]; pred_d = out["pred_depth"]
                h_out, w_out = pred_h.shape[-2:]
                scale_x = w_out / DA3_INPUT; scale_y = h_out / DA3_INPUT
                gx = (gt_pix[..., 0] * scale_x).long().clamp(0, w_out - 1)
                gy = (gt_pix[..., 1] * scale_y).long().clamp(0, h_out - 1)
                tgt_flat = gy * w_out + gx
                mask = valid.float().reshape(-1); denom = mask.sum().clamp_min(1.0)
                ce = F.cross_entropy(pred_h.reshape(rgb.shape[0] * N_WINDOW, -1),
                                      tgt_flat.reshape(-1), reduction='none')
                v_h_losses.append((ce * mask).sum().item() / denom.item())
                if args.depth_loss_type == "l1":
                    v_d_losses.append(F.l1_loss(pred_d, gt_depth).item())
                else:
                    eps = 1e-6
                    v_d_losses.append(F.l1_loss(torch.log(pred_d.clamp(min=eps)),
                                                 torch.log(gt_depth.clamp(min=eps))).item())
                pred_flat = pred_h.reshape(rgb.shape[0], N_WINDOW, -1).argmax(dim=-1)
                py = (pred_flat // w_out).float() / scale_y
                px = (pred_flat %  w_out).float() / scale_x
                pred_pix = torch.stack([px, py], dim=-1)
                pe = (pred_pix - gt_pix).norm(dim=-1) * valid.float()
                v_pix_errs.append(pe.sum().item() / denom.item())

        v_h = float(np.mean(v_h_losses)); v_d = float(np.mean(v_d_losses))
        v_pe = float(np.mean(v_pix_errs))
        print(f"Epoch {epoch}: val_h={v_h:.4f}  val_d={v_d:.4f}  val_pix={v_pe:.1f}px")
        wandb.log({"epoch_end/val_h": v_h, "epoch_end/val_d": v_d, "epoch_end/val_pix": v_pe,
                   "epoch": epoch, "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        # Save periodically (default every epoch). best.pth saved when val improves AND
        # this is a save epoch — avoids saving a 4 GB best on every epoch with LARGE.
        is_save_epoch = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save_epoch:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "val_h": v_h, "args": vars(args)}
            torch.save(ckpt, ckpt_dir / "latest.pth")
            if v_h < best_val:
                best_val = v_h
                torch.save(ckpt, ckpt_dir / "best.pth")
                print(f"  ✓ best (val_h={v_h:.4f})")
        elif v_h < best_val:
            best_val = v_h

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
