"""DinoVolumeQuery2View with a dual-frustum sampling strategy.

Two volumes of world XYZ points are sampled:
  1. BEV-anchored: world XYZ at each (z_bin, y_grid_bev, x_grid_bev) — current approach
  2. Wrist-anchored: world XYZ at each (z_bin, y_grid_wrist, x_grid_wrist) — NEW

Each volume's voxels are scored against BOTH views (direct lookup for the home view,
grid_sample projection for the other). The two volumes are stacked along an "anchor" axis
so the CE loss is over the union of sample points.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from model_dino_volume_query_2view import (
    DinoVolumeQuery2View,
    project_world_to_wrist_uv_grid, build_bev_world_xyz_table_batched,
)


def build_wrist_world_xyz_table_batched(K_norm_wrist, wrist_extrinsic, n_height_bins,
                                          min_height, max_height, H, W, image_size):
    """Same as build_bev_world_xyz_table_batched but rays come from the wrist camera."""
    device = K_norm_wrist.device
    B = K_norm_wrist.shape[0]
    K = K_norm_wrist.clone()
    K[:, 0] *= float(image_size); K[:, 1] *= float(image_size)
    scale = float(image_size) / float(H)
    ys = (torch.arange(H, device=device).float() + 0.5) * scale
    xs = (torch.arange(W, device=device).float() + 0.5) * scale
    grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy')
    ones = torch.ones_like(grid_x)
    uv1 = torch.stack([grid_x, grid_y, ones], dim=-1).reshape(-1, 3).T
    K_inv = torch.inverse(K)
    rays_cam = (K_inv @ uv1.unsqueeze(0).expand(B, -1, -1)).transpose(1, 2)
    rays_cam = rays_cam.reshape(B, H, W, 3)
    R_cw = wrist_extrinsic[:, :3, :3]
    t_cw = wrist_extrinsic[:, :3, 3]
    rays_world = torch.einsum('bij,bhwj->bhwi', R_cw, rays_cam)

    heights = torch.linspace(min_height, max_height, n_height_bins, device=device)
    rwz = rays_world[..., 2].unsqueeze(1)
    rwz_safe = torch.where(rwz.abs() < 1e-6, torch.full_like(rwz, -1e-6), rwz)
    s = (heights.view(1, n_height_bins, 1, 1) - t_cw[:, 2].view(B, 1, 1, 1)) / rwz_safe
    xyz = t_cw.view(B, 1, 1, 1, 3) + s.unsqueeze(-1) * rays_world.unsqueeze(1)
    return xyz


class DinoVolumeQuery2ViewDualFrustum(DinoVolumeQuery2View):
    """Adds a wrist-anchored second volume alongside the BEV-anchored one.

    Output:
      volume_logits: (B, T, Z, 2, H, W) — stacked [bev-anchored, wrist-anchored] along axis 3.
      Trainer should flatten to (B, T, Z*2*H*W) for CE, with the GT label re-encoded to
      always live in the BEV-anchor slot (anchor=0).
    """

    def forward(self, rgb_bev, rgb_wrist, start_pix_bev,
                 bev_xyz_table, wrist_K_norm, wrist_extrinsic,
                 bev_K_norm=None, bev_extrinsic=None, wrist_xyz_table=None):
        B = rgb_bev.shape[0]
        T, Z, d = self.n_window, self.n_height_bins, self.d_model

        cls_bev,   patches_bev,   (Hp_b, Wp_b) = self._extract_dino_tokens(rgb_bev)
        cls_wrist, patches_wrist, (Hp_w, Wp_w) = self._extract_dino_tokens(rgb_wrist)
        n_p = Hp_b * Wp_b
        assert (Hp_b, Wp_b) == (Hp_w, Wp_w)

        view_ids = torch.cat([
            torch.zeros(n_p, dtype=torch.long, device=rgb_bev.device),
            torch.ones (n_p, dtype=torch.long, device=rgb_bev.device),
        ])
        view_e = self.view_emb(view_ids).unsqueeze(0)
        joint = torch.cat([patches_bev, patches_wrist], dim=1) + view_e
        joint = self.cross_attn(joint)
        pat_bev_x   = joint[:, :n_p]
        pat_wrist_x = joint[:, n_p:]

        def _to_grid(p):
            return p.reshape(B, Hp_b, Wp_b, self.embed_dim).permute(0, 3, 1, 2).contiguous()
        feat_bev_up   = F.interpolate(_to_grid(pat_bev_x),   size=(self.pred_size, self.pred_size),
                                       mode='bilinear', align_corners=False)
        feat_wrist_up = F.interpolate(_to_grid(pat_wrist_x), size=(self.pred_size, self.pred_size),
                                       mode='bilinear', align_corners=False)
        F_bev   = self.refine_bev  (feat_bev_up)
        F_wrist = self.refine_wrist(feat_wrist_up)
        H, W = F_bev.shape[-2:]

        sx = (start_pix_bev[..., 0] * (W / self.image_size)).long().clamp(0, W - 1)
        sy = (start_pix_bev[..., 1] * (H / self.image_size)).long().clamp(0, H - 1)
        b_idx = torch.arange(B, device=rgb_bev.device)
        eef_feat = F_bev[b_idx, :, sy, sx]
        q_in = self.input_proj(torch.cat([eef_feat, cls_bev, cls_wrist], dim=-1))
        q_in_bt = q_in.unsqueeze(1).expand(B, T, d).reshape(B * T, d)

        cond_t  = self.t_cond_proj(self.t_sin)
        cond_bt = cond_t.unsqueeze(0).expand(B, T, -1).reshape(B * T, -1)

        h = q_in_bt
        for blk in self.blocks:
            h = blk(h, cond_bt)
        h = self.final_norm(h)
        penult = h.view(B, T, d)

        q_spatial = self.q_head(penult)
        gripper   = self.grip_head(penult)
        if self.rotation_mode == 'per_axis':
            rotation = self.rot_head(penult).view(B, T, 3, self.n_rot_bins)
        else:
            rotation = self.rot_head(penult)

        d_F, d_z, d_t = self.d_feat, self.d_sin_z, self.d_sin_t
        q_F_bev   = q_spatial[..., :d_F]
        q_F_wrist = q_spatial[..., d_F:2 * d_F]
        q_z       = q_spatial[..., 2 * d_F:2 * d_F + d_z]
        q_t       = q_spatial[..., 2 * d_F + d_z:]

        score_z = torch.einsum('btc, zc -> btz', q_z, self.z_sin)
        score_t = torch.einsum('btc, tc -> bt',  q_t, self.t_sin)
        z_term = score_z.unsqueeze(-1).unsqueeze(-1)
        t_term = score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        # --- VOLUME 1: BEV-anchored ---
        score_bev_yx = torch.einsum('btc, bchw -> bthw', q_F_bev, F_bev)
        with torch.no_grad():
            uv_wrist_grid = project_world_to_wrist_uv_grid(
                bev_xyz_table, wrist_K_norm, wrist_extrinsic, self.image_size
            )
        Bv, Zv, Hv, Wv, _ = uv_wrist_grid.shape
        F_w_sampled_for_bev = F.grid_sample(
            F_wrist, uv_wrist_grid.view(Bv, Zv * Hv, Wv, 2),
            mode='bilinear', padding_mode='zeros', align_corners=True,
        ).view(Bv, d_F, Zv, Hv, Wv)
        score_wrist_on_bev = torch.einsum('btc, bczhw -> btzhw', q_F_wrist, F_w_sampled_for_bev)
        vol_bev = score_bev_yx.unsqueeze(2) + score_wrist_on_bev + z_term + t_term

        # --- VOLUME 2: Wrist-anchored ---
        # Each (z, y_w, x_w) voxel lives along a wrist ray, so F_wrist sampled at the same
        # (y_w, x_w) regardless of z. Just broadcast.
        score_wrist_yx = torch.einsum('btc, bchw -> bthw', q_F_wrist, F_wrist)
        # Project wrist-volume world XYZ → BEV uv to sample F_bev
        with torch.no_grad():
            uv_bev_grid = project_world_to_wrist_uv_grid(
                wrist_xyz_table, bev_K_norm, bev_extrinsic, self.image_size
            )
        F_b_sampled_for_wrist = F.grid_sample(
            F_bev, uv_bev_grid.view(Bv, Zv * Hv, Wv, 2),
            mode='bilinear', padding_mode='zeros', align_corners=True,
        ).view(Bv, d_F, Zv, Hv, Wv)
        score_bev_on_wrist = torch.einsum('btc, bczhw -> btzhw', q_F_bev, F_b_sampled_for_wrist)
        vol_wrist = score_wrist_yx.unsqueeze(2) + score_bev_on_wrist + z_term + t_term

        volume_logits = torch.stack([vol_bev, vol_wrist], dim=3)  # (B, T, Z, 2, H, W)

        return {
            "volume_logits":     volume_logits,
            "vol_bev":           vol_bev,
            "vol_wrist":         vol_wrist,
            "gripper_logits":    gripper,
            "rotation_logits":   rotation,
            "pixel_feats":       F_bev,
            "pixel_feats_wrist": F_wrist,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQuery2ViewDualFrustum(n_window=8, image_size=448, rotation_mode='1d_pca').to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    B, IMG = 2, 448
    rgb_bev = torch.rand(B, 3, IMG, IMG).to(device)
    rgb_wrist = torch.rand(B, 3, IMG, IMG).to(device)
    sp = torch.rand(B, 2).to(device) * IMG
    bev_K = torch.tensor([[[0.6, 0, 0.5], [0, 0.6, 0.5], [0, 0, 1]]] * B, dtype=torch.float32).to(device)
    bev_ext = torch.eye(4).unsqueeze(0).repeat(B, 1, 1).to(device); bev_ext[:, 2, 3] = 1.5
    wrist_K = bev_K.clone()
    wrist_ext = bev_ext.clone(); wrist_ext[:, 2, 3] = 1.2
    bev_xyz = build_bev_world_xyz_table_batched(bev_K, bev_ext, 32, 0.9, 1.2, 56, 56, IMG)
    wrist_xyz = build_wrist_world_xyz_table_batched(wrist_K, wrist_ext, 32, 0.9, 1.2, 56, 56, IMG)
    with torch.no_grad():
        out = m(rgb_bev, rgb_wrist, sp, bev_xyz, wrist_K, wrist_ext, bev_K, bev_ext, wrist_xyz)
    for k, v in out.items():
        if v is not None: print(f"  {k}: {tuple(v.shape)}")
