"""PARA head: map MAR decoder tokens to per-pixel ray (height-bin) logits.

Takes (B, T, S, C) decoder tokens, reshapes to (B, T, H_lat, W_lat, C),
bilinearly upsamples with convs to (B, T, n_bins, H_out, W_out).
"""
import torch
import torch.nn as nn


class ParaHead(nn.Module):
    def __init__(
        self,
        decoder_embed_dim: int,
        n_bins: int = 32,
        in_grid_size: int = 16,
        out_size: int = 64,
    ):
        super().__init__()
        self.in_grid = in_grid_size
        self.out_size = out_size
        self.n_bins = n_bins
        mid = decoder_embed_dim

        self.conv_head = nn.Sequential(
            nn.Conv2d(decoder_embed_dim, mid, 3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(mid, mid, 3, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(mid, mid, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(mid, n_bins, 1),
        )

    def forward(self, dec_tokens):
        """
        dec_tokens: (B, T, S, C) where S = H_lat*W_lat (e.g. 256)
        Returns volume_logits: (B, T, n_bins, H_out, W_out)
        """
        B, T, S, C = dec_tokens.shape
        H_lat = W_lat = self.in_grid
        x = dec_tokens.view(B, T, H_lat, W_lat, C).permute(0, 1, 4, 2, 3)
        x = x.reshape(B * T, C, H_lat, W_lat)
        x = self.conv_head(x)
        _, n_bins, H_out, W_out = x.shape
        x = x.view(B, T, n_bins, H_out, W_out)
        return x
