"""Train DinoHeatmapDiffusion on smith300.

Loss: MSE between predicted ε and sampled ε (standard DDPM).

Tracked metrics:
  train/mse_loss          — DDPM MSE loss
  train/pix_err_argmax    — argmax of x0_hat (decoded from current sample) vs GT, in 504-space.
    Computed via a single x_T → x_0 path with a small number of sampling steps, periodically
    (every --sample_every_steps). Expensive, so don't run every batch.

Headline accuracy (per Cameron): full-train-set argmax pix err. Logged epoch-end.
"""
import argparse, os, sys, time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
sys.path.insert(0, "/data/cameron/da3_repo/src")

from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT, N_WINDOW
from model_dino_diffusion import (DinoHeatmapDiffusion, make_gaussian_heatmap,
                                   HEATMAP_RES, GAUSSIAN_SIGMA)


def rainbow_bgr(t, T):
    h = int(180.0 * (t / max(T - 1, 1)))
    return tuple(int(x) for x in cv2.cvtColor(np.uint8([[[h, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0])


def rainbow_keypoints_overlay(rgb_chw, pred_pix, gt_pix, valid):
    img = (rgb_chw.permute(1, 2, 0).numpy() * 255).astype(np.uint8).copy()
    T = pred_pix.shape[0]
    pts_gt = [tuple(int(c) for c in p) for i, p in enumerate(gt_pix) if valid[i]]
    for i in range(len(pts_gt) - 1):
        cv2.line(img, pts_gt[i], pts_gt[i+1], (255, 255, 255), 1)
    for p in pts_gt: cv2.drawMarker(img, p, (255, 255, 255), cv2.MARKER_CROSS, 6, 1)
    pts_pred = [tuple(int(c) for c in p) for p in pred_pix]
    for i in range(T - 1):
        col = rainbow_bgr(i, T)
        cv2.line(img, pts_pred[i], pts_pred[i+1], col, 2)
    for i, p in enumerate(pts_pred):
        col = rainbow_bgr(i, T)
        cv2.drawMarker(img, p, col, cv2.MARKER_TILTED_CROSS, 10, 2)
        cv2.putText(img, str(i), (p[0]+5, p[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, col, 1)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def heatmap_strip(heat, H_out=128):
    """heat: (T, H, W) — clip + colormap + tile."""
    T = heat.shape[0]
    tiles = []
    for t in range(T):
        h = heat[t].detach().cpu().numpy()
        h = (h - h.min()) / (h.max() - h.min() + 1e-8)
        h8 = (h * 255).astype(np.uint8)
        col = cv2.applyColorMap(h8, cv2.COLORMAP_INFERNO)
        col = cv2.resize(col, (H_out, H_out), interpolation=cv2.INTER_NEAREST)
        cv2.putText(col, str(t), (4, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        tiles.append(col)
    return cv2.cvtColor(np.concatenate(tiles, axis=1), cv2.COLOR_BGR2RGB)


def argmax_pix_from_heatmap(hm, image_size=DA3_INPUT):
    """hm: (B, T, H, W). Returns pred pixel (B, T, 2) in image_size space."""
    B, T, H, W = hm.shape
    flat = hm.reshape(B, T, -1).argmax(dim=-1)
    py = (flat // W).float() * (image_size / H)
    px = (flat %  W).float() * (image_size / W)
    return torch.stack([px, py], dim=-1)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--depth_subdir", type=str, default="da3_depth_large")
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--val_split", type=float, default=0.05)
    p.add_argument("--vis_every_steps", type=int, default=200)
    p.add_argument("--log_scalars_every", type=int, default=10)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--sigma_px", type=float, default=GAUSSIAN_SIGMA)
    p.add_argument("--T_diff", type=int, default=1000)
    p.add_argument("--sample_steps", type=int, default=10)
    p.add_argument("--freeze_backbone", type=int, default=1)
    p.add_argument("--save_every_epochs", type=int, default=5)
    p.add_argument("--run_name", type=str, default="da3_dino_diffusion_v0")
    p.add_argument("--wandb_project", type=str, default="para_libero")
    p.add_argument("--wandb_mode", type=str, default="online")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    script_dir = Path(__file__).parent
    ckpt_dir = script_dir / "checkpoints" / args.run_name
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode, config=vars(args))

    print("Loading dataset…")
    full = Smith300DA3VolumeDataset(root_dir=args.root_dir, image_size=DA3_INPUT,
                                     n_window=N_WINDOW, depth_subdir=args.depth_subdir)
    n = len(full); n_val = max(1, int(n * args.val_split)); n_tr = n - n_val
    train_ds, val_ds = torch.utils.data.random_split(
        full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))
    print(f"  Train: {n_tr}  Val: {n_val}")
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True,
                              persistent_workers=args.num_workers > 0)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True,
                              persistent_workers=args.num_workers > 0)

    print("Building model…")
    model = DinoHeatmapDiffusion(T_diff=args.T_diff, freeze_backbone=bool(args.freeze_backbone)).to(device)
    n_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    wandb.config.update({"trainable_params": n_t})
    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr, weight_decay=1e-4)

    best_train_pix = float("inf"); global_step = 0; t_start = time.time()
    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            rgb     = batch["rgb"].to(device, non_blocking=True)
            gt_pix  = batch["gt_pix_504"].to(device, non_blocking=True)
            valid   = batch["gt_pix_valid"].to(device, non_blocking=True)
            B = rgb.shape[0]
            # Build GT heatmap stack (B, T, H, W)
            x0 = make_gaussian_heatmap(gt_pix, N_WINDOW, HEATMAP_RES, HEATMAP_RES,
                                        args.sigma_px, image_size=DA3_INPUT, device=device)
            # Sample diffusion step
            t = torch.randint(0, model.T_diff, (B,), device=device)
            noise = torch.randn_like(x0)
            x_t = model.q_sample(x0, t, noise)
            eps_pred = model(rgb, x_t, t)
            # Masked MSE: zero out timesteps where future is clamped (valid=False)
            mse_per = ((eps_pred - noise) ** 2).mean(dim=(2, 3))                # (B, T)
            mask = valid.float()
            denom = mask.sum().clamp_min(1.0)
            loss = (mse_per * mask).sum() / denom

            opt.zero_grad(); loss.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()
            pbar.set_postfix(mse=f"{loss.item():.4f}")
            global_step += 1

            if global_step % args.log_scalars_every == 0:
                wandb.log({"train/mse_loss": loss.item(), "epoch": epoch}, step=global_step)

            if args.vis_every_steps > 0 and global_step % args.vis_every_steps == 0:
                model.eval()
                with torch.no_grad():
                    rand_idx = int(torch.randint(0, len(train_ds), (1,)).item())
                    s = train_ds[rand_idx]
                    v_rgb = s["rgb"].unsqueeze(0).to(device)
                    v_pix = s["gt_pix_504"].cpu().numpy()
                    v_valid = s["gt_pix_valid"].cpu().numpy()
                    samp = model.sample(v_rgb, n_steps=args.sample_steps)        # (1, T, H, W)
                    pred_pix = argmax_pix_from_heatmap(samp).cpu().numpy()[0]
                viz_kp = rainbow_keypoints_overlay(s["rgb"], pred_pix, v_pix, v_valid)
                viz_hm = heatmap_strip(samp[0], H_out=128)
                wandb.log({"vis/keypoints": wandb.Image(viz_kp),
                           "vis/sampled_heatmap": wandb.Image(viz_hm)}, step=global_step)
                model.train()

        # End-of-epoch: full train-set evaluation by sampling
        model.eval()
        with torch.no_grad():
            t_pe = []; n_full8 = 0; t_pe_full8 = []
            for batch in tqdm(train_loader, desc=f"  eval train ep{epoch}", leave=False):
                rgb = batch["rgb"].to(device)
                gt_pix = batch["gt_pix_504"].to(device)
                valid = batch["gt_pix_valid"].to(device)
                samp = model.sample(rgb, n_steps=args.sample_steps)            # (B, T, H, W)
                pred_pix = argmax_pix_from_heatmap(samp)                        # (B, T, 2) in 504
                pe = (pred_pix - gt_pix).norm(dim=-1)
                m = valid.float(); d = m.sum().clamp_min(1.0)
                t_pe.append(((pe * m).sum() / d).item())
                full8 = (m.sum(dim=1) == 8)
                if full8.sum() > 0:
                    n_full8 += full8.sum().item()
                    t_pe_full8.append(pe[full8].mean().item())
            tr_pix_all  = float(np.mean(t_pe))
            tr_pix_full = float(np.mean(t_pe_full8)) if t_pe_full8 else float('nan')

            v_pe = []; v_pe_full = []
            for batch in val_loader:
                rgb = batch["rgb"].to(device)
                gt_pix = batch["gt_pix_504"].to(device)
                valid = batch["gt_pix_valid"].to(device)
                samp = model.sample(rgb, n_steps=args.sample_steps)
                pred_pix = argmax_pix_from_heatmap(samp)
                pe = (pred_pix - gt_pix).norm(dim=-1)
                m = valid.float(); d = m.sum().clamp_min(1.0)
                v_pe.append(((pe * m).sum() / d).item())
                full8 = (m.sum(dim=1) == 8)
                if full8.sum() > 0:
                    v_pe_full.append(pe[full8].mean().item())
            val_pix_all = float(np.mean(v_pe))
            val_pix_full = float(np.mean(v_pe_full)) if v_pe_full else float('nan')

        print(f"Epoch {epoch}: train_pix_argmax(all)={tr_pix_all:.1f}px  train_pix_argmax(full8)={tr_pix_full:.1f}px  "
              f"val_pix(all)={val_pix_all:.1f}px  val_pix(full8)={val_pix_full:.1f}px")
        wandb.log({"epoch_end/train_pix_argmax_all":  tr_pix_all,
                   "epoch_end/train_pix_argmax_full8": tr_pix_full,
                   "epoch_end/val_pix_argmax_all":   val_pix_all,
                   "epoch_end/val_pix_argmax_full8": val_pix_full,
                   "epoch": epoch, "elapsed_min": (time.time() - t_start) / 60.0}, step=global_step)

        is_save = ((epoch + 1) % max(1, args.save_every_epochs) == 0) or (epoch + 1 == args.epochs)
        if is_save:
            ckpt = {"epoch": epoch, "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "train_pix_argmax_full8": tr_pix_full, "args": vars(args)}
            torch.save(ckpt, ckpt_dir / "latest.pth")
            if tr_pix_full < best_train_pix:
                best_train_pix = tr_pix_full
                torch.save(ckpt, ckpt_dir / "best.pth")
                print(f"  ✓ best (train_pix_full8={tr_pix_full:.2f})")

    wandb.finish()
    print(f"Done. Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
