"""Dual-camera PARA with DA3 pretrained backbone + DPT upsampling.

Processes both agentview and wrist camera through a shared DA3 backbone,
upsamples via pretrained DPT refinement to 64×64 with 64-dim features,
then applies per-view 1×1 conv heads for volume/rotation/gripper prediction.

At eval, picks the view with higher heatmap confidence per timestep.
"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file

DA3_WEIGHTS_PATH = os.environ.get("DA3_WEIGHTS_PATH", "/data/cameron/da3_weights/model.safetensors")

N_WINDOW = 4
N_HEIGHT_BINS = 32
N_ROT_BINS = 32
PRED_SIZE = 64
PATCH_SIZE = 14


# ---------------------------------------------------------------------------
# DPT head components (matching DA3 architecture)
# ---------------------------------------------------------------------------

class ResidualConvUnit(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.conv1 = nn.Conv2d(features, features, 3, padding=1)
        self.conv2 = nn.Conv2d(features, features, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(x)
        out = self.conv1(out)
        out = self.relu(out)
        out = self.conv2(out)
        return out + x


class RefineNet(nn.Module):
    def __init__(self, features, has_rcu1=True):
        super().__init__()
        self.out_conv = nn.Conv2d(features, features, 1)
        if has_rcu1:
            self.resConfUnit1 = ResidualConvUnit(features)
        else:
            self.resConfUnit1 = None
        self.resConfUnit2 = ResidualConvUnit(features)

    def forward(self, x, skip=None):
        if skip is not None:
            x = x + skip
        if self.resConfUnit1 is not None:
            x = self.resConfUnit1(x)
        x = self.resConfUnit2(x)
        x = self.out_conv(x)
        return x


class DPTFeatureExtractor(nn.Module):
    """DPT head that produces 64-dim features at ~64×64 resolution.

    Takes multi-scale ViT features from layers [5,7,9,11] and progressively
    refines them through the pretrained DPT pipeline, stopping at the stage
    that gives us ~64×64 spatial resolution.
    """

    def __init__(self, in_channels=[48, 96, 192, 384], features=64):
        super().__init__()
        self.features = features

        # Project backbone features to common dim
        self.layer1_rn = nn.Conv2d(in_channels[0], features, 3, padding=1, bias=False)
        self.layer2_rn = nn.Conv2d(in_channels[1], features, 3, padding=1, bias=False)
        self.layer3_rn = nn.Conv2d(in_channels[2], features, 3, padding=1, bias=False)
        self.layer4_rn = nn.Conv2d(in_channels[3], features, 3, padding=1, bias=False)

        # RefineNets (coarse to fine)
        self.refinenet4 = RefineNet(features, has_rcu1=False)
        self.refinenet3 = RefineNet(features, has_rcu1=True)
        self.refinenet2 = RefineNet(features, has_rcu1=True)
        # Skip refinenet1 — we want ~64×64, not full res

    def forward(self, layer_outputs, patch_h, patch_w):
        """
        Args:
            layer_outputs: list of 4 tensors, each (B, N_patches, D) from backbone layers [5,7,9,11]
            patch_h, patch_w: patch grid dimensions

        Returns:
            features: (B, 64, H_out, W_out) where H_out ≈ 64
        """
        B = layer_outputs[0].shape[0]

        # Reshape patch tokens to spatial grids
        feats = []
        for i, layer_out in enumerate(layer_outputs):
            f = layer_out.reshape(B, patch_h, patch_w, -1).permute(0, 3, 1, 2)  # (B, D, H_p, W_p)
            feats.append(f)

        # Project to in_channels dims via 1×1 conv (done by backbone projects)
        # Then to features dim via 3×3 conv
        l1 = self.layer1_rn(feats[0])  # (B, 64, H_p, W_p)
        l2 = self.layer2_rn(feats[1])
        l3 = self.layer3_rn(feats[2])
        l4 = self.layer4_rn(feats[3])

        # RefineNet: coarse to fine with upsampling
        # All start at patch_h × patch_w (e.g., 32×32 for 448 input with patch=14)
        r4 = self.refinenet4(l4)  # (B, 64, 32, 32)
        r4_up = F.interpolate(r4, size=l3.shape[2:], mode='bilinear', align_corners=False)
        r3 = self.refinenet3(r4_up, l3)  # (B, 64, 32, 32)
        r3_up = F.interpolate(r3, size=l2.shape[2:], mode='bilinear', align_corners=False)
        r2 = self.refinenet2(r3_up, l2)  # (B, 64, 32, 32)

        # Upsample to target pred_size (64×64)
        out = F.interpolate(r2, size=(PRED_SIZE, PRED_SIZE), mode='bilinear', align_corners=False)
        return out  # (B, 64, 64, 64)


# ---------------------------------------------------------------------------
# Full dual-camera model
# ---------------------------------------------------------------------------

class DualDA3Predictor(nn.Module):
    """Dual-camera PARA: shared DA3 backbone + DPT features + per-view prediction heads."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 freeze_backbone=False, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.patch_size = PATCH_SIZE
        self.model_type = "dual_da3"

        print("Loading DA3 backbone (DINOv2 ViT-S/14)...")
        self.backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=False)
        self.embed_dim = 384  # ViT-S
        self.out_layers = [5, 7, 9, 11]

        # Load DA3 pretrained weights
        sd = load_file(DA3_WEIGHTS_PATH)

        # Load backbone
        prefix = "model.backbone.pretrained."
        backbone_sd = {k[len(prefix):]: v for k, v in sd.items() if k.startswith(prefix)}
        missing, unexpected = self.backbone.load_state_dict(backbone_sd, strict=False)
        print(f"✓ Loaded DA3 backbone ({len(backbone_sd)} keys, {len(missing)} missing)")

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
            self.backbone.eval()
            print("✓ Frozen DA3 backbone")

        # Project backbone outputs to DPT input channels
        in_channels = [48, 96, 192, 384]
        self.projects = nn.ModuleList([
            nn.Conv2d(self.embed_dim * 2, ch, 1) for ch in in_channels  # *2 because cat_token=True
        ])

        # Load project weights
        for i, proj in enumerate(self.projects):
            pk = f"model.head.projects.{i}"
            if f"{pk}.weight" in sd:
                proj.weight.data = sd[f"{pk}.weight"]
                proj.bias.data = sd[f"{pk}.bias"]
        self.head_norm = nn.LayerNorm(self.embed_dim * 2)
        if "model.head.norm.weight" in sd:
            self.head_norm.weight.data = sd["model.head.norm.weight"]
            self.head_norm.bias.data = sd["model.head.norm.bias"]
        print("✓ Loaded DPT projection weights")

        # DPT feature extractor (shared between views)
        self.dpt = DPTFeatureExtractor(in_channels=in_channels, features=64)

        # Load DPT refinement weights
        dpt_prefix = "model.head.scratch."
        dpt_sd = {}
        for k, v in sd.items():
            if k.startswith(dpt_prefix):
                new_key = k[len(dpt_prefix):]
                # Map to our naming: layer*_rn, refinenet*
                dpt_sd[new_key] = v
        loaded = self.dpt.load_state_dict(dpt_sd, strict=False)
        print(f"✓ Loaded DPT refinement weights ({len(dpt_sd) - len(loaded.unexpected_keys)} matched)")

        D = 64  # DPT feature dim

        # Start keypoint embedding (for agentview)
        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)

        # Per-view 1×1 conv prediction heads
        N_GRIP = 2
        for view in ['agent', 'wrist']:
            setattr(self, f'{view}_volume_head', nn.Conv2d(D, n_window * N_HEIGHT_BINS, 1))
            setattr(self, f'{view}_gripper_head', nn.Conv2d(D, n_window * N_GRIP, 1))
            setattr(self, f'{view}_rotation_head', nn.Conv2d(D, n_window * 3 * N_ROT_BINS, 1))

        print(f"✓ Per-view heads: volume({n_window}×{N_HEIGHT_BINS}), gripper({n_window}×{N_GRIP}), rotation({n_window}×3×{N_ROT_BINS})")
        print(f"✓ Feature dim: {D}, pred_size: {pred_size}")

        n_total = sum(p.numel() for p in self.parameters())
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"✓ DualDA3: {n_trainable:,} / {n_total:,} trainable params")

    def to(self, device):
        super().to(device)
        self.backbone = self.backbone.to(device)
        return self

    def _extract_features(self, x):
        """Extract multi-scale features from DA3 backbone.

        Returns:
            layer_outputs: list of 4 tensors, each (B, N_patches, D*2) for layers [5,7,9,11]
                           D*2 because cat_token=True (patch token concatenated with CLS)
            patch_h, patch_w: spatial dimensions of patch grid
        """
        B = x.shape[0]
        # Get intermediate layers
        feats = self.backbone.get_intermediate_layers(x, n=self.out_layers, return_class_token=True)
        # feats is list of (patch_tokens, cls_token) tuples

        patch_h = patch_w = int(math.sqrt(feats[0][0].shape[1]))
        layer_outputs = []
        for patch_tokens, cls_token in feats:
            # cat_token: concatenate CLS token to each patch token
            cls_expanded = cls_token.unsqueeze(1).expand(-1, patch_tokens.shape[1], -1)
            combined = torch.cat([patch_tokens, cls_expanded], dim=-1)  # (B, N, D*2)
            # Normalize
            combined = self.head_norm(combined)
            layer_outputs.append(combined)

        return layer_outputs, patch_h, patch_w

    def _project_features(self, layer_outputs, patch_h, patch_w):
        """Project multi-scale features through DPT projects → DPT refinement."""
        B = layer_outputs[0].shape[0]
        projected = []
        for i, layer_out in enumerate(layer_outputs):
            # Reshape to spatial
            f = layer_out.reshape(B, patch_h, patch_w, -1).permute(0, 3, 1, 2)  # (B, D*2, H_p, W_p)
            f = self.projects[i](f)  # (B, in_channels[i], H_p, W_p)
            projected.append(f)
        return projected

    def _get_view_predictions(self, feats, view_name, query_pixels=None):
        """Apply per-view prediction heads.

        Args:
            feats: (B, 64, pred_size, pred_size)
            view_name: 'agent' or 'wrist'
            query_pixels: (B, N_WINDOW, 2) for indexing gripper/rotation

        Returns:
            volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            gripper_logits: (B, N_WINDOW, 2) or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS) or None
        """
        B = feats.shape[0]
        N = self.n_window
        H = W = self.pred_size

        vol_head = getattr(self, f'{view_name}_volume_head')
        grip_head = getattr(self, f'{view_name}_gripper_head')
        rot_head = getattr(self, f'{view_name}_rotation_head')

        vol = vol_head(feats).view(B, N, N_HEIGHT_BINS, H, W)

        if query_pixels is not None:
            px = query_pixels[..., 0].long().clamp(0, W - 1)
            py = query_pixels[..., 1].long().clamp(0, H - 1)
            batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
            time_idx = torch.arange(N, device=feats.device).view(1, N).expand(B, N)

            grip_map = grip_head(feats).view(B, N, 2, H, W)
            gripper_logits = grip_map[batch_idx, time_idx, :, py, px]  # (B, N, 2)

            rot_map = rot_head(feats).view(B, N, 3, N_ROT_BINS, H, W)
            rotation_logits = rot_map[batch_idx, time_idx, :, :, py, px]  # (B, N, 3, Nr)
        else:
            gripper_logits = rotation_logits = None

        return vol, gripper_logits, rotation_logits

    def predict_at_pixels(self, feats, query_pixels, view_name='agent'):
        """For eval: predict gripper/rotation at specific pixels."""
        _, grip, rot = self._get_view_predictions(feats, view_name, query_pixels)
        return grip, rot

    def forward(self, agent_img, wrist_img=None, start_keypoint_2d=None,
                agent_query_pixels=None, wrist_query_pixels=None):
        """
        Args:
            agent_img:           (B, 3, H, W) agentview image
            wrist_img:           (B, 3, H, W) wrist camera image (optional)
            start_keypoint_2d:   (B, 2) or (2,) current EEF pixel on agentview
            agent_query_pixels:  (B, N_WINDOW, 2) GT pixels on agentview (pred_size space)
            wrist_query_pixels:  (B, N_WINDOW, 2) GT pixels on wrist (pred_size space)

        Returns dict with:
            agent_volume, agent_gripper, agent_rotation, agent_feats
            wrist_volume, wrist_gripper, wrist_rotation, wrist_feats (if wrist_img given)
        """
        B = agent_img.shape[0]
        result = {}

        # --- Agentview ---
        agent_layers, ph, pw = self._extract_features(agent_img)
        agent_projected = self._project_features(agent_layers, ph, pw)
        # DPT refinement (uses layer_rn naming internally)
        l1 = self.dpt.layer1_rn(agent_projected[0])
        l2 = self.dpt.layer2_rn(agent_projected[1])
        l3 = self.dpt.layer3_rn(agent_projected[2])
        l4 = self.dpt.layer4_rn(agent_projected[3])
        r4 = self.dpt.refinenet4(l4)
        r4_up = F.interpolate(r4, size=l3.shape[2:], mode='bilinear', align_corners=False)
        r3 = self.dpt.refinenet3(r4_up, l3)
        r3_up = F.interpolate(r3, size=l2.shape[2:], mode='bilinear', align_corners=False)
        r2 = self.dpt.refinenet2(r3_up, l2)
        agent_feats = F.interpolate(r2, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)

        # Start keypoint conditioning on agentview features
        if start_keypoint_2d is not None:
            if start_keypoint_2d.dim() == 1:
                start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
            skp_x = (start_keypoint_2d[:, 0] * self.pred_size / self.target_size).long().clamp(0, self.pred_size - 1)
            skp_y = (start_keypoint_2d[:, 1] * self.pred_size / self.target_size).long().clamp(0, self.pred_size - 1)
            bi = torch.arange(B, device=agent_feats.device)
            agent_feats[bi, :, skp_y, skp_x] += self.start_keypoint_embedding.unsqueeze(0)

        av, ag, ar = self._get_view_predictions(agent_feats, 'agent', agent_query_pixels)
        result['agent_volume'] = av
        result['agent_gripper'] = ag
        result['agent_rotation'] = ar
        result['agent_feats'] = agent_feats

        # --- Wrist view ---
        if wrist_img is not None:
            wrist_layers, wph, wpw = self._extract_features(wrist_img)
            wrist_projected = self._project_features(wrist_layers, wph, wpw)
            wl1 = self.dpt.layer1_rn(wrist_projected[0])
            wl2 = self.dpt.layer2_rn(wrist_projected[1])
            wl3 = self.dpt.layer3_rn(wrist_projected[2])
            wl4 = self.dpt.layer4_rn(wrist_projected[3])
            wr4 = self.dpt.refinenet4(wl4)
            wr4_up = F.interpolate(wr4, size=wl3.shape[2:], mode='bilinear', align_corners=False)
            wr3 = self.dpt.refinenet3(wr4_up, wl3)
            wr3_up = F.interpolate(wr3, size=wl2.shape[2:], mode='bilinear', align_corners=False)
            wr2 = self.dpt.refinenet2(wr3_up, wl2)
            wrist_feats = F.interpolate(wr2, size=(self.pred_size, self.pred_size), mode='bilinear', align_corners=False)

            wv, wg, wr = self._get_view_predictions(wrist_feats, 'wrist', wrist_query_pixels)
            result['wrist_volume'] = wv
            result['wrist_gripper'] = wg
            result['wrist_rotation'] = wr
            result['wrist_feats'] = wrist_feats

        return result


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DualDA3Predictor(target_size=448, n_window=N_WINDOW)
    model = model.to(device)

    agent_img = torch.randn(2, 3, 448, 448).to(device)
    wrist_img = torch.randn(2, 3, 448, 448).to(device)
    kp = torch.tensor([224.0, 224.0]).to(device)
    aq = torch.zeros(2, N_WINDOW, 2).to(device)
    wq = torch.zeros(2, N_WINDOW, 2).to(device)

    with torch.no_grad():
        out = model(agent_img, wrist_img, start_keypoint_2d=kp,
                    agent_query_pixels=aq, wrist_query_pixels=wq)

    for k, v in out.items():
        if v is not None:
            print(f"{k}: {v.shape}")
