"""DinoVolumeQuery with InternVL VLM trunk instead of DINOv3.

Take the same query-MLP architecture (refine + EEF query + AdaLN blocks + volume scoring)
but swap the DINOv3 patch extractor for InternVL3.5-1B's language-conditioned vision pipeline.

Forward: image + task prompt → full VLM forward → extract image hidden states from the LLM's
last layer at the image-token positions → reshape to (B, D_lm, grid, grid) → projected down
to (D_dino_equiv, grid, grid) → existing query-MLP pipeline.

VLM is frozen by default (otherwise training is huge memory). Image features are cached not.
"""
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from model_dino_volume_query import (
    DinoVolumeQuery,
    N_WINDOW, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS,
    D_FEAT, D_SINZ, D_SINT, IMG_SIZE, PRED_SIZE,
)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


class DinoVolumeQueryVLM(DinoVolumeQuery):
    """1view query-MLP with InternVL trunk instead of DINOv3."""

    def __init__(self, *args,
                 vlm_model_name="OpenGVLab/InternVL2_5-1B",
                 freeze_vlm=False,
                 freeze_llm_only=False,
                 task_text_default="pick the black bowl between the plate and the ramekin and place it on the plate",
                 **kwargs):
        super().__init__(*args, **kwargs)
        # Override DINO backbone — we don't use it
        del self.dino
        self._task_text_default = task_text_default

        # Load VLM
        from transformers import AutoModel, AutoTokenizer
        print(f"Loading VLM: {vlm_model_name}")
        self.vlm = AutoModel.from_pretrained(
            vlm_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(
            vlm_model_name, trust_remote_code=True,
        )
        if freeze_vlm:
            for p in self.vlm.parameters():
                p.requires_grad = False
            self.vlm.eval()
            print("  VLM fully frozen")
        elif freeze_llm_only:
            # Freeze LLM, keep vision encoder + projector trainable
            for p in self.vlm.parameters():
                p.requires_grad = False
            # The LM stays frozen; we'll re-enable vision_model and projector below
            print("  VLM LLM frozen, vision + projector trainable")
        else:
            print("  VLM fully trainable")

        # Detect components (InternVL2.5 has these attribute names)
        self._vision_enc = self._find_attr('vision_model', 'visual')
        self._projector  = self._find_attr('mlp1', 'multi_modal_projector')
        self._lm         = self._find_attr('language_model')
        if freeze_llm_only and not freeze_vlm:
            # Unfreeze vision encoder + projector
            for p in self._vision_enc.parameters():
                p.requires_grad = True
            for p in self._projector.parameters():
                p.requires_grad = True

        cfg = self.vlm.config
        llm_cfg = getattr(cfg, 'llm_config', getattr(cfg, 'text_config', None))
        self.vlm_hidden_dim = llm_cfg.hidden_size
        vis_cfg = cfg.vision_config
        self.vlm_image_size = getattr(vis_cfg, 'image_size', 448)
        self.vlm_patch_size = getattr(vis_cfg, 'patch_size', 14)
        self.downsample_ratio = getattr(cfg, 'downsample_ratio', 0.5)
        print(f"  VLM hidden dim: {self.vlm_hidden_dim}, image_size={self.vlm_image_size}")

        # Project VLM hidden → DINO embed dim so the rest of the architecture is unchanged
        self.vlm_proj = nn.Linear(self.vlm_hidden_dim, self.embed_dim)

    def _find_attr(self, *names):
        for n in names:
            if hasattr(self.vlm, n):
                return getattr(self.vlm, n)
        raise AttributeError(f"VLM has none of {names}")

    def _renormalize_image(self, x):
        """ImageNet [B,3,H,W] → [B,3,vlm_size,vlm_size] (both use ImageNet stats; just resize)."""
        if x.shape[-1] != self.vlm_image_size or x.shape[-2] != self.vlm_image_size:
            x = F.interpolate(x, size=(self.vlm_image_size, self.vlm_image_size),
                              mode='bilinear', align_corners=False)
        return x

    def _get_image_tokens(self, pixel_values):
        """Vision encoder + projector → image tokens for LLM."""
        if hasattr(self.vlm, 'extract_feature'):
            return self.vlm.extract_feature(pixel_values)
        vis_out = self._vision_enc(pixel_values)
        feat = vis_out.last_hidden_state if hasattr(vis_out, 'last_hidden_state') else vis_out[0]
        return self._projector(feat)

    def _extract_vlm_features(self, x, task_texts=None):
        """Run vision + LLM forward, return image hidden states from last layer + grid size.

        Args:
            x: (B, 3, H, W) ImageNet-normalized.
            task_texts: list of B strings (or None to use default).

        Returns:
            patch: (B, vlm_hidden_dim, grid, grid)
            cls:   (B, vlm_hidden_dim)
        """
        B = x.shape[0]
        device = x.device
        if task_texts is None:
            task_texts = [self._task_text_default] * B
        x_vlm = self._renormalize_image(x)
        x_vlm = x_vlm.to(self.vlm.dtype)
        image_tokens = self._get_image_tokens(x_vlm)                # (B, N_img, D_lm)
        N_img = image_tokens.shape[1]
        text_inputs = self._tokenizer(
            task_texts, return_tensors="pt", padding=True,
            truncation=True, max_length=64,
        ).to(device)
        text_embeds = self._lm.get_input_embeddings()(text_inputs['input_ids']).to(image_tokens.dtype)
        inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)
        attn_mask = torch.cat([
            torch.ones(B, N_img, device=device, dtype=torch.long),
            text_inputs['attention_mask'],
        ], dim=1)
        lm_out = self._lm(
            inputs_embeds=inputs_embeds,
            attention_mask=attn_mask,
            output_hidden_states=True,
        )
        hidden = lm_out.hidden_states[-1]                            # (B, N_img + T_text, D_lm)
        image_hidden = hidden[:, :N_img, :].float()                  # (B, N_img, D_lm)

        grid = int(math.sqrt(N_img))
        if grid * grid != N_img:
            grid = int(math.ceil(math.sqrt(N_img)))
            pad = grid * grid - N_img
            if pad > 0:
                image_hidden = torch.cat([
                    image_hidden,
                    torch.zeros(B, pad, image_hidden.shape[-1], device=device, dtype=image_hidden.dtype),
                ], dim=1)
        patch_lm = image_hidden.reshape(B, grid, grid, -1).permute(0, 3, 1, 2).contiguous()
        # Project to embed_dim so downstream pipeline is unchanged
        patch = self.vlm_proj(patch_lm.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()  # (B, embed, grid, grid)
        # CLS: take mean of image features as a coarse pooling (no explicit CLS token in InternVL)
        cls = image_hidden.mean(dim=1)                                # (B, D_lm)
        cls = self.vlm_proj(cls)                                       # (B, embed)
        return patch, cls

    def _extract_dino_features(self, x):
        """Override: return VLM patches in DINO-feature format."""
        return self._extract_vlm_features(x, task_texts=None)

    # forward inherits from DinoVolumeQuery — uses _extract_dino_features which we just overrode


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQueryVLM(n_window=8, image_size=448, rotation_mode='1d_pca').to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, 448, 448).to(device)
    sp = torch.rand(2, 2).to(device) * 448
    with torch.no_grad():
        out = m(rgb, sp)
    for k, v in out.items():
        if v is not None: print(f"  {k}: {tuple(v.shape)}")
