"""Model for trajectory volume prediction using SmolVLA (LeRobot).

Uses pretrained SmolVLA: image + language go through the VLM; we take the image token
outputs after language-image self-attention, upsample the feature grid to an intermediate
resolution (e.g. 64×64), then a small conv stack (3×3 → 3×3 → 1×1) for volume and gripper
logits, then upsample to target resolution. All parameters are fine-tuned (no freezing).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Lerobot is imported lazily in TrajectoryHeatmapPredictor.__init__ to avoid mutex/hang
# when "from model import ..." runs (lerobot pulls in transformers etc. which can block).

N_WINDOW = 12
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32

# Default pretrained checkpoint (downloaded from HF on first use)
DEFAULT_SMOLVLA_CKPT = "lerobot/smolvla_base"


def _pad_vector(vector, new_dim, device, dtype):
    """Pad state vector to new_dim (batch, dim) -> (batch, new_dim)."""
    if vector.shape[-1] == new_dim:
        return vector
    b = vector.shape[0]
    out = torch.zeros(b, new_dim, dtype=dtype, device=device)
    out[:, : vector.shape[-1]] = vector
    return out


class TrajectoryHeatmapPredictor(nn.Module):
    """Predicts pixel-aligned volume and gripper using SmolVLA image features after VLM attention.

    - Loads pretrained SmolVLA (e.g. lerobot/smolvla_base), uses image + language prefix only.
    - Extracts image token outputs from the VLM, reshapes to 2D, bilinear upsample, then
      same volume/gripper heads as DINO version.
    - Fine-tunes all parameters (no freezing).
    """

    def __init__(
        self,
        target_size=448,
        feature_grid_size=64,
        n_window=N_WINDOW,
        pretrained_ckpt=DEFAULT_SMOLVLA_CKPT,
        freeze_backbone=False,
        max_lang_len=128,
        head_hidden_dim=256,
    ):
        super().__init__()
        # Deferred import so "from model import ..." does not trigger lerobot/transformers init
        from lerobot.policies.smolvla.modeling_smolvla import (
            SmolVLAPolicy,
            resize_with_pad,
            make_att_2d_masks,
        )

        self._resize_with_pad = resize_with_pad
        self._make_att_2d_masks = make_att_2d_masks

        self.target_size = target_size
        self.feature_grid_size = feature_grid_size
        self.n_window = n_window
        self.max_lang_len = max_lang_len

        print(f"Loading SmolVLA from {pretrained_ckpt} ...")
        policy = SmolVLAPolicy.from_pretrained(pretrained_ckpt)

        self.vla_model = policy.model  # VLAFlowMatching
        self.config = policy.config
        # Processor for tokenizing language
        self.processor = self.vla_model.vlm_with_expert.processor
        self.tokenizer = self.processor.tokenizer

        # Unfreeze all parameters for full fine-tuning (unless freeze_backbone)
        if freeze_backbone:
            for p in self.vla_model.parameters():
                p.requires_grad = False
            self.vla_model.eval()
            print("✓ SmolVLA backbone frozen")
        else:
            for p in self.vla_model.parameters():
                p.requires_grad = True
            print("✓ SmolVLA backbone trainable (full fine-tune)")

        # Infer image token layout from a dummy forward
        with torch.no_grad():
            dummy_img = torch.zeros(
                1, 3,
                self.config.resize_imgs_with_padding[0],
                self.config.resize_imgs_with_padding[1],
                device=next(self.vla_model.parameters()).device,
            )
            img_emb = self.vla_model.vlm_with_expert.embed_image(dummy_img)
            self._num_img_embs = img_emb.shape[1]
            self._embed_dim = img_emb.shape[2]
        self._image_start_ix = 2 if self.vla_model.add_image_special_tokens else 0
        self._image_end_ix = self._image_start_ix + self._num_img_embs

        # Spatial grid: assume square layout and row-major token order (standard ViT/SigLIP:
        # token index i = row i // W_p, col i % W_p; top-left to bottom-right).
        s = int(math.sqrt(self._num_img_embs))
        if s * s < self._num_img_embs:
            s += 1
        self._H_p = self._W_p = s
        self._num_spatial = self._H_p * self._W_p
        if self._num_spatial != self._num_img_embs:
            print(f"⚠ SmolVLA image tokens {self._num_img_embs} is not a perfect square; using grid {self._H_p}x{self._W_p} (may truncate/pad)")
        print(f"✓ SmolVLA image tokens: {self._num_img_embs} -> grid {self._H_p}x{self._W_p}, embed_dim={self._embed_dim}")

        # Projection from VLM hidden_size to a common dimension if needed (VLM uses text_config.hidden_size)
        self.embed_dim = self._embed_dim
        # Upsample patch grid to feature_grid_size (e.g. 8×8 -> 64×64), then conv stack (3×3 -> 3×3 -> 1×1)
        self.volume_head = nn.Sequential(
            nn.Conv2d(self.embed_dim, head_hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_hidden_dim, head_hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_hidden_dim, self.n_window * N_HEIGHT_BINS, kernel_size=1),
        )
        self.start_keypoint_embedding = nn.Parameter(torch.randn(self.embed_dim) * 0.02)
        self.gripper_head = nn.Sequential(
            nn.Conv2d(self.embed_dim, head_hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_hidden_dim, head_hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_hidden_dim, self.n_window * N_GRIPPER_BINS, kernel_size=1),
        )
        print(f"✓ Feature grid: patch {self._H_p}x{self._W_p} -> upsample to {self.feature_grid_size}x{self.feature_grid_size}, then 3×3→3×3→1×1 -> upsample to {self.target_size}x{self.target_size}")
        print(f"✓ Volume head: (B, embed, {self.feature_grid_size}, {self.feature_grid_size}) -> (B, {self.n_window}*{N_HEIGHT_BINS}, H, W)")
        print(f"✓ Gripper head: (B, embed, {self.feature_grid_size}, {self.feature_grid_size}) -> (B, {self.n_window}*{N_GRIPPER_BINS}, H, W)")

    def to(self, device):
        super().to(device)
        if hasattr(self, "vla_model"):
            self.vla_model = self.vla_model.to(device)
        return self

    def _preprocess_images(self, x):
        """Resize with padding to SmolVLA size and normalize to [-1, 1] (SigLIP).
        Accepts either [0, 1] or ImageNet-normalized (0.485/0.229 etc.) input.
        """
        # Dataset typically provides ImageNet-normalized; convert to [0, 1]
        if x.min() < -0.5 or x.max() > 1.5:
            mean = torch.tensor([0.485, 0.456, 0.406], device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
            x = (x * std + mean).clamp(0.0, 1.0)
        if self.config.resize_imgs_with_padding is not None:
            w, h = self.config.resize_imgs_with_padding
            x = self._resize_with_pad(x, w, h, pad_value=0)
        x = x * 2.0 - 1.0
        return x

    def _tokenize_task(self, task, device, batch_size):
        """Tokenize task strings to input_ids and attention_mask. task: list of str or single str."""
        if isinstance(task, str):
            task = [task] * batch_size
        # Truncate to max_lang_len
        enc = self.tokenizer(
            task,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_lang_len,
            add_special_tokens=True,
        )
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        return input_ids, attention_mask

    def _get_image_features_after_attention(self, x, task="", state=None):
        """
        Run SmolVLA prefix only (image + language + state), return image token outputs
        after language-image self-attention. Reshape to (B, D, H_p, W_p).
        """
        B = x.shape[0]
        device = x.device
        dtype = next(self.vla_model.parameters()).dtype

        # Preprocess image
        img = self._preprocess_images(x)  # (B, 3, 224, 224), [-1, 1]
        images = [img]
        img_masks = [torch.ones(B, dtype=torch.bool, device=device)]

        # Language
        lang_tokens, lang_masks = self._tokenize_task(task, device, B)

        # State: zeros if not provided
        if state is None:
            state = torch.zeros(B, self.config.max_state_dim, device=device, dtype=torch.float32)
        else:
            state = _pad_vector(state, self.config.max_state_dim, device, torch.float32)

        # Build prefix and run VLM (prefix only)
        prefix_embs, prefix_pad_masks, prefix_att_masks = self.vla_model.embed_prefix(
            images, img_masks, lang_tokens, lang_masks, state=state
        )
        prefix_att_2d = self._make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
        prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1

        (prefix_out, _), _ = self.vla_model.vlm_with_expert.forward(
            attention_mask=prefix_att_2d.bool(),
            position_ids=prefix_position_ids,
            past_key_values=None,
            inputs_embeds=[prefix_embs, None],
            use_cache=True,
            fill_kv_cache=True,
        )
        # prefix_out: (B, prefix_len, D)
        # Slice image tokens
        img_out = prefix_out[:, self._image_start_ix : self._image_end_ix, :]  # (B, num_img_embs, D)
        # Truncate to spatial grid if needed
        if img_out.shape[1] > self._num_spatial:
            img_out = img_out[:, : self._num_spatial, :]
        elif img_out.shape[1] < self._num_spatial:
            pad_len = self._num_spatial - img_out.shape[1]
            img_out = F.pad(img_out, (0, 0, 0, pad_len), value=0)

        # prefix_out: (B, prefix_len, D). Image block is at [image_start_ix : image_end_ix].
        # Unpack to 2D: assume row-major order (token i = row i//W_p, col i%W_p) → (B, H_p, W_p, D).
        # Then (B, D, H_p, W_p) for convs, with H_p = height (y), W_p = width (x).
        img_out = img_out.reshape(B, self._H_p, self._W_p, self._embed_dim)
        patch_features = img_out.permute(0, 3, 1, 2).contiguous()
        return patch_features

    def forward(
        self,
        x,
        gt_target_heatmap=None,
        training=False,
        start_keypoint_2d=None,
        current_height=None,
        current_gripper=None,
        task="",
        state=None,
    ):
        """
        Args:
            x: (B, 3, H, W) RGB in [0, 1]
            start_keypoint_2d: (B, 2) or (2,) optional
            task: str or list of str, language instruction (optional)
            state: (B, state_dim) optional robot state; zeros if None
        Returns:
            volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, H, W)
            gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W)
        """
        B = x.shape[0]
        patch_features = self._get_image_features_after_attention(x, task=task, state=state)
        # VLM outputs bfloat16; our heads are float32 — cast for compatibility
        patch_features = patch_features.float()
        _, D, H_p, W_p = patch_features.shape

        # Upsample patch grid (e.g. 8×8) to feature_grid_size (e.g. 64×64)
        patch_features = F.interpolate(
            patch_features,
            size=(self.feature_grid_size, self.feature_grid_size),
            mode="bilinear",
            align_corners=False,
        )
        Hg, Wg = self.feature_grid_size, self.feature_grid_size

        if start_keypoint_2d is not None:
            if start_keypoint_2d.dim() == 1:
                start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
            start_patch_x = (start_keypoint_2d[:, 0] * Wg / self.target_size).long().clamp(0, Wg - 1)
            start_patch_y = (start_keypoint_2d[:, 1] * Hg / self.target_size).long().clamp(0, Hg - 1)
            batch_indices = torch.arange(B, device=patch_features.device)
            patch_features[batch_indices, :, start_patch_y, start_patch_x] += self.start_keypoint_embedding.unsqueeze(0)

        # Volume head (3×3 → 3×3 → 1×1) on feature grid, then upsample to target_size
        vol = self.volume_head(patch_features)
        vol = vol.view(B, self.n_window, N_HEIGHT_BINS, Hg, Wg)
        volume_logits = F.interpolate(
            vol.view(B, self.n_window * N_HEIGHT_BINS, Hg, Wg),
            size=(self.target_size, self.target_size),
            mode="bilinear",
            align_corners=False,
        )
        volume_logits = volume_logits.view(B, self.n_window, N_HEIGHT_BINS, self.target_size, self.target_size)

        # Gripper head (3×3 → 3×3 → 1×1) on feature grid, then upsample to target_size
        grip = self.gripper_head(patch_features)
        grip = F.interpolate(
            grip,
            size=(self.target_size, self.target_size),
            mode="bilinear",
            align_corners=False,
        )
        gripper_logits = grip.view(B, self.n_window, N_GRIPPER_BINS, self.target_size, self.target_size)

        return volume_logits, gripper_logits


if __name__ == "__main__":
    # Prefer CUDA, then MPS, then CPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    model = TrajectoryHeatmapPredictor(target_size=448, feature_grid_size=64, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    x = torch.rand(2, 3, 448, 448).to(device)
    with torch.no_grad():
        vol, grip = model(x, training=False, start_keypoint_2d=torch.tensor([224.0, 224.0]), task="pick the block")
    print("volume_logits", vol.shape)
    print("gripper_logits", grip.shape)
