"""Point-track heatmap predictor: DINO + query-conditioned 64x64 heatmaps per timestep."""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINO configuration — override via env vars on servers
DINO_REPO_DIR = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/point_track_pretraining/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/point_track_pretraining/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# Point-track pretraining
N_WINDOW_POINT_TRACK = 15
N_QUERY_POINTS = 32
HEATMAP_SIZE = 64


def _conv3x3(c_in, c_out):
    return nn.Conv2d(c_in, c_out, 3, padding=1)


class PointTrackHeatmapPredictor(nn.Module):
    """Predicts 64x64 heatmap logits per (query, timestep). Conditioning: concat DINO feature at query start to every patch."""

    def __init__(self, target_size=448, n_window=N_WINDOW_POINT_TRACK, n_query=N_QUERY_POINTS, heatmap_size=HEATMAP_SIZE, freeze_backbone=False):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.n_query = n_query
        self.heatmap_size = heatmap_size
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("Frozen DINOv2 backbone")
        else:
            print("DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        print(f"DINO embedding dim: {self.embed_dim}")

        # Shared decoder: (B, 2*D, H_p, W_p) -> bilinear 64x64 -> 3 convs -> 1x1 -> (B, n_window, 64, 64)
        dim_concat = self.embed_dim * 2
        self.upsample = lambda x: F.interpolate(x, size=(heatmap_size, heatmap_size), mode='bilinear', align_corners=False)
        self.decoder = nn.Sequential(
            _conv3x3(dim_concat, 256),
            nn.ReLU(inplace=True),
            _conv3x3(256, 256),
            nn.ReLU(inplace=True),
            _conv3x3(256, 128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, n_window, 1),
        )
        print(f"Point-track decoder: 2*D -> bilinear {heatmap_size}x{heatmap_size} -> 3 convs -> 1x1 -> (B, {n_window}, {heatmap_size}, {heatmap_size})")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """Returns patch_features (B, D, H_p, W_p), cls_token (B, D)."""
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)

        cls_token = x_tokens[:, 0]
        patch_tokens = x_tokens[:, self.dino.n_storage_tokens + 1 :]
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features, cls_token

    def forward(self, x, start_keypoint_2d):
        """
        Args:
            x: (B, 3, H, W) first frame
            start_keypoint_2d: (B, n_query, 2) in pixel coords (448 space)

        Returns:
            heatmap_logits: (B, n_query, n_window, heatmap_size, heatmap_size)
        """
        B = x.shape[0]
        patch_features, _ = self._extract_dino_features(x)  # (B, D, H_p, W_p)
        _, D, H_p, W_p = patch_features.shape
        dev = patch_features.device

        # start_keypoint_2d (B, n_query, 2) -> patch indices (B, n_query)
        start_patch_x = (start_keypoint_2d[..., 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[..., 1] * H_p / self.target_size).long().clamp(0, H_p - 1)

        # Query feature per (b, q): patch_features[b, :, start_patch_y[b,q], start_patch_x[b,q]] -> (B, n_query, D)
        batch_idx = torch.arange(B, device=dev).view(B, 1).expand(B, self.n_query)
        query_feat = patch_features[batch_idx, :, start_patch_y, start_patch_x]  # (B, n_query, D)

        # (B, n_query, D, H_p, W_p) and (B, n_query, D, H_p, W_p) -> concat -> (B, n_query, 2*D, H_p, W_p)
        patch_expand = patch_features.unsqueeze(1).expand(B, self.n_query, D, H_p, W_p)
        query_expand = query_feat.unsqueeze(-1).unsqueeze(-1).expand(B, self.n_query, D, H_p, W_p)
        concat = torch.cat([patch_expand, query_expand], dim=2)  # (B, n_query, 2*D, H_p, W_p)
        concat = concat.view(B * self.n_query, 2 * D, H_p, W_p)

        up = self.upsample(concat)
        out = self.decoder(up)  # (B*n_query, n_window, heatmap_size, heatmap_size)
        out = out.view(B, self.n_query, self.n_window, self.heatmap_size, self.heatmap_size)
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PointTrackHeatmapPredictor(
        target_size=448,
        n_window=N_WINDOW_POINT_TRACK,
        n_query=N_QUERY_POINTS,
        heatmap_size=HEATMAP_SIZE,
        freeze_backbone=True,
    )
    model = model.to(device)
    B, Q = 2, N_QUERY_POINTS
    x = torch.randn(B, 3, 448, 448).to(device)
    start_2d = torch.rand(B, Q, 2).to(device) * 448
    with torch.no_grad():
        logits = model(x, start_2d)
    print("heatmap_logits", logits.shape)  # (B, 32, 15, 64, 64)
