"""DINOv3-conditioned heatmap diffusion (DDPM).

Per Cameron 2026-05-19: denoise heatmaps directly. The forward process is
  x_t = sqrt(α̅_t) · x_0 + sqrt(1 - α̅_t) · ε,  ε ~ N(0, I)
where x_0 is the GT heatmap stack (B, T, H, W) — Gaussian blob (σ≈2px) at each future GT
pixel. The model predicts ε given (rgb, x_t, t). Standard cosine β schedule, 1000 train
timesteps, 10-step DDIM at inference.

Why heatmaps (not coords): multiple plausible futures = multiple modes in the heatmap;
diffusion samples from one mode per pass rather than collapsing to the centroid.
"""
import os, sys, math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH",
                                    "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMG_SIZE        = 448
N_WINDOW        = 8
HEATMAP_RES     = 56                # output grid (image_size // patch_size * 2)
GAUSSIAN_SIGMA  = 2.0               # px in HEATMAP_RES units
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# Cosine β schedule (Nichol & Dhariwal 2021)
def cosine_betas(T, s=0.008, max_beta=0.999):
    steps = T + 1
    x = torch.linspace(0, T, steps)
    f = torch.cos(((x / T) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = f / f[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(max=max_beta)


class SinusoidalTimeEmb(nn.Module):
    def __init__(self, dim):
        super().__init__(); self.dim = dim
    def forward(self, t):
        # t: (B,) integer or float
        half = self.dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
        a = t.float().unsqueeze(-1) * freqs.unsqueeze(0)                       # (B, half)
        return torch.cat([torch.sin(a), torch.cos(a)], dim=-1)                 # (B, dim)


class ResBlock(nn.Module):
    def __init__(self, ch, t_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, ch); self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, ch); self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.t_proj = nn.Linear(t_dim, ch)
    def forward(self, x, t_e):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.t_proj(t_e)[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return x + h


class SmallUNet(nn.Module):
    """Small 3-level UNet for (B, in_ch, 56, 56) → (B, out_ch, 56, 56)."""
    def __init__(self, in_ch, out_ch, base_ch=128, t_dim=256):
        super().__init__()
        self.t_emb_mlp = nn.Sequential(SinusoidalTimeEmb(t_dim),
                                       nn.Linear(t_dim, t_dim), nn.SiLU(),
                                       nn.Linear(t_dim, t_dim))
        # Encoder
        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.r1 = ResBlock(base_ch, t_dim)
        self.down1 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1)
        self.r2 = ResBlock(base_ch * 2, t_dim)
        self.down2 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1)
        self.r_mid = ResBlock(base_ch * 4, t_dim)
        # Decoder
        self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 4, stride=2, padding=1)
        self.r2b = ResBlock(base_ch * 2, t_dim)
        self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, 4, stride=2, padding=1)
        self.r1b = ResBlock(base_ch, t_dim)
        self.out_conv = nn.Conv2d(base_ch, out_ch, 3, padding=1)
        nn.init.zeros_(self.out_conv.weight); nn.init.zeros_(self.out_conv.bias)  # ε-pred starts at 0

    def forward(self, x, t):
        t_e = self.t_emb_mlp(t)
        h0 = self.in_conv(x); h0 = self.r1(h0, t_e)
        h1 = self.down1(h0); h1 = self.r2(h1, t_e)
        h2 = self.down2(h1); h2 = self.r_mid(h2, t_e)
        u1 = self.up2(h2) + h1; u1 = self.r2b(u1, t_e)
        u0 = self.up1(u1) + h0; u0 = self.r1b(u0, t_e)
        return self.out_conv(u0)


def make_gaussian_heatmap(gt_pix_504, T, H, W, sigma_px, image_size=504, device='cpu'):
    """gt_pix_504: (B, T, 2) GT pixels in image_size space.
    Returns: (B, T, H, W) GT heatmap rescaled to [-1, 1] so diffusion noise (N(0,1)) doesn't
    overwhelm the signal at moderate t. Peak = +1 at GT pixel, far background = -1.
    Previous version was [0, 1] which was the bug — 99% of pixels at 0 with N(0,1) noise
    means signal-to-noise is terrible for most pixels."""
    B = gt_pix_504.shape[0]
    scale_x = W / image_size; scale_y = H / image_size
    cx = gt_pix_504[..., 0] * scale_x                                          # (B, T)
    cy = gt_pix_504[..., 1] * scale_y                                          # (B, T)
    ys = torch.arange(H, device=device, dtype=torch.float32).view(1, 1, H, 1)
    xs = torch.arange(W, device=device, dtype=torch.float32).view(1, 1, 1, W)
    cx = cx.view(B, T, 1, 1); cy = cy.view(B, T, 1, 1)
    g = torch.exp(-((xs - cx) ** 2 + (ys - cy) ** 2) / (2 * sigma_px ** 2))     # peak 1
    return g * 2.0 - 1.0                                                        # → [-1, 1]


