"""Model for trajectory volume prediction using DINOv2.

Predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (cross-entropy).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel during training,
decoded at predicted pixel during inference (teacher forcing in train, argmax at pred pixel in val/inference).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINO configuration
DINO_REPO_DIR = "dinov3"
DINO_WEIGHTS_PATH = "dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 12
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32


class TrajectoryHeatmapPredictor(nn.Module):
    """Predicts pixel-aligned volume (N_WINDOW x N_HEIGHT_BINS per pixel) and per-pixel gripper (N_WINDOW x N_GRIPPER_BINS per pixel)."""

    def __init__(self, target_size=448, n_window=N_WINDOW, freeze_backbone=False):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        print(f"✓ DINO embedding dim: {self.embed_dim}")

        # Volume head: per-pixel logits for (N_WINDOW, N_HEIGHT_BINS)
        self.volume_head = nn.Conv2d(
            self.embed_dim,
            self.n_window * N_HEIGHT_BINS,
            kernel_size=1
        )
        print(f"✓ Volume head: (B, {self.n_window}*{N_HEIGHT_BINS}, H_p, W_p) -> upsample to (B, {self.n_window}, {N_HEIGHT_BINS}, H, W)")

        self.start_keypoint_embedding = nn.Parameter(torch.randn(self.embed_dim) * 0.02)
        print(f"✓ Learnable start keypoint embedding (dim={self.embed_dim})")

        # Gripper: per-pixel logits (N_WINDOW, N_GRIPPER_BINS) — supervised at GT pixel in train, decoded at pred pixel in inference
        self.gripper_head = nn.Conv2d(
            self.embed_dim,
            self.n_window * N_GRIPPER_BINS,
            kernel_size=1
        )
        print(f"✓ Gripper head (per-pixel): (B, {self.n_window}*{N_GRIPPER_BINS}, H_p, W_p) -> upsample to (B, {self.n_window}, {N_GRIPPER_BINS}, H, W)")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """Extract patch features and CLS token.
        Returns:
            patch_features: (B, D, H_p, W_p)
            cls_token: (B, D)
        """
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)

        cls_token = x_tokens[:, 0]  # (B, D)
        patch_tokens = x_tokens[:, self.dino.n_storage_tokens + 1 :]
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features, cls_token

    def forward(self, x, gt_target_heatmap=None, training=False, start_keypoint_2d=None, current_height=None, current_gripper=None):
        """
        Args:
            x: (B, 3, H, W)
            start_keypoint_2d: (B, 2) or (2,) optional
            current_height, current_gripper: ignored (kept for API compatibility)

        Returns:
            volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
            gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W)  # per-pixel; index at GT pixel (train) or pred pixel (inference)
        """
        B = x.shape[0]
        patch_features, cls_token = self._extract_dino_features(x)  # (B, D, H_p, W_p), (B, D)
        _, D, H_p, W_p = patch_features.shape

        if start_keypoint_2d.dim() == 1: start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_patch_x = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
        batch_indices = torch.arange(B, device=patch_features.device)
        patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Volume: (B, N_WINDOW*N_HEIGHT_BINS, H_p, W_p) -> (B, N_WINDOW, N_HEIGHT_BINS, H, W)
        vol = self.volume_head(patch_features)  # (B, N_W*Nh, H_p, W_p)
        vol = vol.view(B, self.n_window, N_HEIGHT_BINS, H_p, W_p)
        volume_logits = F.interpolate(
            vol.view(B, self.n_window * N_HEIGHT_BINS, H_p, W_p),
            size=(self.target_size, self.target_size),
            mode='bilinear',
            align_corners=False
        )
        volume_logits = volume_logits.view(B, self.n_window, N_HEIGHT_BINS, self.target_size, self.target_size)

        # Gripper: per-pixel (B, N_WINDOW*N_GRIPPER_BINS, H_p, W_p) -> interpolate -> (B, N_WINDOW, N_GRIPPER_BINS, H, W)
        grip = self.gripper_head(patch_features)  # (B, N_W*Ng, H_p, W_p)
        grip = F.interpolate(
            grip,
            size=(self.target_size, self.target_size),
            mode='bilinear',
            align_corners=False
        )
        gripper_logits = grip.view(B, self.n_window, N_GRIPPER_BINS, self.target_size, self.target_size)

        return volume_logits, gripper_logits


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = TrajectoryHeatmapPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    with torch.no_grad():
        vol, grip = model(x, training=False, start_keypoint_2d=torch.tensor([224.0, 224.0]))
    print("volume_logits", vol.shape)
    print("gripper_logits", grip.shape)
