"""Dual-camera PARA with vanilla DINOv3 backbone.

Same as PARA but processes both agentview and wrist camera through a shared
DINOv3 backbone → bilinear upsample → conv refinement → per-view 1×1 conv heads.
No cross-attention or communication between views.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16

N_WINDOW = 4
N_HEIGHT_BINS = 32
N_ROT_BINS = 32
PRED_SIZE = 64


class DualParaPredictor(nn.Module):
    """Dual-camera PARA: shared DINOv3 backbone, independent per-view heads."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 freeze_backbone=False, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE
        self.model_type = "dual_para"

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        D = self.embed_dim

        # Start keypoint embedding (agentview only)
        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)

        # Shared feature refinement convs
        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )
        print(f"✓ Shared feature convs: 3× Conv2d(3×3) at pred_size={pred_size}")

        # Per-view 1×1 conv prediction heads
        N_GRIP = 2
        for view in ['agent', 'wrist']:
            setattr(self, f'{view}_volume_head', nn.Conv2d(D, n_window * N_HEIGHT_BINS, 1))
            setattr(self, f'{view}_gripper_head', nn.Conv2d(D, n_window * N_GRIP, 1))
            setattr(self, f'{view}_rotation_head', nn.Conv2d(D, n_window * 3 * N_ROT_BINS, 1))

        print(f"✓ Per-view heads: volume({n_window}×{N_HEIGHT_BINS}), gripper({n_window}×{N_GRIP}), rotation({n_window}×3×{N_ROT_BINS})")

        n_total = sum(p.numel() for p in self.parameters())
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"✓ DualPara: {n_trainable:,} / {n_total:,} trainable params")

    def to(self, device):
        super().to(device)
        self.dino = self.dino.to(device)
        return self

    def _extract_features(self, x):
        """Extract patch features from DINOv3 backbone."""
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
        else:
            x_norm_patches = self.dino.norm(x_tokens)[:, self.dino.n_storage_tokens + 1:]

        patch_features = x_norm_patches.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()
        return patch_features  # (B, D, H_p, W_p)

    def _upsample_and_refine(self, patch_features):
        """Bilinear upsample to pred_size + conv refinement."""
        feats = F.interpolate(patch_features, size=(self.pred_size, self.pred_size),
                              mode='bilinear', align_corners=False)
        feats = self.feature_convs(feats)
        return feats  # (B, D, pred_size, pred_size)

    def _get_view_predictions(self, feats, view_name, query_pixels=None):
        """Apply per-view 1×1 conv heads, index at query pixels."""
        B = feats.shape[0]
        N = self.n_window
        H = W = self.pred_size

        vol = getattr(self, f'{view_name}_volume_head')(feats).view(B, N, N_HEIGHT_BINS, H, W)

        if query_pixels is not None:
            px = query_pixels[..., 0].long().clamp(0, W - 1)
            py = query_pixels[..., 1].long().clamp(0, H - 1)
            batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
            time_idx = torch.arange(N, device=feats.device).view(1, N).expand(B, N)

            grip_map = getattr(self, f'{view_name}_gripper_head')(feats).view(B, N, 2, H, W)
            gripper = grip_map[batch_idx, time_idx, :, py, px]

            rot_map = getattr(self, f'{view_name}_rotation_head')(feats).view(B, N, 3, N_ROT_BINS, H, W)
            rotation = rot_map[batch_idx, time_idx, :, :, py, px]
        else:
            gripper = rotation = None

        return vol, gripper, rotation

    def predict_at_pixels(self, feats, query_pixels, view_name='agent'):
        """For eval: predict gripper/rotation at specific pixels."""
        _, grip, rot = self._get_view_predictions(feats, view_name, query_pixels)
        return grip, rot

    def forward(self, agent_img, wrist_img=None, start_keypoint_2d=None,
                agent_query_pixels=None, wrist_query_pixels=None):
        """
        Returns dict with agent_volume/gripper/rotation/feats and wrist_* equivalents.
        """
        B = agent_img.shape[0]
        result = {}

        # --- Agentview ---
        agent_patches = self._extract_features(agent_img)
        _, D, H_p, W_p = agent_patches.shape

        if start_keypoint_2d is not None:
            if start_keypoint_2d.dim() == 1:
                start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
            skx = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
            sky = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
            bi = torch.arange(B, device=agent_patches.device)
            agent_patches[bi, :, sky, skx] += self.start_keypoint_embedding.unsqueeze(0)

        agent_feats = self._upsample_and_refine(agent_patches)
        av, ag, ar = self._get_view_predictions(agent_feats, 'agent', agent_query_pixels)
        result['agent_volume'] = av
        result['agent_gripper'] = ag
        result['agent_rotation'] = ar
        result['agent_feats'] = agent_feats

        # --- Wrist view ---
        if wrist_img is not None:
            wrist_patches = self._extract_features(wrist_img)
            wrist_feats = self._upsample_and_refine(wrist_patches)
            wv, wg, wr = self._get_view_predictions(wrist_feats, 'wrist', wrist_query_pixels)
            result['wrist_volume'] = wv
            result['wrist_gripper'] = wg
            result['wrist_rotation'] = wr
            result['wrist_feats'] = wrist_feats

        return result


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DualParaPredictor(target_size=448, n_window=N_WINDOW)
    model = model.to(device)
    a = torch.randn(2, 3, 448, 448).to(device)
    w = torch.randn(2, 3, 448, 448).to(device)
    kp = torch.tensor([224.0, 224.0]).to(device)
    aq = torch.zeros(2, N_WINDOW, 2).to(device)
    wq = torch.zeros(2, N_WINDOW, 2).to(device)
    with torch.no_grad():
        out = model(a, w, start_keypoint_2d=kp, agent_query_pixels=aq, wrist_query_pixels=wq)
    for k, v in out.items():
        if v is not None:
            print(f"{k}: {v.shape}")
