"""InternVL VLA — InternVL3.5-1B as vision-language backbone + PARA action heads.

Runs the image + task description through InternVL's full pipeline (vision encoder
+ projector + LLM). Language-conditioned image patch features from the LLM's last
hidden layer are extracted, projected down, upsampled to pred_size, refined with
conv layers, and fed to PARA heads for heatmap + gripper/rotation prediction.

Usage:
    python train.py --model_type internvl --task_ids all --cache_root /data/libero/parsed_libero
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

N_WINDOW       = 6
N_HEIGHT_BINS  = 32
N_GRIPPER_BINS = 32
N_ROT_BINS     = 32
PRED_SIZE      = 64
FEAT_DIM       = 256   # internal feature dim for PARA heads


class InternVLAPredictor(nn.Module):
    """InternVL3.5-1B backbone + PARA-style heatmap heads."""

    def __init__(self, target_size=448, pred_size=PRED_SIZE, n_window=N_WINDOW,
                 freeze_backbone=False, model_name="OpenGVLab/InternVL2_5-1B", **kwargs):
        super().__init__()
        self.target_size = target_size
        self.pred_size   = pred_size
        self.n_window    = n_window
        self.model_type  = "internvl"

        # --- Load VLM ---------------------------------------------------------
        print(f"Loading InternVL model: {model_name}")
        from transformers import AutoModel, AutoTokenizer

        self.vlm = AutoModel.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, trust_remote_code=True,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True,
        )

        if freeze_backbone:
            for p in self.vlm.parameters():
                p.requires_grad = False
            self.vlm.eval()
            print("  Frozen InternVL backbone")
        else:
            print("  InternVL backbone is trainable")

        # --- Detect VLM component attribute names -----------------------------
        # InternVL2.5 (trust_remote_code): vision_model, mlp1, language_model
        # InternVL3.5-HF (native):         vision_model, multi_modal_projector, language_model
        self._vision_enc = self._find_attr('vision_model', 'visual')
        self._projector  = self._find_attr('mlp1', 'multi_modal_projector')
        self._lm         = self._find_attr('language_model')

        # --- Config -----------------------------------------------------------
        cfg = self.vlm.config
        # InternVL2.5 uses llm_config; HF variant uses text_config
        llm_cfg = getattr(cfg, 'llm_config', getattr(cfg, 'text_config', None))
        self.vlm_hidden_dim = llm_cfg.hidden_size
        vis_cfg = cfg.vision_config
        self.vlm_image_size = getattr(vis_cfg, 'image_size', 448)
        self.vlm_patch_size = getattr(vis_cfg, 'patch_size', 14)
        self.downsample_ratio = getattr(cfg, 'downsample_ratio', 0.5)
        print(f"  VLM hidden dim: {self.vlm_hidden_dim}")
        print(f"  Vision: {self.vlm_image_size}px, patch {self.vlm_patch_size}px, downsample {self.downsample_ratio}")

        # Image re-normalization: dataset uses ImageNet; InternVL uses same ImageNet stats
        # (both OpenAI CLIP and InternVL use ImageNet mean/std — keep the buffers for safety)
        self.register_buffer('vlm_mean', torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('vlm_std',  torch.tensor(IMAGENET_STD,  dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('inet_mean', torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('inet_std',  torch.tensor(IMAGENET_STD,  dtype=torch.float32).view(3, 1, 1))

        # --- Feature projection + PARA heads ---------------------------------
        D = FEAT_DIM
        self.embed_dim = D  # used by predict_at_pixels

        self.feat_proj = nn.Sequential(
            nn.Linear(self.vlm_hidden_dim, D),
            nn.GELU(),
        )

        self.feature_convs = nn.Sequential(
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
            nn.Conv2d(D, D, 3, padding=1), nn.GELU(),
        )

        self.start_keypoint_embedding = nn.Parameter(torch.randn(D) * 0.02)

        self.volume_head = nn.Conv2d(D, n_window * N_HEIGHT_BINS, kernel_size=1)
        self.gripper_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, N_GRIPPER_BINS),
        )
        self.rotation_mlp = nn.Sequential(
            nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3 * N_ROT_BINS),
        )

        print(f"  Feature proj: {self.vlm_hidden_dim} -> {D}")
        print(f"  Feature convs: 3x Conv2d(3x3) at pred_size={pred_size}")
        print(f"  Volume   head -> (B, {n_window}, {N_HEIGHT_BINS}, {pred_size}, {pred_size})")
        print(f"  Gripper  MLP  -> (B, {n_window}, {N_GRIPPER_BINS})")
        print(f"  Rotation MLP  -> (B, {n_window}, 3, {N_ROT_BINS})")

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------
    def _find_attr(self, *names):
        """Return the first matching attribute on self.vlm, or raise."""
        for n in names:
            if hasattr(self.vlm, n):
                return getattr(self.vlm, n)
        raise AttributeError(
            f"Cannot find any of {names} on VLM ({type(self.vlm).__name__}). "
            f"Available: {[a for a in dir(self.vlm) if not a.startswith('_')]}"
        )

    def _renormalize_image(self, x):
        """ImageNet-normalized (B,3,H,W) -> VLM-normalized (B,3,H',W')."""
        x_raw = x * self.inet_std + self.inet_mean          # -> [0, 1]
        x_vlm = (x_raw - self.vlm_mean) / self.vlm_std      # -> VLM space
        # Resize to VLM's expected resolution if needed
        if x_vlm.shape[-1] != self.vlm_image_size or x_vlm.shape[-2] != self.vlm_image_size:
            x_vlm = F.interpolate(x_vlm, size=(self.vlm_image_size, self.vlm_image_size),
                                  mode='bilinear', align_corners=False)
        return x_vlm

    # ------------------------------------------------------------------
    # VLM feature extraction
    # ------------------------------------------------------------------
    def _get_image_tokens(self, pixel_values):
        """Vision encoder + projector -> image tokens for LLM.

        Args:
            pixel_values: (B, 3, H, W) VLM-normalized

        Returns:
            image_tokens: (B, N_img, D_lm) in VLM dtype (bf16)
        """
        # InternVL2.5 (trust_remote_code): extract_feature() bundles
        # ViT + CLS removal + pixel shuffle + mlp1 projection — returns ready tokens.
        if hasattr(self.vlm, 'extract_feature'):
            return self.vlm.extract_feature(pixel_values)
        # Fallback for HF-native models: vision_model + projector separately
        vis_out = self._vision_enc(pixel_values)
        vit_features = vis_out.last_hidden_state if hasattr(vis_out, 'last_hidden_state') else vis_out[0]
        return self._projector(vit_features)

    def _extract_vlm_features(self, pixel_values, task_texts):
        """Full VLM pipeline: vision + text -> language-conditioned image features.

        Args:
            pixel_values: (B, 3, H, W) VLM-normalized, bf16 or fp32
            task_texts:   list of B strings

        Returns:
            image_hidden: (B, N_img, D_lm) fp32
            grid_size:    int, spatial side length (sqrt of N_img)
        """
        B = pixel_values.shape[0]
        device = pixel_values.device

        # 1. Get image tokens from vision encoder + projector
        image_tokens = self._get_image_tokens(pixel_values.to(self.vlm.dtype))
        N_img = image_tokens.shape[1]

        # 2. Tokenize task text
        text_inputs = self._tokenizer(
            task_texts, return_tensors="pt", padding=True,
            truncation=True, max_length=64,
        ).to(device)
        text_embeds = self._lm.get_input_embeddings()(text_inputs['input_ids'])
        text_embeds = text_embeds.to(image_tokens.dtype)

        # 3. Merge [image_tokens, text_embeds] and run LLM
        inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)
        attn_mask = torch.cat([
            torch.ones(B, N_img, device=device, dtype=torch.long),
            text_inputs['attention_mask'],
        ], dim=1)

        lm_out = self._lm(
            inputs_embeds=inputs_embeds,
            attention_mask=attn_mask,
            output_hidden_states=True,
        )

        # 4. Extract image features from last hidden layer (first N_img positions)
        hidden = lm_out.hidden_states[-1]              # (B, N_img+T_text, D_lm)
        image_hidden = hidden[:, :N_img, :].float()    # (B, N_img, D_lm), fp32

        # 5. Spatial grid size
        grid_size = int(math.sqrt(N_img))
        if grid_size * grid_size != N_img:
            # Not a perfect square — use ceil and zero-pad
            grid_size = int(math.ceil(math.sqrt(N_img)))
            pad_len = grid_size * grid_size - N_img
            if pad_len > 0:
                padding = torch.zeros(B, pad_len, image_hidden.shape[-1],
                                      device=device, dtype=image_hidden.dtype)
                image_hidden = torch.cat([image_hidden, padding], dim=1)

        return image_hidden, grid_size

    # ------------------------------------------------------------------
    # PARA head helpers (shared with other model variants)
    # ------------------------------------------------------------------
    def _index_features(self, feats, query_pixels):
        B, D, H, W = feats.shape
        N = query_pixels.shape[1]
        px = query_pixels[..., 0].long().clamp(0, W - 1)
        py = query_pixels[..., 1].long().clamp(0, H - 1)
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, N)
        return feats[batch_idx, :, py, px]

    def predict_at_pixels(self, feats, query_pixels):
        """Apply gripper / rotation MLPs at specified pixel locations.

        Args:
            feats:        (B, D, pred_size, pred_size)
            query_pixels: (B, N_WINDOW, 2) in pred_size coords

        Returns:
            gripper_logits:  (B, N_WINDOW, N_GRIPPER_BINS)
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)
        """
        B, N = query_pixels.shape[:2]
        indexed = self._index_features(feats.detach(), query_pixels)  # (B, N, D)
        flat = indexed.reshape(B * N, self.embed_dim)
        gripper  = self.gripper_mlp(flat).reshape(B, N, N_GRIPPER_BINS)
        rotation = self.rotation_mlp(flat).reshape(B, N, 3, N_ROT_BINS)
        return gripper, rotation

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------
    def forward(self, x, start_keypoint_2d, query_pixels=None, task_text=None):
        """
        Args:
            x:                 (B, 3, H, W) ImageNet-normalized
            start_keypoint_2d: (B, 2) or (2,) current EEF pixel in image coords
            query_pixels:      (B, N_WINDOW, 2) in pred_size space (GT during train)
            task_text:         list of B strings — LIBERO task description

        Returns:
            volume_logits:   (B, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
            gripper_logits:  (B, N_WINDOW, N_GRIPPER_BINS)  or None
            rotation_logits: (B, N_WINDOW, 3, N_ROT_BINS)   or None
            feats:           (B, D, pred_size, pred_size)
        """
        B = x.shape[0]

        if task_text is None:
            task_text = ["Pick up the object and place it on the target."] * B

        # --- VLM feature extraction -------------------------------------------
        pixel_values = self._renormalize_image(x)
        image_hidden, grid_size = self._extract_vlm_features(pixel_values, task_text)
        # image_hidden: (B, grid_size^2, D_lm), fp32

        # --- Project to working dim and reshape to spatial grid ---------------
        feat_tokens = self.feat_proj(image_hidden)   # (B, grid_size^2, FEAT_DIM)
        feat_map = feat_tokens.permute(0, 2, 1).reshape(B, self.embed_dim, grid_size, grid_size)

        # --- Start keypoint conditioning --------------------------------------
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        kp_x = (start_keypoint_2d[:, 0] * grid_size / self.target_size).long().clamp(0, grid_size - 1)
        kp_y = (start_keypoint_2d[:, 1] * grid_size / self.target_size).long().clamp(0, grid_size - 1)
        batch_idx = torch.arange(B, device=feat_map.device)
        feat_map[batch_idx, :, kp_y, kp_x] += self.start_keypoint_embedding.unsqueeze(0)

        # --- Upsample + refine ------------------------------------------------
        feats = F.interpolate(feat_map, size=(self.pred_size, self.pred_size),
                              mode='bilinear', align_corners=False)
        feats = self.feature_convs(feats)  # (B, D, pred_size, pred_size)

        # --- PARA heads -------------------------------------------------------
        vol = self.volume_head(feats)
        volume_logits = vol.view(B, self.n_window, N_HEIGHT_BINS, self.pred_size, self.pred_size)

        if query_pixels is not None:
            gripper_logits, rotation_logits = self.predict_at_pixels(feats, query_pixels)
        else:
            gripper_logits = rotation_logits = None

        return volume_logits, gripper_logits, rotation_logits, feats


# ---------------------------------------------------------------------------
# Quick smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = InternVLAPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)

    x = torch.randn(2, 3, 448, 448).to(device)
    fake_query = torch.zeros(2, N_WINDOW, 2).to(device)
    task_texts = [
        "pick up the black bowl between the plate and the ramekin and place it on the plate",
        "pick up the black bowl on the stove and place it on the plate",
    ]

    with torch.no_grad():
        vol, grip, rot, feats = model(
            x,
            start_keypoint_2d=torch.tensor([224.0, 224.0]).to(device),
            query_pixels=fake_query,
            task_text=task_texts,
        )
    print("volume_logits  ", vol.shape)
    print("gripper_logits ", grip.shape)
    print("rotation_logits", rot.shape)
    print("feats          ", feats.shape)
