"""Model for trajectory volume prediction using DINOv2.

Predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (cross-entropy).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel during training,
decoded at predicted pixel during inference (teacher forcing in train, argmax at pred pixel in val/inference).
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINO configuration — override via env vars on servers without local weights
DINO_REPO_DIR    = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 4
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
MIN_ROT = [-3.14159, -3.14159, -3.14159]  # per-axis min (updated from dataset stats)
MAX_ROT = [ 3.14159,  3.14159,  3.14159]  # per-axis max
REF_ROTATION_QUAT = [0.0, 0.0, 0.0, 1.0]  # reference rotation for delta axis-angle (identity default)
MIN_POS = [-1.0, -1.0, 0.0]  # per-axis XYZ min (updated from dataset stats)
MAX_POS = [ 1.0,  1.0,  2.0]  # per-axis XYZ max
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32

PRED_SIZE = 64  # supervision resolution; upsample to IMAGE_SIZE for vis/downstream


class TrajectoryHeatmapPredictor(nn.Module):
    """Predicts pixel-aligned volume (N_WINDOW x N_HEIGHT_BINS per pixel) and per-pixel gripper (N_WINDOW x N_GRIPPER_BINS per pixel)."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW, freeze_backbone=False):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        print(f"✓ DINO embedding dim: {self.embed_dim}")

        self.start_keypoint_embedding = nn.Parameter(torch.randn(self.embed_dim) * 0.02)
        print(f"✓ Learnable start keypoint embedding (dim={self.embed_dim})")

        # Shared feature refinement: upsample patch grid to pred_size, then 3 conv layers
        D = self.embed_dim
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
        )
        print(f"✓ Feature convs: 3× Conv2d(3×3) at pred_size={pred_size}")

        # Volume head: 1×1 conv at pred_size (spatial — needed for heatmap)
        self.volume_head = nn.Conv2d(D, self.n_window * N_HEIGHT_BINS, kernel_size=1)
        print(f"✓ Volume   head → (B, {self.n_window}, {N_HEIGHT_BINS}, {pred_size}, {pred_size})")

        # Gripper / rotation: dense 1×1 conv heads (same pattern as volume head).
        # Produces spatial maps, indexed at GT pixel (train) or predicted pixel (eval).
        N_GRIPPER_CLASSES = 2  # open / close
        self.n_gripper_classes = N_GRIPPER_CLASSES
        self.gripper_head = nn.Conv2d(D, self.n_window * N_GRIPPER_CLASSES, kernel_size=1)
        self.rotation_head = nn.Conv2d(D, self.n_window * 3 * N_ROT_BINS, kernel_size=1)
        print(f"✓ Gripper  head → (B, {self.n_window}, {N_GRIPPER_CLASSES}, {pred_size}, {pred_size})  [1×1 conv, CE 2-class]")
        print(f"✓ Rotation head → (B, {self.n_window}, 3, {N_ROT_BINS}, {pred_size}, {pred_size})  [1×1 conv, CE]")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """Extract patch features and CLS token.
        Returns:
            patch_features: (B, D, H_p, W_p)
            cls_token: (B, D)
        """
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)

        cls_token = x_tokens[:, 0]  # (B, D)
        patch_tokens = x_tokens[:, self.dino.n_storage_tokens + 1 :]
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features, cls_token

    def _index_features(self, feats, query_pixels):
        """Index spatial feature map at specified pixel locations.

        Args:
            feats:         (B, D, H, W)
            query_pixels:  (B, N, 2) pixel coords [x, y] in feats coordinate space

        Returns:
            indexed: (B, N, D)
        """
        B, D, H, W = feats.shape
        N = query_pixels.shape[1]
        px = query_pixels[..., 0].long().clamp(0, W - 1)  # (B, N)
        py = query_pixels[..., 1].long().clamp(0, H - 1)  # (B, N)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        return feats[batch_idx, :, py, px]  # (B, N, D)

    def predict_at_pixels(self, feats, query_pixels):
        """Index dense gripper/rotation maps at specified pixel locations.

        Args:
            feats:         (B, D, pred_size, pred_size)
            query_pixels:  (B, N_WINDOW, 2) in pred_size coordinate space

        Returns:
            gripper_logits:  (B, N_WINDOW, 2) logits for [open, close]
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)
        """
        B = feats.shape[0]
        N = query_pixels.shape[1]
        H = W = self.pred_size
        Nc = self.n_gripper_classes
        Nr = N_ROT_BINS

        px = query_pixels[..., 0].long().clamp(0, W - 1)  # (B, N)
        py = query_pixels[..., 1].long().clamp(0, H - 1)  # (B, N)

        # Dense gripper map
        grip_map = self.gripper_head(feats).view(B, N, Nc, H, W)  # (B, N, 2, H, W)
        # Index: for each (b, t), extract grip_map[b, t, :, py[b,t], px[b,t]]
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        time_idx = torch.arange(N, device=feats.device).view(1, N).expand(B, N)
        gripper_logits = grip_map[batch_idx, time_idx, :, py, px]  # (B, N, 2)

        # Dense rotation map
        rot_map = self.rotation_head(feats).view(B, N, 3, Nr, H, W)  # (B, N, 3, Nr, H, W)
        rotation_logits = rot_map[batch_idx, time_idx, :, :, py, px]  # (B, N, 3, Nr)

        return gripper_logits, rotation_logits

    def forward(self, x, start_keypoint_2d, query_pixels=None):
        """
        Args:
            x:                  (B, 3, H, W)
            start_keypoint_2d:  (B, 2) or (2,) current EEF pixel in image coords
            query_pixels:       (B, N_WINDOW, 2) pixel coords in pred_size space to query
                                for gripper/rotation.  Pass GT pixels during training;
                                pass predicted pixels (volume argmax) during inference.
                                If None, gripper/rotation logits are not computed.

        Returns:
            volume_logits:   (B, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            gripper_logits:  (B, N_WINDOW, 2)  or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)   or None
            feats:           (B, D, pred_size, pred_size)  — for predict_at_pixels at eval
        """
        B = x.shape[0]
        patch_features, cls_token = self._extract_dino_features(x)  # (B, D, H_p, W_p), (B, D)
        _, D, H_p, W_p = patch_features.shape

        # Start keypoint conditioning: inject into patch token at current EEF location
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_patch_x = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
        batch_indices = torch.arange(B, device=patch_features.device)
        patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Upsample patch grid (28×28) → pred_size (64×64)
        feats = F.interpolate(patch_features, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)

        # Refine with 3 conv layers at pred_size resolution
        feats = self.feature_convs(feats)  # (B, D, pred_size, pred_size)

        # Volume head: dense spatial prediction (full heatmap needed)
        vol = self.volume_head(feats)  # (B, N_W*Nh, pred_size, pred_size)
        volume_logits = vol.view(B, self.n_window, N_HEIGHT_BINS, self.pred_size, self.pred_size)

        # Gripper/rotation: dense 1×1 conv heads, indexed at query pixels
        if query_pixels is not None:
            gripper_logits, rotation_logits = self.predict_at_pixels(feats, query_pixels)
        else:
            gripper_logits = rotation_logits = None

        return volume_logits, gripper_logits, rotation_logits, feats


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = TrajectoryHeatmapPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    fake_query = torch.zeros(2, N_WINDOW, 2).to(device)
    with torch.no_grad():
        vol, grip, rot, feats = model(x, start_keypoint_2d=torch.tensor([224.0, 224.0]).to(device), query_pixels=fake_query)
    print("volume_logits  ", vol.shape)
    print("gripper_logits ", grip.shape)
    print("rotation_logits", rot.shape)
    print("feats          ", feats.shape)
