"""DA3-based pixel-aligned heatmap predictor.

Wraps Depth-Anything-3 (DA3-SMALL). Repurposes the DualDPT head:
  - main head: 1 ch depth + 1 ch conf — kept, supervised by frozen-DA3 depth predictions.
  - aux  head: original 7 + 1 ch — REPLACED with (N_WINDOW + 1) channels so the
               DualDPT split (`aux_pred=[..., :-1]`, `aux_conf=[..., -1]`) yields
               N_WINDOW heatmap channels and 1 unused conf channel.

We bypass DA3's `_process_camera_estimation` (which DELETES `ray`/`ray_conf` from the
output) and call `backbone` + `head` directly.

Output of `forward(rgb)`:
  pred_heatmap: (B, N_WINDOW, H_out, W_out)  raw logits
  pred_depth:   (B, H_out, W_out)
  dino_feats:   list of (B, n_patches, C) intermediate DINO features (for PCA viz)
"""
import sys, types, os
for n in ['depth_anything_3.utils.export', 'depth_anything_3.utils.pose_align']:
    m = types.ModuleType(n); sys.modules[n] = m
sys.modules['depth_anything_3.utils.export'].export = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].align_poses_umeyama = lambda *a, **k: None

import torch
import torch.nn as nn

from depth_anything_3.api import DepthAnything3

DA3_WEIGHTS = "/data/cameron/da3_weights"
DA3_INPUT   = 504
N_WINDOW    = 8


class FiLMHead(nn.Module):
    """Per-timestep FiLM modulation + per-timestep 1×1 conv → 1 logit per (pixel, t).

    Input:  (B, in_dim, H, W) shared per-pixel features (from DA3 DPT aux pre-final layers)
    Output: (B, N_WINDOW + 1, H, W) where the last channel is a dummy conf (DualDPT splits
            it off). Heatmaps live in channels [:N_WINDOW].
    """

    def __init__(self, in_dim: int = 32, n_window: int = N_WINDOW):
        super().__init__()
        self.n_window = n_window
        self.in_dim   = in_dim
        # FiLM: per-timestep multiplicative scale + additive shift on the in_dim features.
        # Initialise scale ≈ 1, shift ≈ 0 so the model starts close to "all timesteps identical".
        self.scale = nn.Parameter(torch.ones(n_window, in_dim))
        self.shift = nn.Parameter(torch.zeros(n_window, in_dim))
        # Single 1×1 conv from in_dim → 1, shared across timesteps (timesteps already
        # differentiated by FiLM).
        self.conv = nn.Conv2d(in_dim, 1, kernel_size=1)
        nn.init.zeros_(self.conv.bias)
        nn.init.normal_(self.conv.weight, std=0.01)

    def forward(self, x):
        B, C, H, W = x.shape
        # FiLM: x_t = x * scale[t] + shift[t]
        # x: (B, C, H, W); scale: (T, C); shift: (T, C)
        # broadcast: (B, T, C, H, W) = x[:, None] * scale[None, :, :, None, None] + ...
        x_t = (x.unsqueeze(1) *
               self.scale[None, :, :, None, None] +
               self.shift[None, :, :, None, None])                              # (B, T, C, H, W)
        x_t = x_t.reshape(B * self.n_window, C, H, W)
        out = self.conv(x_t)                                                     # (B*T, 1, H, W)
        out = out.reshape(B, self.n_window, H, W)
        # Pad a dummy conf channel so DualDPT's `aux_pred=fmap[..., :-1], aux_conf=fmap[..., -1]`
        # split yields N_WINDOW heatmaps in aux_pred and a thrown-away conf.
        dummy_conf = torch.zeros(B, 1, H, W, device=x.device, dtype=x.dtype)
        return torch.cat([out, dummy_conf], dim=1)                               # (B, T+1, H, W)


