"""Cost-volume PARA: fused third-view + wrist-view features for heatmap prediction.

For each pixel (u,v) in the agentview and each height bin h:
  1. Project the 3D cell onto the wrist camera, sample 16-dim wrist features
  2. Get 16-dim agentview features at (u,v)
  3. Get 16-dim learned height embedding for bin h
  4. Concatenate → 48-dim → 3-layer MLP → per-timestep heatmap score

Gripper/rotation: at the selected 3D cell (GT during train, argmax during eval),
extract the same 48-dim fused feature and map through 2-layer MLPs.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16

N_WINDOW = 4
N_HEIGHT_BINS = 32
N_ROT_BINS = 32
PRED_SIZE = 64
FEAT_DIM = 16


class CostVolumePredictor(nn.Module):

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 n_height_bins=N_HEIGHT_BINS, feat_dim=FEAT_DIM,
                 freeze_backbone=False, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size = pred_size
        self.n_window = n_window
        self.n_height_bins = n_height_bins
        self.patch_size = DINO_PATCH_SIZE
        self.feat_dim = feat_dim
        self.model_type = "cost_volume"

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()

        self.embed_dim = self.dino.embed_dim
        D = self.embed_dim

        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)

        # Agentview: D → D → feat_dim
        self.agent_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, feat_dim, 3, padding=1), nn.GELU(),
        )

        # Wrist: D → feat_dim
        self.wrist_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D // 2, 3, padding=1), nn.GELU(),
            nn.Conv2d(D // 2, feat_dim, 3, padding=1), nn.GELU(),
        )

        # Learned per-height-bin embedding
        self.height_embeddings = nn.Parameter(torch.randn(n_height_bins, feat_dim) * 0.02)

        # Volume MLP: concat(agent, wrist, height) = 3*feat_dim → N_WINDOW scores
        fused_dim = feat_dim * 3
        self.volume_mlp = nn.Sequential(
            nn.Conv2d(fused_dim, fused_dim // 2, 1), nn.GELU(),
            nn.Conv2d(fused_dim // 2, feat_dim, 1), nn.GELU(),
            nn.Conv2d(feat_dim, n_window, 1),
        )

        # Gripper MLP: fused_dim → 2 per timestep (applied per-timestep, not per-pixel)
        self.gripper_mlp = nn.Sequential(
            nn.Linear(fused_dim, fused_dim // 2), nn.GELU(),
            nn.Linear(fused_dim // 2, 2),
        )

        # No rotation prediction — zero rotation at eval

        n_total = sum(p.numel() for p in self.parameters())
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"✓ CostVolume model: {n_trainable:,} / {n_total:,} trainable params")
        print(f"  Feature dim: {feat_dim}, fused: {fused_dim}")
        print(f"  Volume MLP: {fused_dim} → {n_window} per cell")
        print(f"  Gripper MLP: {fused_dim} → 2")
        print(f"  No rotation prediction")

    def to(self, device):
        super().to(device)
        self.dino = self.dino.to(device)
        return self

    def _extract_features(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
        else:
            patches = self.dino.norm(x_tokens)[:, self.dino.n_storage_tokens + 1:]
        patches = patches.reshape(B, H_p, W_p, self.embed_dim).permute(0, 3, 1, 2).contiguous()
        return patches

    def _sample_wrist_features(self, wrist_feats, agent_cam_pose, agent_cam_K,
                                wrist_cam_pose, wrist_cam_K, height_bins):
        """Sample wrist features for each (u,v,h) cell in agentview. Fully vectorized.

        Returns: (B, Nh, feat_dim, H, W) — wrist features at each 3D cell
        """
        B, C, H, W = wrist_feats.shape
        device = wrist_feats.device
        Nh = len(height_bins)
        HW = H * W
        scale = self.target_size / self.pred_size

        # Agentview pixel grid in image coordinates
        ys, xs = torch.meshgrid(
            torch.arange(H, device=device, dtype=torch.float32),
            torch.arange(W, device=device, dtype=torch.float32),
            indexing='ij'
        )
        xs_img = (xs + 0.5) * scale
        ys_img = (ys + 0.5) * scale
        ones = torch.ones_like(xs_img)
        pix_h = torch.stack([xs_img, ys_img, ones], dim=-1).reshape(HW, 3)  # (HW, 3)

        # Unproject all pixels to world-space rays: (B, HW, 3)
        agent_K_inv = torch.inverse(agent_cam_K)  # (B, 3, 3)
        rays_cam = torch.einsum('bij,nj->bni', agent_K_inv, pix_h)  # (B, HW, 3)
        agent_R = agent_cam_pose[:, :3, :3]  # (B, 3, 3)
        rays_world = torch.einsum('bij,bnj->bni', agent_R, rays_cam)  # (B, HW, 3)
        cam_pos = agent_cam_pose[:, :3, 3]  # (B, 3)

        # Ray-plane intersection for all height bins at once
        # t[b, h, n] = (height[h] - cam_pos[b, 2]) / rays_world[b, n, 2]
        ray_z = rays_world[:, :, 2]  # (B, HW)
        heights = height_bins.view(1, Nh, 1)  # (1, Nh, 1)
        cam_z = cam_pos[:, 2].view(B, 1, 1)  # (B, 1, 1)
        t = (heights - cam_z) / (ray_z.unsqueeze(1) + 1e-8)  # (B, Nh, HW)
        valid = t > 0  # (B, Nh, HW)

        # 3D points: cam_pos + t * rays_world
        # cam_pos: (B, 1, 1, 3), t: (B, Nh, HW, 1), rays: (B, 1, HW, 3)
        points_3d = cam_pos.view(B, 1, 1, 3) + t.unsqueeze(-1) * rays_world.unsqueeze(1)  # (B, Nh, HW, 3)

        # Project onto wrist camera
        wrist_R_inv = wrist_cam_pose[:, :3, :3].transpose(1, 2)  # (B, 3, 3)
        wrist_t = wrist_cam_pose[:, :3, 3]  # (B, 3)
        # Transform to wrist camera frame
        pts_centered = points_3d - wrist_t.view(B, 1, 1, 3)  # (B, Nh, HW, 3)
        pts_flat = pts_centered.reshape(B, Nh * HW, 3)  # (B, Nh*HW, 3)
        pts_cam = torch.einsum('bij,bnj->bni', wrist_R_inv, pts_flat)  # (B, Nh*HW, 3)
        pts_cam = pts_cam.reshape(B, Nh, HW, 3)

        # Perspective projection
        z = pts_cam[:, :, :, 2].clamp(min=1e-4)  # (B, Nh, HW)
        fx = wrist_cam_K[:, 0, 0].view(B, 1, 1)
        fy = wrist_cam_K[:, 1, 1].view(B, 1, 1)
        cx = wrist_cam_K[:, 0, 2].view(B, 1, 1)
        cy = wrist_cam_K[:, 1, 2].view(B, 1, 1)
        u_w = fx * pts_cam[:, :, :, 0] / z + cx  # (B, Nh, HW)
        v_w = fy * pts_cam[:, :, :, 1] / z + cy

        # Normalize to grid_sample coordinates [-1, 1]
        grid_x = 2.0 * u_w / (self.target_size - 1) - 1.0
        grid_y = 2.0 * v_w / (self.target_size - 1) - 1.0
        # Invalidate behind-camera points (grid_sample zeros_padding handles out-of-bounds)
        grid_x = torch.where(valid, grid_x, torch.full_like(grid_x, 10.0))
        grid_y = torch.where(valid, grid_y, torch.full_like(grid_y, 10.0))

        # Reshape grid: (B*Nh, H, W, 2) for grid_sample
        grid = torch.stack([grid_x, grid_y], dim=-1).reshape(B * Nh, H, W, 2)

        # Expand wrist features: (B, C, H, W) → (B*Nh, C, H, W)
        wrist_exp = wrist_feats.unsqueeze(1).expand(B, Nh, C, H, W).reshape(B * Nh, C, H, W)

        # Single grid_sample call for all batch elements and height bins
        sampled = F.grid_sample(wrist_exp, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
        # (B*Nh, C, H, W) → (B, Nh, C, H, W)
        return sampled.reshape(B, Nh, C, H, W)

    def _build_fused_volume(self, agent_feats_16, wrist_sampled):
        """Build fused feature volume: concat(agent, wrist, height) at each cell.

        Args:
            agent_feats_16: (B, feat_dim, H, W)
            wrist_sampled:  (B, Nh, feat_dim, H, W)

        Returns:
            fused: (B, Nh, 3*feat_dim, H, W)
        """
        B, C, H, W = agent_feats_16.shape
        Nh = self.n_height_bins

        # Expand agent features across height bins
        agent_exp = agent_feats_16.unsqueeze(1).expand(B, Nh, C, H, W)

        # Expand height embeddings across spatial dims
        h_emb = self.height_embeddings.view(1, Nh, C, 1, 1).expand(B, Nh, C, H, W)

        # Concatenate: (B, Nh, 3C, H, W)
        fused = torch.cat([agent_exp, wrist_sampled, h_emb], dim=2)
        return fused

    def predict_at_pixels(self, fused_volume, query_pixels, query_height_bins):
        """Extract fused features at specific (u, v, h) cells and predict gripper/rotation.

        Args:
            fused_volume:     (B, Nh, 3*feat_dim, H, W)
            query_pixels:     (B, N_WINDOW, 2) pixel coords in pred_size space
            query_height_bins:(B, N_WINDOW) height bin indices

        Returns:
            gripper_logits:  (B, N_WINDOW, 2)
            rotation_pred:   (B, N_WINDOW, 3) sigmoid [0,1] delta axis-angle
        """
        B, Nh, C3, H, W = fused_volume.shape
        N = query_pixels.shape[1]
        device = fused_volume.device

        px = query_pixels[..., 0].long().clamp(0, W - 1)
        py = query_pixels[..., 1].long().clamp(0, H - 1)
        hb = query_height_bins.long().clamp(0, Nh - 1)

        batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)

        feats = fused_volume[batch_idx, hb, :, py, px]  # (B, N, 3*feat_dim)

        flat = feats.reshape(B * N, C3)
        gripper = self.gripper_mlp(flat).reshape(B, N, 2)
        rotation = None  # no rotation prediction

        return gripper, rotation

    def forward(self, agent_img, wrist_img=None, start_keypoint_2d=None,
                agent_query_pixels=None, agent_query_height_bins=None,
                agent_cam_pose=None, agent_cam_K_norm=None,
                wrist_cam_pose=None, wrist_cam_K_norm=None):
        """
        Returns:
            volume_logits:   (B, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            gripper_logits:  (B, N_WINDOW, 2) or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS) or None
            fused_volume:    (B, Nh, 3*feat_dim, H, W) — for predict_at_pixels at eval
        """
        B = agent_img.shape[0]
        device = agent_img.device

        # Extract backbone features
        agent_patches = self._extract_features(agent_img)
        _, D, H_p, W_p = agent_patches.shape

        # Start keypoint conditioning
        if start_keypoint_2d is not None:
            if start_keypoint_2d.dim() == 1:
                start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
            skx = (start_keypoint_2d[:, 0] * W_p / self.target_size).long().clamp(0, W_p - 1)
            sky = (start_keypoint_2d[:, 1] * H_p / self.target_size).long().clamp(0, H_p - 1)
            bi = torch.arange(B, device=device)
            agent_patches[bi, :, sky, skx] += self.start_keypoint_embedding.unsqueeze(0)

        # Upsample + refine to feat_dim
        agent_feats = F.interpolate(agent_patches, size=(self.pred_size, self.pred_size),
                                    mode='bilinear', align_corners=False)
        agent_feats_16 = self.agent_convs(agent_feats)  # (B, feat_dim, 64, 64)

        if wrist_img is not None and agent_cam_pose is not None:
            wrist_patches = self._extract_features(wrist_img)
            wrist_feats = F.interpolate(wrist_patches, size=(self.pred_size, self.pred_size),
                                        mode='bilinear', align_corners=False)
            wrist_feats_16 = self.wrist_convs(wrist_feats)  # (B, feat_dim, 64, 64)

            # Unnormalize intrinsics
            agent_K = agent_cam_K_norm.clone()
            agent_K[:, 0] *= self.target_size
            agent_K[:, 1] *= self.target_size
            wrist_K = wrist_cam_K_norm.clone()
            wrist_K[:, 0] *= self.target_size
            wrist_K[:, 1] *= self.target_size

            import model as model_module
            height_bins = torch.linspace(model_module.MIN_HEIGHT, model_module.MAX_HEIGHT,
                                         self.n_height_bins, device=device)

            # Sample wrist features at each 3D cell
            wrist_sampled = self._sample_wrist_features(
                wrist_feats_16, agent_cam_pose, agent_K, wrist_cam_pose, wrist_K, height_bins
            )  # (B, Nh, feat_dim, H, W)

            # Build fused volume
            fused = self._build_fused_volume(agent_feats_16, wrist_sampled)  # (B, Nh, 3*feat_dim, H, W)

            # Volume MLP: fused → per-timestep scores
            Nh = self.n_height_bins
            H = W = self.pred_size
            fused_flat = fused.reshape(B * Nh, self.feat_dim * 3, H, W)
            vol_flat = self.volume_mlp(fused_flat)  # (B*Nh, N_WINDOW, H, W)
            vol = vol_flat.reshape(B, Nh, self.n_window, H, W)
            volume_logits = vol.permute(0, 2, 1, 3, 4).contiguous()  # (B, N_WINDOW, Nh, H, W)

            # Gripper/rotation at query pixels
            if agent_query_pixels is not None and agent_query_height_bins is not None:
                gripper_logits, rotation_logits = self.predict_at_pixels(
                    fused, agent_query_pixels, agent_query_height_bins
                )
            else:
                gripper_logits = rotation_logits = None

            return volume_logits, gripper_logits, rotation_logits, fused
        else:
            # Fallback without wrist (shouldn't happen in practice)
            volume_logits = torch.zeros(B, self.n_window, self.n_height_bins,
                                        self.pred_size, self.pred_size, device=device)
            return volume_logits, None, None, None


if __name__ == "__main__":
    import model as model_module
    model_module.MIN_HEIGHT = 0.91
    model_module.MAX_HEIGHT = 1.20

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = CostVolumePredictor(target_size=448, n_window=N_WINDOW)
    m = m.to(device)

    B = 2
    a = torch.randn(B, 3, 448, 448).to(device)
    w = torch.randn(B, 3, 448, 448).to(device)
    kp = torch.tensor([[224.0, 224.0]] * B).to(device)
    qp = torch.zeros(B, N_WINDOW, 2).to(device)
    qh = torch.zeros(B, N_WINDOW).long().to(device)

    cam = torch.eye(4).unsqueeze(0).expand(B, -1, -1).to(device).float()
    cam[:, 0, 3] = 0.65; cam[:, 2, 3] = 1.6
    K = torch.tensor([[0.69, 0, 0.5], [0, 0.69, 0.5], [0, 0, 1.0]]).unsqueeze(0).expand(B, -1, -1).to(device).float()

    with torch.no_grad():
        vol, grip, rot, fused = m(a, w, start_keypoint_2d=kp,
                                   agent_query_pixels=qp, agent_query_height_bins=qh,
                                   agent_cam_pose=cam, agent_cam_K_norm=K,
                                   wrist_cam_pose=cam, wrist_cam_K_norm=K)
    print(f"volume:   {vol.shape}")
    print(f"gripper:  {grip.shape}")
    print(f"rotation: {rot.shape}")
    print(f"fused:    {fused.shape}")
