"""Shared bounds + voxel-center + projection helpers for the volume AR model.

Bounds locked-in 2026-05-17 (symmetric around robot base, verified visually):
  X[-0.40, 0.40]  Y[-0.40, 0.40]  Z[0.90, 1.25]
  32 × 32 × 32 voxels  →  cells ~25mm xy, ~11mm z
"""
import numpy as np
import torch

# ── Volume config (single source of truth) ────────────────────────
# smith300 baseBody-frame bounds (per mac agent's per-session EEF stats,
# padded ~2-3cm beyond observed extremes).
# Switch back to LIBERO bounds if needed:  X[-0.4,0.4] Y[-0.4,0.4] Z[0.9,1.25]
X_MIN, X_MAX =  0.05, 0.45
Y_MIN, Y_MAX = -0.25, 0.35
Z_MIN, Z_MAX = -0.05, 0.35
N_X = N_Y = N_Z = 32
N_VOX = N_X * N_Y * N_Z   # 32,768

# ── Other constants for the model + data loader ───────────────────
N_PAST_EEF       = 20
T_FUTURE         = 8     # = window_length, number of timesteps predicted
N_GRIPPER_CLASS  = 1     # BCE (single logit)
N_ROT_BINS       = 32
# Tight rotation bounds — measured from libero_spatial task 0 demos using the corrected
# WXYZ→XYZW quat unpack:  x ∈ [-0.17, 0.0]  y ∈ [-0.07, 0.04]  z ∈ [0, 0.02]
# Padding ~2-3× to leave headroom for OOD scenes.
MIN_ROT          = [-0.40, -0.20, -0.15]
MAX_ROT          = [ 0.20,  0.20,  0.15]
IMAGE_SIZE       = 448


def voxel_centers_world() -> torch.Tensor:
    """(V, 3) tensor of voxel-center world coords, flat-indexed as iz * Ny*Nx + iy*Nx + ix."""
    xs = torch.linspace(X_MIN + (X_MAX - X_MIN) / (2 * N_X),
                        X_MAX - (X_MAX - X_MIN) / (2 * N_X), N_X)
    ys = torch.linspace(Y_MIN + (Y_MAX - Y_MIN) / (2 * N_Y),
                        Y_MAX - (Y_MAX - Y_MIN) / (2 * N_Y), N_Y)
    zs = torch.linspace(Z_MIN + (Z_MAX - Z_MIN) / (2 * N_Z),
                        Z_MAX - (Z_MAX - Z_MIN) / (2 * N_Z), N_Z)
    gx, gy, gz = torch.meshgrid(xs, ys, zs, indexing='ij')
    # ordering: (Nx, Ny, Nz) → flat = iz + iy*Nz + ix*Ny*Nz when contiguous
    # We use ix-major order below; world_to_voxel_idx matches.
    centers = torch.stack([gx, gy, gz], dim=-1).reshape(-1, 3)  # (V, 3)
    return centers


def world_to_voxel_idx(pts_world: torch.Tensor) -> torch.Tensor:
    """(..., 3) world coords → (...,) flat voxel index. Clamps to inside the volume."""
    cell_x = (X_MAX - X_MIN) / N_X
    cell_y = (Y_MAX - Y_MIN) / N_Y
    cell_z = (Z_MAX - Z_MIN) / N_Z
    ix = ((pts_world[..., 0] - X_MIN) / cell_x).long().clamp(0, N_X - 1)
    iy = ((pts_world[..., 1] - Y_MIN) / cell_y).long().clamp(0, N_Y - 1)
    iz = ((pts_world[..., 2] - Z_MIN) / cell_z).long().clamp(0, N_Z - 1)
    # match the meshgrid (ix, iy, iz) ordering above: flat = ix*Ny*Nz + iy*Nz + iz
    return ix * (N_Y * N_Z) + iy * N_Z + iz


def voxel_idx_to_world(idx: torch.Tensor) -> torch.Tensor:
    """Inverse of world_to_voxel_idx — returns voxel center world coord."""
    iz = idx % N_Z
    iy = (idx // N_Z) % N_Y
    ix = idx // (N_Y * N_Z)
    cell_x = (X_MAX - X_MIN) / N_X
    cell_y = (Y_MAX - Y_MIN) / N_Y
    cell_z = (Z_MAX - Z_MIN) / N_Z
    x = X_MIN + (ix.float() + 0.5) * cell_x
    y = Y_MIN + (iy.float() + 0.5) * cell_y
    z = Z_MIN + (iz.float() + 0.5) * cell_z
    return torch.stack([x, y, z], dim=-1)


def world_to_pixel_torch(pts_world: torch.Tensor, world_to_camera: torch.Tensor) -> torch.Tensor:
    """Vectorized world→pixel using a robosuite-style 4×4 world_to_camera matrix.

    pts_world: (B, M, 3)
    world_to_camera: (B, 4, 4) — the matrix returned by get_camera_transform_matrix
                                  (already encodes the agentview flip; pixel coords on
                                  flipud(obs) image, see existing data.py comment.)
    returns: (B, M, 2) (u, v) pixel coords.
    """
    B, M = pts_world.shape[:2]
    ones = torch.ones(B, M, 1, device=pts_world.device, dtype=pts_world.dtype)
    pts_h = torch.cat([pts_world, ones], dim=-1)                        # (B, M, 4)
    proj  = torch.einsum('bij,bmj->bmi', world_to_camera, pts_h)        # (B, M, 4)
    # Empirically verified against robosuite.utils.camera_utils.project_points_from_world_to_camera:
    # proj[..., 0]/z = column index (u), proj[..., 1]/z = row index (v).
    # The world_to_camera matrix already bakes in the flipud convention so no further flip needed.
    z = proj[..., 2:3].clamp(min=1e-3)
    return proj[..., :2] / z                                             # (B, M, 2) = (u, v)


def pixel_to_normalized_grid(pix_uv: torch.Tensor, image_size: int = IMAGE_SIZE) -> torch.Tensor:
    """(B, M, 2) pixel uv → (B, M, 2) normalized in [-1, 1] for F.grid_sample."""
    return pix_uv / (image_size - 1) * 2.0 - 1.0


def sincos_pe_3d(xyz: torch.Tensor, n_bands: int = 6, scale: float = 1.0) -> torch.Tensor:
    """(B, M, 3) → (B, M, 3 * 2 * n_bands). Frequencies geometric in [1, 2^(n_bands-1)]."""
    device = xyz.device
    dtype  = xyz.dtype
    freqs = 2.0 ** torch.arange(n_bands, device=device, dtype=dtype)     # (n_bands,)
    # (B, M, 3, n_bands) = xyz * freqs * scale
    angles = (xyz.unsqueeze(-1) * freqs) * (np.pi * scale)
    out = torch.cat([angles.sin(), angles.cos()], dim=-1)               # (B, M, 3, 2*n_bands)
    return out.flatten(-2)                                              # (B, M, 3 * 2 * n_bands)


PE_DIM = 3 * 2 * 6  # = 36 for n_bands=6