class DA3PixelModel(nn.Module):
    def __init__(self, n_window: int = N_WINDOW, weights_path: str = DA3_WEIGHTS,
                 dino_feat_layers=None):
        super().__init__()
        self.n_window = n_window
        # Some variants (LARGE/GIANT) have different depth; auto-pick last 4 layers if not given.
        # We'll re-derive after backbone is loaded if dino_feat_layers is None.
        self._user_feat_layers = dino_feat_layers

        full = DepthAnything3.from_pretrained(weights_path)
        # We only need backbone + head, skip cam_enc / cam_dec entirely.
        self.backbone = full.model.backbone
        self.head     = full.model.head
        # Free the camera networks from memory.
        del full

        # Pick DINO feature layers — use the backbone's own configured `out_layers` so the
        # 4 features fed to DPT match the channel widths DPT expects (small=[5,7,9,11];
        # large=[11,15,19,23]).
        if self._user_feat_layers is None:
            self.dino_feat_layers = list(getattr(self.backbone, 'out_layers', [5, 7, 9, 11]))
        else:
            self.dino_feat_layers = list(self._user_feat_layers)

        # Swap final aux Conv2d(32, 7, 1×1) → Conv2d(32, N_WINDOW + 1, 1×1).
        # Per-timestep channel outputs (no FiLM) — each timestep gets its own output channel.
        # The +1 is consumed as the DualDPT conf split.
        last_aux_seq = self.head.scratch.output_conv2_aux[-1]
        old_conv = last_aux_seq[-1]
        new_conv = nn.Conv2d(old_conv.in_channels, n_window + 1, kernel_size=1, stride=1, padding=0)
        nn.init.zeros_(new_conv.bias)
        nn.init.normal_(new_conv.weight, std=0.01)
        last_aux_seq[-1] = new_conv

    def forward(self, rgb):
        """rgb: (B, 3, 504, 504) in [0, 1]."""
        x = rgb.unsqueeze(1)                                     # (B, S=1, 3, H, W)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats, _aux = self.backbone(
                x, cam_token=None,
                export_feat_layers=self.dino_feat_layers,
                ref_view_strategy="saddle_balanced",
            )
        H, W = x.shape[-2], x.shape[-1]
        # DPT head — run in float32 (matches DA3's _process_depth_head autocast_disabled)
        with torch.autocast(device_type=rgb.device.type, enabled=False):
            head_out = self.head(feats, H, W, patch_start_idx=0)

        # main 'depth' shape (B, S, H, W). drop S.
        depth = head_out['depth']
        if depth.dim() == 5:
            pred_depth = depth[:, 0, 0]  # (B, H, W)  — sometimes (B, S, 1, H, W)
        elif depth.dim() == 4:
            pred_depth = depth[:, 0]
        else:
            raise RuntimeError(f"unexpected depth shape {tuple(depth.shape)}")

        # 'ray' (aux split): shape (B, S, H, W, N_WINDOW) — channel dim was permuted to last
        ray = head_out['ray']
        if ray.dim() == 5:
            pred_heatmap = ray[:, 0].permute(0, 3, 1, 2)         # (B, N_WINDOW, H, W)
        elif ray.dim() == 4:
            pred_heatmap = ray.permute(0, 3, 1, 2)
        else:
            raise RuntimeError(f"unexpected ray shape {tuple(ray.shape)}")

        # Extract DINO intermediate features for PCA viz (list of (B*S, T, C))
        dino_feats = []
        for layer_feats in feats:
            if isinstance(layer_feats, (list, tuple)):
                # DA3 returns (feats, cam_feats) per layer; grab patch tokens
                dino_feats.append(layer_feats[0])
            else:
                dino_feats.append(layer_feats)

        return {
            "pred_heatmap": pred_heatmap,
            "pred_depth":   pred_depth,
            "dino_feats":   dino_feats,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DA3PixelModel().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, DA3_INPUT, DA3_INPUT).to(device)
    with torch.no_grad():
        out = m(rgb)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
        elif isinstance(v, (list, tuple)):
            print(f"  {k}: list({len(v)})")
            if v and hasattr(v[0], 'shape'): print(f"    first: {tuple(v[0].shape)}")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
