"""ACT baseline — direct regression from DINO CLS token + proprioception.

Same DINOv3 ViT-S/16 backbone as PARA, but instead of pixel-aligned heatmaps,
the CLS token is concatenated with proprioceptive state (current EEF position +
gripper) and passed through MLPs with sigmoid outputs to predict normalized [0,1]
targets:
  - 3D EEF position (N_WINDOW × 3)  — normalized via dataset min/max
  - Euler rotation  (N_WINDOW × 3)  — normalized via dataset min/max
  - Gripper value    (N_WINDOW × 1)  — normalized via dataset min/max

All outputs are sigmoid → [0,1], all targets are min/max normalized → [0,1],
so MSE losses are naturally balanced without manual weight tuning.
"""

import os
import torch
import torch.nn as nn

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",    "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16

N_WINDOW = 4


class ACTPredictor(nn.Module):
    """Direct regression baseline: DINO CLS token + proprioception → MLP → sigmoid → [0,1]."""

    def __init__(self, target_size=448, n_window=N_WINDOW, freeze_backbone=False, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.model_type = "act"

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        D = self.embed_dim
        # Input: CLS token (D) + start_keypoint_norm (2) + current_eef_pos (3) + current_gripper (1)
        # + optional CLIP embedding projected to D
        inp_dim = D + 2 + 3 + 1

        # CLIP task conditioning (optional, for multi-task)
        clip_dim = kwargs.get('clip_dim', 512)
        self.clip_proj = nn.Sequential(
            nn.Linear(clip_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
        )
        inp_dim += D  # add projected CLIP embedding
        print(f"✓ CLIP projection: {clip_dim} → {D}")

        self.pos_mlp = nn.Sequential(
            nn.LayerNorm(inp_dim),
            nn.Linear(inp_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, n_window * 3),
            nn.Sigmoid(),
        )
        self.rot_mlp = nn.Sequential(
            nn.LayerNorm(inp_dim),
            nn.Linear(inp_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, n_window * 3),
            nn.Sigmoid(),
        )
        self.gripper_mlp = nn.Sequential(
            nn.LayerNorm(inp_dim),
            nn.Linear(inp_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, n_window),
            # No sigmoid — outputs raw logits for BCE with logits
        )

        n_total = sum(p.numel() for p in self.parameters())
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"✓ ACT model: {n_trainable:,} / {n_total:,} trainable params")
        print(f"  Input: CLS({D}) + start_kp(2) + eef_pos(3) + gripper(1) = {inp_dim}")
        print(f"  pos_mlp:     → (B, {n_window}, 3) [sigmoid, normalized]")
        print(f"  rot_mlp:     → (B, {n_window}, 3) [sigmoid, normalized]")
        print(f"  gripper_mlp: → (B, {n_window})    [sigmoid, normalized]")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_cls(self, x):
        """Extract CLS token from DINO backbone."""
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
        else:
            x_norm_cls = self.dino.norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
        return x_norm_cls[:, 0]  # (B, D)

    def forward(self, x, start_keypoint_2d, current_eef_pos=None, current_gripper=None,
                query_pixels=None, clip_embedding=None):
        """
        Args:
            x:                 (B, 3, H, W)
            start_keypoint_2d: (B, 2) or (2,) current EEF pixel in image coords
            current_eef_pos:   (B, 3) current EEF 3D position (normalized to [0,1])
            current_gripper:   (B, 1) or (B,) current gripper state (normalized to [0,1])
            query_pixels:      ignored (kept for interface compatibility)
            clip_embedding:    (B, clip_dim) precomputed CLIP text embedding for task

        Returns:
            pos_pred:     (B, N_WINDOW, 3)  normalized [0,1] position
            rot_pred:     (B, N_WINDOW, 3)  normalized [0,1] rotation
            gripper_pred: (B, N_WINDOW)     normalized [0,1] gripper
        """
        B = x.shape[0]
        device = x.device
        cls_token = self._extract_cls(x)  # (B, D)

        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        start_kp_norm = start_keypoint_2d / self.target_size  # [0, 1]

        # Proprioception: current EEF position + gripper (already normalized to [0,1])
        if current_eef_pos is None:
            current_eef_pos = torch.zeros(B, 3, device=device)
        if current_gripper is None:
            current_gripper = torch.zeros(B, 1, device=device)
        if current_gripper.dim() == 1:
            current_gripper = current_gripper.unsqueeze(-1)
        if current_eef_pos.dim() == 1:
            current_eef_pos = current_eef_pos.unsqueeze(0).expand(B, -1)

        # CLIP task conditioning
        if clip_embedding is not None:
            clip_proj = self.clip_proj(clip_embedding)  # (B, D)
        else:
            clip_proj = torch.zeros(B, self.embed_dim, device=device)

        inp = torch.cat([cls_token, start_kp_norm, current_eef_pos, current_gripper, clip_proj], dim=-1)

        pos_pred     = self.pos_mlp(inp).reshape(B, self.n_window, 3)
        rot_pred     = self.rot_mlp(inp).reshape(B, self.n_window, 3)
        gripper_pred = self.gripper_mlp(inp).reshape(B, self.n_window)

        return pos_pred, rot_pred, gripper_pred


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ACTPredictor(target_size=448, n_window=N_WINDOW)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    kp = torch.tensor([224.0, 224.0]).to(device)
    eef = torch.randn(2, 3).to(device)
    grip = torch.randn(2, 1).to(device)
    with torch.no_grad():
        pos, rot, grip_out = model(x, kp, current_eef_pos=eef, current_gripper=grip)
    print("pos  ", pos.shape, f"range=[{pos.min():.3f}, {pos.max():.3f}]")
    print("rot  ", rot.shape, f"range=[{rot.min():.3f}, {rot.max():.3f}]")
    print("grip ", grip_out.shape, f"range=[{grip_out.min():.3f}, {grip_out.max():.3f}]")
