"""DINO-VLA baseline — DINOv3 backbone + CLIP task conditioning + PARA heads.

Same architecture as PARA, but with a precomputed CLIP text embedding of the
current task added to every DINO patch token via a learned projection MLP.
This tests whether language-conditioned features help multi-task performance.

The CLIP embeddings are precomputed (frozen) and loaded as .pt tensors —
no CLIP forward pass at train time.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16

N_WINDOW = 6
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
PRED_SIZE = 64
CLIP_DIM = 512  # openai/clip-vit-base-patch32 output dim


class DinoVLAPredictor(nn.Module):
    """DINOv3 backbone + CLIP task conditioning + PARA-style heatmap heads."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 freeze_backbone=False, clip_dim=CLIP_DIM, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE
        self.model_type = "dino_vla"

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        D = self.embed_dim

        # CLIP task conditioning: project clip_dim → dino_dim, add to every patch token
        self.clip_proj = nn.Sequential(
            nn.Linear(clip_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
        )
        print(f"✓ CLIP projection: {clip_dim} → {D}")

        # Start keypoint embedding (same as PARA)
        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)
        print(f"✓ Learnable start keypoint embedding (dim={D})")

        # Feature refinement convs
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
        )
        print(f"✓ Feature convs: 3× Conv2d(3×3) at pred_size={pred_size}")

        # Volume head
        self.volume_head = nn.Conv2d(D, n_window * N_HEIGHT_BINS, kernel_size=1)
        print(f"✓ Volume   head → (B, {n_window}, {N_HEIGHT_BINS}, {pred_size}, {pred_size})")

        # Gripper / rotation MLPs
        self.gripper_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 1)
        )
        self.rotation_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3 * N_ROT_BINS)
        )
        print(f"✓ Gripper  MLP  → (B, {n_window}, {N_GRIPPER_BINS})")
        print(f"✓ Rotation MLP  → (B, {n_window}, 3, {N_ROT_BINS})")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """Extract patch features from DINOv3 backbone."""
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
        else:
            x_norm_patches = self.dino.norm(x_tokens)[:, self.dino.n_storage_tokens + 1:]

        patch_features = x_norm_patches.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features

    def _index_features(self, feats, query_pixels):
        B, D, H, W = feats.shape
        N = query_pixels.shape[1]
        px = query_pixels[..., 0].long().clamp(0, W - 1)
        py = query_pixels[..., 1].long().clamp(0, H - 1)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        return feats[batch_idx, :, py, px]

    def predict_at_pixels(self, feats, query_pixels):
        B, N = query_pixels.shape[:2]
        indexed = self._index_features(feats.detach(), query_pixels)
        flat = indexed.reshape(B * N, self.embed_dim)
        gripper = self.gripper_mlp(flat).reshape(B, N)
        rotation = self.rotation_mlp(flat).reshape(B, N, 3, N_ROT_BINS)
        return gripper, rotation

    def forward(self, x, start_keypoint_2d, query_pixels=None, clip_embedding=None):
        """
        Args:
            x:                 (B, 3, H, W)
            start_keypoint_2d: (B, 2) or (2,) current EEF pixel
            query_pixels:      (B, N_WINDOW, 2) for gripper/rotation heads
            clip_embedding:    (B, clip_dim) precomputed CLIP text embedding for task

        Returns:
            volume_logits, gripper_logits, rotation_logits, feats
        """
        B = x.shape[0]

        # Extract DINO patch features
        patch_features = self._extract_dino_features(x)  # (B, D, H_p, W_p)
        _, D, H_p, W_p = patch_features.shape

        # CLIP task conditioning: project and add to every patch
        if clip_embedding is not None:
            clip_proj = self.clip_proj(clip_embedding)  # (B, D)
            # Add to every spatial location
            patch_features = patch_features + clip_proj.unsqueeze(-1).unsqueeze(-1)

        # Start keypoint conditioning
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_patch_x = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
        batch_indices = torch.arange(B, device=patch_features.device)
        patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Upsample to pred_size
        feats = F.interpolate(patch_features, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)
        feats = self.feature_convs(feats)

        vol = self.volume_head(feats)
        volume_logits = vol.view(B, self.n_window, N_HEIGHT_BINS, self.pred_size, self.pred_size)

        if query_pixels is not None:
            gripper_logits, rotation_logits = self.predict_at_pixels(feats, query_pixels)
        else:
            gripper_logits = rotation_logits = None

        return volume_logits, gripper_logits, rotation_logits, feats


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DinoVLAPredictor(target_size=448, n_window=N_WINDOW)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    clip_emb = torch.randn(2, CLIP_DIM).to(device)
    fake_query = torch.zeros(2, N_WINDOW, 2).to(device)
    with torch.no_grad():
        vol, grip, rot, feats = model(x, start_keypoint_2d=torch.tensor([224.0, 224.0]).to(device),
                                       query_pixels=fake_query, clip_embedding=clip_emb)
    print("volume_logits  ", vol.shape)
    print("gripper_logits ", grip.shape)
    print("rotation_logits", rot.shape)
    print("feats          ", feats.shape)
