"""InternVL VLA baseline — action-token decoding via the LLM (pi0-style).

Same InternVL2.5-1B backbone as model_vla_internvl.py, but instead of PARA
pixel-aligned heatmaps, appends learnable action query tokens + a proprioception
token to the LLM input sequence:

    [image_patches, text_tokens, proprio_token, action_0, ..., action_{N-1}]

The LLM's causal self-attention lets each action token attend to all image
patches, the task description, proprioception, and all *preceding* action
tokens — giving autoregressive action prediction for free.  The hidden states
at the action token positions are then decoded by lightweight per-timestep
heads into normalised [0,1] position / rotation / gripper values (MSE loss).

This is the standard modern VLA formulation (pi0, RT-2, Octo) and serves as
the non-PARA baseline for comparing action representations on the same backbone.

Usage:
    python train.py --model_type internvl_act --task_ids all --cache_root /data/libero/parsed_libero
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

N_WINDOW = 6
PROPRIO_DIM = 6  # start_kp_norm(2) + eef_pos(3) + gripper(1)


class InternVLACTPredictor(nn.Module):
    """InternVL2.5-1B + action-token VLA (pi0-style)."""

    def __init__(self, target_size=448, n_window=N_WINDOW,
                 freeze_backbone=False, model_name="OpenGVLab/InternVL2_5-1B", **kwargs):
        super().__init__()
        self.target_size = target_size
        self.n_window    = n_window
        self.model_type  = "internvl_act"

        # --- Load VLM ---------------------------------------------------------
        print(f"Loading InternVL model: {model_name}")
        from transformers import AutoModel, AutoTokenizer

        self.vlm = AutoModel.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, trust_remote_code=True,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True,
        )

        if freeze_backbone:
            for p in self.vlm.parameters():
                p.requires_grad = False
            self.vlm.eval()
            print("  Frozen InternVL backbone")
        else:
            print("  InternVL backbone is trainable")

        # --- Detect VLM components --------------------------------------------
        self._vision_enc = self._find_attr('vision_model', 'visual')
        self._projector  = self._find_attr('mlp1', 'multi_modal_projector')
        self._lm         = self._find_attr('language_model')

        # --- Config -----------------------------------------------------------
        cfg = self.vlm.config
        llm_cfg = getattr(cfg, 'llm_config', getattr(cfg, 'text_config', None))
        self.vlm_hidden_dim = D = llm_cfg.hidden_size
        vis_cfg = cfg.vision_config
        self.vlm_image_size = getattr(vis_cfg, 'image_size', 448)
        self.downsample_ratio = getattr(cfg, 'downsample_ratio', 0.5)
        print(f"  VLM hidden dim: {D}")
        print(f"  Vision: {self.vlm_image_size}px, downsample {self.downsample_ratio}")

        # Image re-normalization buffers
        self.register_buffer('vlm_mean', torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('vlm_std',  torch.tensor(IMAGENET_STD,  dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('inet_mean', torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1))
        self.register_buffer('inet_std',  torch.tensor(IMAGENET_STD,  dtype=torch.float32).view(3, 1, 1))

        # --- Action tokens + proprioception -----------------------------------
        # Learnable action query tokens — one per predicted timestep.
        # Injected at the end of the sequence so causal attention lets them
        # attend to image + text + proprio + preceding action tokens.
        self.action_tokens = nn.Parameter(torch.randn(n_window, D) * 0.02)

        # Proprioception projection: low-dim state → LLM embedding space
        self.proprio_proj = nn.Sequential(
            nn.Linear(PROPRIO_DIM, D),
            nn.GELU(),
            nn.Linear(D, D),
        )

        # --- Action decoding heads --------------------------------------------
        # Shared across timesteps: each action token's hidden state → (pos, rot, grip)
        self.pos_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, 3),
            nn.Sigmoid(),
        )
        self.rot_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, 3),
            nn.Sigmoid(),
        )
        self.gripper_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, 1),
            nn.Sigmoid(),
        )

        print(f"  Action tokens: {n_window} x {D}")
        print(f"  Proprio proj:  {PROPRIO_DIM} -> {D}")
        print(f"  Sequence: [img_patches, text, proprio, act_0..act_{n_window-1}]")
        print(f"  pos_head:     per-token -> (3,) [sigmoid]")
        print(f"  rot_head:     per-token -> (3,) [sigmoid]")
        print(f"  gripper_head: per-token -> (1,) [sigmoid]")

    # ------------------------------------------------------------------
    def _find_attr(self, *names):
        for n in names:
            if hasattr(self.vlm, n):
                return getattr(self.vlm, n)
        raise AttributeError(
            f"Cannot find any of {names} on VLM ({type(self.vlm).__name__}). "
            f"Available: {[a for a in dir(self.vlm) if not a.startswith('_')]}"
        )

    def _renormalize_image(self, x):
        x_raw = x * self.inet_std + self.inet_mean
        x_vlm = (x_raw - self.vlm_mean) / self.vlm_std
        if x_vlm.shape[-1] != self.vlm_image_size or x_vlm.shape[-2] != self.vlm_image_size:
            x_vlm = F.interpolate(x_vlm, size=(self.vlm_image_size, self.vlm_image_size),
                                  mode='bilinear', align_corners=False)
        return x_vlm

    def _get_image_tokens(self, pixel_values):
        if hasattr(self.vlm, 'extract_feature'):
            return self.vlm.extract_feature(pixel_values)
        vis_out = self._vision_enc(pixel_values)
        vit_features = vis_out.last_hidden_state if hasattr(vis_out, 'last_hidden_state') else vis_out[0]
        return self._projector(vit_features)

    # ------------------------------------------------------------------
    def forward(self, x, start_keypoint_2d, current_eef_pos=None, current_gripper=None,
                query_pixels=None, task_text=None, **kwargs):
        """
        Args:
            x:                 (B, 3, H, W) ImageNet-normalized
            start_keypoint_2d: (B, 2) or (2,) current EEF pixel
            current_eef_pos:   (B, 3) normalized [0,1]
            current_gripper:   (B, 1) or (B,) normalized [0,1]
            task_text:         list of B strings

        Returns:
            pos_pred:     (B, N_WINDOW, 3)  [0,1]
            rot_pred:     (B, N_WINDOW, 3)  [0,1]
            gripper_pred: (B, N_WINDOW)     [0,1]
        """
        B = x.shape[0]
        device = x.device

        if task_text is None:
            task_text = ["Pick up the object and place it on the target."] * B

        # --- Prepare proprioception -------------------------------------------
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_kp_norm = start_keypoint_2d / self.target_size

        if current_eef_pos is None:
            current_eef_pos = torch.zeros(B, 3, device=device)
        if current_gripper is None:
            current_gripper = torch.zeros(B, 1, device=device)
        if current_gripper.dim() == 1:
            current_gripper = current_gripper.unsqueeze(-1)
        if current_eef_pos.dim() == 1:
            current_eef_pos = current_eef_pos.unsqueeze(0).expand(B, -1)

        proprio = torch.cat([start_kp_norm, current_eef_pos, current_gripper], dim=-1)  # (B, 6)

        # --- Image tokens -----------------------------------------------------
        pixel_values = self._renormalize_image(x)
        image_tokens = self._get_image_tokens(pixel_values.to(self.vlm.dtype))  # (B, N_img, D)
        N_img = image_tokens.shape[1]

        # --- Text tokens ------------------------------------------------------
        text_inputs = self._tokenizer(
            task_text, return_tensors="pt", padding=True,
            truncation=True, max_length=64,
        ).to(device)
        text_embeds = self._lm.get_input_embeddings()(text_inputs['input_ids'])
        text_embeds = text_embeds.to(image_tokens.dtype)
        T_text = text_embeds.shape[1]

        # --- Proprioception token ---------------------------------------------
        proprio_token = self.proprio_proj(proprio.float()).to(image_tokens.dtype).unsqueeze(1)  # (B, 1, D)

        # --- Action query tokens ----------------------------------------------
        action_queries = self.action_tokens.unsqueeze(0).expand(B, -1, -1)  # (B, N_WINDOW, D)
        action_queries = action_queries.to(image_tokens.dtype)

        # --- Assemble sequence ------------------------------------------------
        # [image_patches | text | proprio | action_0 ... action_{N-1}]
        # Causal attention: action_t sees image + text + proprio + action_0..t-1
        inputs_embeds = torch.cat([
            image_tokens, text_embeds, proprio_token, action_queries,
        ], dim=1)

        attn_mask = torch.cat([
            torch.ones(B, N_img, device=device, dtype=torch.long),
            text_inputs['attention_mask'],
            torch.ones(B, 1 + self.n_window, device=device, dtype=torch.long),  # proprio + actions
        ], dim=1)

        # --- LLM forward (causal self-attention) ------------------------------
        lm_out = self._lm(
            inputs_embeds=inputs_embeds,
            attention_mask=attn_mask,
            output_hidden_states=True,
        )

        # --- Extract action token hidden states -------------------------------
        hidden = lm_out.hidden_states[-1]  # (B, total_seq_len, D)
        action_start = N_img + T_text + 1  # after image + text + proprio
        action_hidden = hidden[:, action_start:action_start + self.n_window, :].float()
        # action_hidden: (B, N_WINDOW, D)

        # --- Decode actions ---------------------------------------------------
        pos_pred     = self.pos_head(action_hidden)                   # (B, N_WINDOW, 3)
        rot_pred     = self.rot_head(action_hidden)                   # (B, N_WINDOW, 3)
        gripper_pred = self.gripper_head(action_hidden).squeeze(-1)   # (B, N_WINDOW)

        return pos_pred, rot_pred, gripper_pred


# ---------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = InternVLACTPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    kp = torch.tensor([224.0, 224.0]).to(device)
    eef = torch.randn(2, 3).to(device)
    grip = torch.randn(2, 1).to(device)
    texts = ["pick up the bowl", "place it on the plate"]
    with torch.no_grad():
        pos, rot, grip_out = model(x, kp, current_eef_pos=eef, current_gripper=grip, task_text=texts)
    print("pos  ", pos.shape, f"range=[{pos.min():.3f}, {pos.max():.3f}]")
    print("rot  ", rot.shape, f"range=[{rot.min():.3f}, {rot.max():.3f}]")
    print("grip ", grip_out.shape, f"range=[{grip_out.min():.3f}, {grip_out.max():.3f}]")
