"""ACT / vanilla regression baseline: image + current robot state -> N_WINDOW future (3D + gripper) in global robot frame.

Conditioned on current 3D position (3) and gripper state (1) concatenated to CLS before regression.
No heatmaps; direct regression of trajectory_3d (N_WINDOW, 3) and gripper (N_WINDOW,).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR = "dinov3"
DINO_WEIGHTS_PATH = "dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 12
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8
# No bins for ACT; kept for checkpoint/API compatibility
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32


CURRENT_STATE_DIM = 4  # current_3d (3) + current_gripper (1)


class ACTTrajectoryPredictor(nn.Module):
    """Vanilla regression: image + current (3d, gripper) -> (N_WINDOW, 3) trajectory_3d + (N_WINDOW,) gripper in global robot frame."""

    def __init__(self, target_size=448, n_window=N_WINDOW, freeze_backbone=False, hidden_dim=512):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        # MLP: [CLS, current_3d (3), current_gripper (1)] -> hidden -> (n_window * 3 + n_window)
        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim + CURRENT_STATE_DIM, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, n_window * 3 + n_window),
        )
        print(f"✓ ACT head: [CLS, current_3d, current_gripper] -> MLP -> (B, {n_window}*3 + {n_window})")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_cls(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        cls_token = x_tokens[:, 0]  # (B, D)
        return cls_token

    def forward(self, x, gt_target_heatmap=None, training=False, start_keypoint_2d=None, current_height=None, current_gripper=None,
                current_3d=None, current_gripper_state=None):
        """
        Args:
            x: (B, 3, H, W)
            current_3d: (B, 3) current gripper keypoint 3D position in world frame. If None, zeros.
            current_gripper_state: (B,) or (B, 1) current gripper value. If None, zeros.
            start_keypoint_2d, current_height: ignored (API compatibility)

        Returns:
            trajectory_3d: (B, N_WINDOW, 3) in global robot frame
            gripper: (B, N_WINDOW) gripper joint value per timestep
        """
        cls = self._extract_cls(x)  # (B, D)
        B = cls.shape[0]
        device = cls.device
        if current_3d is None:
            current_3d = torch.zeros(B, 3, device=device, dtype=cls.dtype)
        if current_gripper_state is None:
            current_gripper_state = torch.zeros(B, device=device, dtype=cls.dtype)
        if current_gripper_state.dim() == 1:
            current_gripper_state = current_gripper_state.unsqueeze(1)  # (B, 1)
        cond = torch.cat([cls, current_3d, current_gripper_state], dim=1)  # (B, D+4)
        out = self.mlp(cond)
        trajectory_3d = out[:, : self.n_window * 3].view(B, self.n_window, 3)
        gripper = out[:, self.n_window * 3 :].view(B, self.n_window)
        return trajectory_3d, gripper


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = ACTTrajectoryPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    cur_3d = torch.randn(2, 3).to(device) * 0.1
    cur_grip = torch.rand(2).to(device)
    with torch.no_grad():
        traj, grip = model(x, current_3d=cur_3d, current_gripper_state=cur_grip)
    print("trajectory_3d", traj.shape)
    print("gripper", grip.shape)
