"""MoGe baseline — MoGe v2 pretrained DINOv2 backbone + PARA heads.

Uses the DINOv2 ViT-S/14 backbone from MoGe v2, pretrained for monocular
geometry estimation. The hypothesis: geometry-pretrained features improve
3D EEF prediction.

Architecture: same as PARA (heatmap volume + gripper/rotation MLPs) but with
MoGe's geometry-pretrained backbone instead of vanilla DINOv3.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

N_WINDOW = 6
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
N_ROT_BINS = 32
PRED_SIZE = 64
PATCH_SIZE = 14

MOGE_WEIGHTS_PATH = os.environ.get("MOGE_WEIGHTS_PATH", "/data/cameron/moge_weights/model.pt")


def _load_dinov2_vits14_from_moge(weights_path):
    """Load DINOv2 ViT-S/14 backbone and initialize with MoGe v2 pretrained weights."""
    backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=False)

    sd = torch.load(weights_path, map_location='cpu')
    if 'model' in sd:
        sd = sd['model']

    prefix = "encoder.backbone."
    backbone_sd = {}
    for k, v in sd.items():
        if k.startswith(prefix):
            new_key = k[len(prefix):]
            backbone_sd[new_key] = v

    missing, unexpected = backbone.load_state_dict(backbone_sd, strict=False)
    if missing:
        print(f"  MoGe backbone missing keys ({len(missing)}): {missing[:5]}...")
    if unexpected:
        print(f"  MoGe backbone unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
    print(f"✓ Loaded MoGe pretrained backbone ({len(backbone_sd)} keys)")
    return backbone


class MoGePredictor(nn.Module):
    """MoGe-pretrained DINOv2 ViT-S/14 backbone + PARA-style heatmap heads."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW, freeze_backbone=False, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = PATCH_SIZE
        self.model_type = "moge"

        print("Loading MoGe pretrained DINOv2 backbone...")
        self.backbone = _load_dinov2_vits14_from_moge(MOGE_WEIGHTS_PATH)

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
            self.backbone.eval()
            print("✓ Frozen MoGe backbone")
        else:
            print("✓ MoGe backbone is trainable")

        self.embed_dim = 384
        D = self.embed_dim

        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)
        print(f"✓ Learnable start keypoint embedding (dim={D})")

        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1), nn.GELU(),
        )
        print(f"✓ Feature convs: 3× Conv2d(3×3) at pred_size={pred_size}")

        self.volume_head = nn.Conv2d(D, n_window * N_HEIGHT_BINS, kernel_size=1)
        print(f"✓ Volume   head → (B, {n_window}, {N_HEIGHT_BINS}, {pred_size}, {pred_size})")

        self.gripper_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 1)
        )
        self.rotation_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3 * N_ROT_BINS)
        )
        print(f"✓ Gripper  MLP  → (B, {n_window}, {N_GRIPPER_BINS})")
        print(f"✓ Rotation MLP  → (B, {n_window}, 3, {N_ROT_BINS})")

    def to(self, device):
        super().to(device)
        self.backbone = self.backbone.to(device)
        return self

    def _extract_features(self, x):
        """Extract patch features using MoGe's DINOv2 backbone."""
        B = x.shape[0]
        feats = self.backbone.get_intermediate_layers(x, n=[11])  # last layer
        patch_tokens = feats[0]  # (B, N_patches, D)
        H_p = W_p = int(patch_tokens.shape[1] ** 0.5)
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features

    def _index_features(self, feats, query_pixels):
        B, D, H, W = feats.shape
        N = query_pixels.shape[1]
        px = query_pixels[..., 0].long().clamp(0, W - 1)
        py = query_pixels[..., 1].long().clamp(0, H - 1)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        return feats[batch_idx, :, py, px]

    def predict_at_pixels(self, feats, query_pixels):
        B, N = query_pixels.shape[:2]
        indexed = self._index_features(feats.detach(), query_pixels)
        flat = indexed.reshape(B * N, self.embed_dim)
        gripper = self.gripper_mlp(flat).reshape(B, N)
        rotation = self.rotation_mlp(flat).reshape(B, N, 3, N_ROT_BINS)
        return gripper, rotation

    def forward(self, x, start_keypoint_2d, query_pixels=None):
        B = x.shape[0]

        # Extract patch features from MoGe backbone
        patch_features = self._extract_features(x)  # (B, D, H_p, W_p)
        _, D, H_p, W_p = patch_features.shape

        # Start keypoint conditioning
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_patch_x = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
        batch_indices = torch.arange(B, device=patch_features.device)
        patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Upsample to pred_size
        feats = F.interpolate(patch_features, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)
        feats = self.feature_convs(feats)

        vol = self.volume_head(feats)
        volume_logits = vol.view(B, self.n_window, N_HEIGHT_BINS, self.pred_size, self.pred_size)

        if query_pixels is not None:
            gripper_logits, rotation_logits = self.predict_at_pixels(feats, query_pixels)
        else:
            gripper_logits = rotation_logits = None

        return volume_logits, gripper_logits, rotation_logits, feats