class DinoHeatmapDiffusion(nn.Module):
    """DINOv3-conditioned heatmap diffusion model.

    Inputs:  rgb (B,3,IMG,IMG), x_t (B, T, H, W) noisy heatmap stack, t (B,) ints in [0, T_diff)
    Outputs: ε_pred (B, T, H, W) — DDPM noise prediction.
    """
    def __init__(self, n_window=N_WINDOW, image_size=IMG_SIZE,
                 heatmap_res=HEATMAP_RES, cond_dim=64,
                 T_diff: int = 1000, freeze_backbone: bool = True):
        super().__init__()
        self.n_window     = n_window
        self.image_size   = image_size
        self.heatmap_res  = heatmap_res
        self.T_diff       = T_diff
        if DINO_REPO_DIR not in sys.path: sys.path.insert(0, DINO_REPO_DIR)
        self.dino = torch.hub.load(DINO_REPO_DIR, "dinov3_vits16plus",
                                    source="local", weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters(): p.requires_grad_(False)
        self.embed_dim = getattr(self.dino, "embed_dim", 384)
        # Project DINO patch features to cond_dim @ heatmap_res
        self.cond_proj = nn.Sequential(
            nn.Conv2d(self.embed_dim, cond_dim, 1), nn.GroupNorm(8, cond_dim), nn.SiLU(),
            nn.Conv2d(cond_dim, cond_dim, 3, padding=1),
        )
        # UNet input: noisy heatmap (T ch) + cond (cond_dim) ; output: noise (T ch)
        self.unet = SmallUNet(in_ch=n_window + cond_dim, out_ch=n_window,
                              base_ch=128, t_dim=256)

        # Diffusion schedule
        betas = cosine_betas(T_diff)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.register_buffer("betas",                betas, persistent=False)
        self.register_buffer("alphas",               alphas, persistent=False)
        self.register_buffer("alphas_cumprod",       alphas_cumprod, persistent=False)
        self.register_buffer("sqrt_alphas_cumprod",  torch.sqrt(alphas_cumprod), persistent=False)
        self.register_buffer("sqrt_one_minus_alphas_cumprod",
                              torch.sqrt(1.0 - alphas_cumprod), persistent=False)
        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer("std",  torch.tensor(IMAGENET_STD ).view(1, 3, 1, 1), persistent=False)

    def _normalize(self, rgb01):
        return (rgb01 - self.mean) / self.std

    def _cond_features(self, rgb):
        """Return cond features at heatmap_res."""
        B = rgb.shape[0]
        if rgb.shape[-1] != self.image_size:
            rgb = F.interpolate(rgb, size=(self.image_size, self.image_size),
                                mode='bilinear', align_corners=False)
        x = self._normalize(rgb)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats = self.dino.forward_features(x)
        if isinstance(feats, dict):
            patch_tokens = feats.get("x_norm_patchtokens", feats.get("x_prenorm"))
        else:
            patch_tokens = feats
        patch_tokens = patch_tokens.to(torch.float32)
        D = patch_tokens.shape[-1]
        g = self.image_size // DINO_PATCH_SIZE
        feat_2d = patch_tokens.permute(0, 2, 1).reshape(B, D, g, g)
        feat_hm = F.interpolate(feat_2d, size=(self.heatmap_res, self.heatmap_res),
                                 mode='bilinear', align_corners=False)
        return self.cond_proj(feat_hm)                                         # (B, cond_dim, H, W)

    def forward(self, rgb, x_t, t):
        """ε-prediction. rgb (B,3,*,*), x_t (B, T, H, W), t (B,) int."""
        cond = self._cond_features(rgb)
        inp = torch.cat([x_t, cond], dim=1)                                    # (B, T+C, H, W)
        return self.unet(inp, t)

    def q_sample(self, x_0, t, noise):
        """Forward diffusion: x_t = sqrt(α̅) x_0 + sqrt(1-α̅) ε."""
        sa = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        soma = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sa * x_0 + soma * noise

    @torch.no_grad()
    def sample(self, rgb, n_steps: int = 10):
        """DDIM-style sampling from t=T_diff-1 down to 0 in n_steps."""
        device = rgb.device
        B = rgb.shape[0]
        H = W = self.heatmap_res
        x = torch.randn(B, self.n_window, H, W, device=device)
        timesteps = torch.linspace(self.T_diff - 1, 0, n_steps + 1).long().to(device)
        cond = self._cond_features(rgb)
        for i in range(n_steps):
            t = timesteps[i]; t_next = timesteps[i + 1]
            t_batch = torch.full((B,), int(t), device=device, dtype=torch.long)
            inp = torch.cat([x, cond], dim=1)
            eps = self.unet(inp, t_batch)
            a = self.alphas_cumprod[t]
            a_next = self.alphas_cumprod[t_next] if t_next >= 0 else torch.tensor(1.0, device=device)
            # DDIM update with η=0 (deterministic):
            x0_hat = (x - (1 - a).sqrt() * eps) / a.sqrt()
            x = a_next.sqrt() * x0_hat + (1 - a_next).sqrt() * eps
        return x                                                                # (B, T, H, W)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoHeatmapDiffusion().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    gt_pix = torch.rand(2, N_WINDOW, 2).to(device) * IMG_SIZE
    x0 = make_gaussian_heatmap(gt_pix, N_WINDOW, HEATMAP_RES, HEATMAP_RES, GAUSSIAN_SIGMA,
                                image_size=IMG_SIZE, device=device)
    print(f"x0 max: {x0.max().item():.3f} min: {x0.min().item():.3f}")
    t = torch.randint(0, m.T_diff, (2,), device=device)
    noise = torch.randn_like(x0)
    x_t = m.q_sample(x0, t, noise)
    eps_pred = m(rgb, x_t, t)
    print(f"eps_pred: {tuple(eps_pred.shape)}")
    loss = F.mse_loss(eps_pred, noise)
    print(f"init MSE loss: {loss.item():.4f}")
    # Test sampling
    samp = m.sample(rgb[:1], n_steps=5)
    print(f"sample: {tuple(samp.shape)}, range [{samp.min().item():.2f}, {samp.max().item():.2f}]")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
