"""Model for trajectory volume prediction using DINOv2.

Predicts a pixel-aligned volume: N_WINDOW x N_HEIGHT_BINS logits per pixel (cross-entropy).
Gripper is per-pixel (N_WINDOW x N_GRIPPER_BINS per pixel): supervised at GT pixel during training,
decoded at predicted pixel during inference (teacher forcing in train, argmax at pred pixel in val/inference).
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINO configuration — override via env vars on servers without local weights
DINO_REPO_DIR    = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_CONVNEXT_WEIGHTS_PATH = os.environ.get(
    "DINO_CONVNEXT_WEIGHTS_PATH",
    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_convnext_small_pretrain_lvd1689m-296db49d.pth",
)
# Backbone selector: "vits16plus" (default, original ViT path with CLS-concat
# gripper/rotation heads) or "convnext_small" (DINOv3 ConvNeXt-S, no CLS,
# grid_sample for sub-pixel-indexed gripper/rotation features).
DINO_BACKBONE = os.environ.get("DINO_BACKBONE", "vits16plus")
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 6
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
MIN_ROT = [-3.14159, -3.14159, -3.14159]  # per-axis euler XYZ min (updated from dataset stats)
MAX_ROT = [ 3.14159,  3.14159,  3.14159]  # per-axis euler XYZ max
MIN_POS = [-1.0, -1.0, 0.0]  # per-axis XYZ min (updated from dataset stats)
MAX_POS = [ 1.0,  1.0,  2.0]  # per-axis XYZ max
N_HEIGHT_BINS = 32
# Smith300 gripper: 10 bins between MIN_GRIPPER and MAX_GRIPPER (set per-run
# from dataset stats in train_smith300_para.py). The original libero head was
# a single-logit BCE; for the smith300's continuous gripper q in radians
# (~0.03 to ~1.07 in our recording, never crossing 0) the binary thresh-at-zero
# target degenerates to "always closed". CE over discretized bins gives a real
# supervised signal across the recorded range.
N_GRIPPER_BINS = 10
N_ROT_BINS = 32

PRED_SIZE = 64  # supervision resolution; upsample to IMAGE_SIZE for vis/downstream


class TrajectoryHeatmapPredictor(nn.Module):
    """Predicts pixel-aligned volume (N_WINDOW x N_HEIGHT_BINS per pixel) and per-pixel gripper (N_WINDOW x N_GRIPPER_BINS per pixel)."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 freeze_backbone=False, backbone: str | None = None):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE
        self.backbone = (backbone or DINO_BACKBONE)

        if self.backbone == "vits16plus":
            print("Loading DINOv3 ViT-S/16+ model...")
            self.dino = torch.hub.load(
                DINO_REPO_DIR, 'dinov3_vits16plus',
                source='local', weights=DINO_WEIGHTS_PATH,
            )
            # ViT path uses CLS broadcast concat for grip/rot heads.
            self._head_channel_mult = 2
        elif self.backbone == "convnext_small":
            print("Loading DINOv3 ConvNeXt-Small model...")
            self.dino = torch.hub.load(
                DINO_REPO_DIR, 'dinov3_convnext_small',
                source='local', weights=DINO_CONVNEXT_WEIGHTS_PATH,
            )
            # ConvNeXt path ignores the (synthetic) CLS — heads see only the
            # local conv feature.
            self._head_channel_mult = 1
        else:
            raise ValueError(f"unknown backbone {self.backbone!r} "
                             "(expected 'vits16plus' or 'convnext_small')")
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print(f"✓ Frozen {self.backbone} backbone")
        else:
            print(f"✓ {self.backbone} backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        print(f"✓ DINO embedding dim: {self.embed_dim}")

        self.start_keypoint_embedding = nn.Parameter(torch.randn(self.embed_dim) * 0.02)
        print(f"✓ Learnable start keypoint embedding (dim={self.embed_dim})")

        # Shared feature refinement: upsample patch grid to pred_size, then 3 conv layers
        D = self.embed_dim
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
        )
        print(f"✓ Feature convs: 3× Conv2d(3×3) at pred_size={pred_size}")

        # Volume head: 1×1 conv at pred_size (spatial — needed for heatmap)
        self.volume_head = nn.Conv2d(D, self.n_window * N_HEIGHT_BINS, kernel_size=1)
        print(f"✓ Volume   head → (B, {self.n_window}, {N_HEIGHT_BINS}, {pred_size}, {pred_size})")

        # Gripper / rotation: 1×1 conv heads applied densely to a feature map
        # that concatenates (a) the per-pixel spatial feats and (b) the DINO
        # CLS token broadcast over space. At query-pixel time we just index
        # the dense output. The CLS gives the gripper/rotation heads global
        # scene context (cup at this height, table location, …) instead of
        # ONLY the local pixel feature — which on the UMI was dominated by
        # the green-gripper appearance and didn't transfer to the white
        # robot gripper. Plus no detach() so gradients shape the shared
        # features — safe because the volume-only warm-up (--volume_only_steps)
        # locks in good features before grip/rot losses turn on.
        head_in_dim = self._head_channel_mult * D
        self.gripper_head = nn.Conv2d(head_in_dim, N_GRIPPER_BINS, kernel_size=1)
        self.rotation_head = nn.Conv2d(head_in_dim, 3 * N_ROT_BINS, kernel_size=1)
        _head_tag = ("feats||CLS, integer-indexed" if self.backbone == "vits16plus"
                     else "feats only, grid_sample sub-pixel")
        print(f"✓ Gripper  head → (B, {self.n_window}, {N_GRIPPER_BINS})   [1×1 conv on {_head_tag}]")
        print(f"✓ Rotation head → (B, {self.n_window}, 3, {N_ROT_BINS})   [1×1 conv on {_head_tag}]")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """ViT path: extract patch features and CLS token.
        Returns:
            patch_features: (B, D, H_p, W_p)   stride=patch_size (16)
            cls_token: (B, D)
        """
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)

        cls_token = x_tokens[:, 0]  # (B, D)
        patch_tokens = x_tokens[:, self.dino.n_storage_tokens + 1 :]
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features, cls_token

    def _extract_convnext_features(self, x):
        """ConvNeXt path: take the stride-32 final stage feature map.
        Returns:
            patch_features: (B, D, H_p, W_p)   stride 32 (448 -> 14×14)
            cls_token: (B, D)                   synthetic global-pool; ignored
                                                downstream (kept for API parity)
        """
        B = x.shape[0]
        out = self.dino.forward_features(x)
        cls_token = out["x_norm_clstoken"]                          # (B, D)
        patch_tokens = out["x_norm_patchtokens"]                    # (B, H_p*W_p, D)
        N = patch_tokens.shape[1]
        side = int(round(N ** 0.5))
        assert side * side == N, f"convnext patch tokens not square: N={N}"
        patch_features = patch_tokens.reshape(B, side, side, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()
        return patch_features, cls_token

    def _index_features(self, feats, query_pixels):
        """Index spatial feature map at specified pixel locations.

        Args:
            feats:         (B, D, H, W)
            query_pixels:  (B, N, 2) pixel coords [x, y] in feats coordinate space

        Returns:
            indexed: (B, N, D)
        """
        B, D, H, W = feats.shape
        N = query_pixels.shape[1]
        px = query_pixels[..., 0].long().clamp(0, W - 1)  # (B, N)
        py = query_pixels[..., 1].long().clamp(0, H - 1)  # (B, N)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        return feats[batch_idx, :, py, px]  # (B, N, D)

    def predict_at_pixels(self, head_input, query_pixels):
        """Apply gripper/rotation 1×1 conv heads densely, then sample at the
        query pixels via F.grid_sample (bilinear, sub-pixel). head_input is
        (B, C, H, W) where C is 2D (ViT path: feats||CLS) or D (ConvNeXt path:
        feats only).

        Called with GT pixels during training (teacher forcing) and with
        predicted pixels (from volume argmax) during inference. No detach:
        gradients flow back into the shared features.

        Args:
            head_input:    (B, C, pred_size, pred_size)
            query_pixels:  (B, N_WINDOW, 2) in pred_size coordinate space; can
                           be float for sub-pixel queries.

        Returns:
            gripper_logits:  (B, N_WINDOW, N_GRIPPER_BINS)
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)
        """
        B, _, H, W = head_input.shape
        N = query_pixels.shape[1]
        grip_map = self.gripper_head(head_input)       # (B, N_GRIPPER_BINS, H, W)
        rot_map = self.rotation_head(head_input)       # (B, 3*N_ROT_BINS, H, W)

        # Normalize pixel coords (in [0, W-1] / [0, H-1]) to grid_sample's
        # align_corners=True frame ([-1, +1] at pixel centers).
        qx = query_pixels[..., 0].float()
        qy = query_pixels[..., 1].float()
        norm_x = 2.0 * qx / max(W - 1, 1) - 1.0
        norm_y = 2.0 * qy / max(H - 1, 1) - 1.0
        grid = torch.stack([norm_x, norm_y], dim=-1).unsqueeze(1)  # (B, 1, N, 2)

        grip = F.grid_sample(grip_map, grid, mode='bilinear', align_corners=True)
        rot = F.grid_sample(rot_map, grid, mode='bilinear', align_corners=True)
        gripper = grip.squeeze(2).transpose(1, 2)                     # (B, N, NG)
        rotation = rot.squeeze(2).transpose(1, 2).reshape(B, N, 3, N_ROT_BINS)
        return gripper, rotation

    def forward(self, x, start_keypoint_2d, query_pixels=None):
        """
        Args:
            x:                  (B, 3, H, W)
            start_keypoint_2d:  (B, 2) or (2,) current EEF pixel in image coords
            query_pixels:       (B, N_WINDOW, 2) pixel coords in pred_size space to query
                                for gripper/rotation.  Pass GT pixels during training;
                                pass predicted pixels (volume argmax) during inference.
                                If None, gripper/rotation logits are not computed.

        Returns:
            volume_logits:   (B, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            gripper_logits:  (B, N_WINDOW, N_GRIPPER_BINS)  or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)   or None
            feats:           (B, D, pred_size, pred_size)  — for downstream predict_at_pixels
        """
        B = x.shape[0]
        if self.backbone == "vits16plus":
            patch_features, cls_token = self._extract_dino_features(x)
        else:
            patch_features, cls_token = self._extract_convnext_features(x)
        _, D, H_p, W_p = patch_features.shape

        # Start keypoint conditioning: inject into patch token at current EEF location
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_patch_x = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
        start_patch_y = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
        batch_indices = torch.arange(B, device=patch_features.device)
        patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Upsample patch grid (28×28) → pred_size (64×64)
        feats = F.interpolate(patch_features, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)

        # Refine with 3 conv layers at pred_size resolution
        feats = self.feature_convs(feats)  # (B, D, pred_size, pred_size)

        # Volume head: dense spatial prediction (full heatmap needed)
        vol = self.volume_head(feats)  # (B, N_W*Nh, pred_size, pred_size)
        volume_logits = vol.view(B, self.n_window, N_HEIGHT_BINS, self.pred_size, self.pred_size)

        # Build head input. ViT path concatenates the CLS token broadcast
        # over space (global scene context for grip/rot). ConvNeXt path
        # ignores the synthetic CLS and feeds local conv features only.
        if self.backbone == "vits16plus":
            cls_broadcast = cls_token[:, :, None, None].expand(
                -1, -1, self.pred_size, self.pred_size
            )
            head_input = torch.cat([feats, cls_broadcast], dim=1)  # (B, 2D, H, W)
        else:
            head_input = feats  # (B, D, H, W)

        if query_pixels is not None:
            gripper_logits, rotation_logits = self.predict_at_pixels(
                head_input, query_pixels)
        else:
            gripper_logits = rotation_logits = None

        return volume_logits, gripper_logits, rotation_logits, head_input


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = TrajectoryHeatmapPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    fake_query = torch.zeros(2, N_WINDOW, 2).to(device)
    with torch.no_grad():
        vol, grip, rot, feats = model(x, start_keypoint_2d=torch.tensor([224.0, 224.0]).to(device), query_pixels=fake_query)
    print("volume_logits  ", vol.shape)
    print("gripper_logits ", grip.shape)
    print("rotation_logits", rot.shape)
    print("feats          ", feats.shape)
