"""Motion tracks baseline: image + current (2d, height, gripper) -> 2D location + height + gripper per timestep (camera frame).

Conditioned on current 2D (2), height (1), and gripper (1) concatenated to CLS.
Factorized as 2d (N_WINDOW, 2) + height (N_WINDOW,) + gripper (N_WINDOW,). Same lifting as volume for eval/live.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


DINO_REPO_DIR = "dinov3"
DINO_WEIGHTS_PATH = "dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 12
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32
CURRENT_STATE_DIM = 4  # current_2d (2) + current_height (1) + current_gripper (1)


class MotionTracksTrajectoryPredictor(nn.Module):
    """Predict 2D (camera/image) + height + gripper per timestep. Conditioned on current 2d, height, gripper. Lift to 3D for eval/live."""

    def __init__(self, target_size=448, n_window=N_WINDOW, freeze_backbone=False, hidden_dim=512):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        self.out_dim = n_window * 2 + n_window + n_window  # 4 * n_window
        # MLP: [CLS, current_2d (2), current_height (1), current_gripper (1)] -> hidden -> out
        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim + CURRENT_STATE_DIM, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, self.out_dim),
        )
        print(f"✓ MotionTracks head: [CLS, current_2d, current_height, current_gripper] -> MLP -> (B, {self.out_dim})")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_cls(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        cls_token = x_tokens[:, 0]
        return cls_token

    def forward(self, x, gt_target_heatmap=None, training=False, start_keypoint_2d=None, current_height=None, current_gripper=None,
                current_2d=None, current_gripper_state=None):
        """
        Args:
            x: (B, 3, H, W)
            current_2d: (B, 2) current 2D position in image coords. If None, zeros.
            current_height: (B,) or (B, 1) current height (z). If None, zeros.
            current_gripper_state: (B,) or (B, 1) current gripper. If None, zeros.
            start_keypoint_2d: ignored (API compatibility; use current_2d).

        Returns:
            trajectory_2d: (B, N_WINDOW, 2) in image/camera pixel coords
            trajectory_height: (B, N_WINDOW) height (z in world) per timestep
            gripper: (B, N_WINDOW) gripper value per timestep
        """
        cls = self._extract_cls(x)
        B = cls.shape[0]
        device = cls.device
        if current_2d is None:
            current_2d = torch.zeros(B, 2, device=device, dtype=cls.dtype)
        if current_height is None:
            current_height = torch.zeros(B, device=device, dtype=cls.dtype)
        if current_height.dim() == 0:
            current_height = current_height.unsqueeze(0)
        if current_height.dim() == 1:
            current_height = current_height.unsqueeze(1)  # (B, 1)
        if current_gripper_state is None:
            current_gripper_state = torch.zeros(B, device=device, dtype=cls.dtype)
        if current_gripper_state.dim() == 1:
            current_gripper_state = current_gripper_state.unsqueeze(1)  # (B, 1)
        cond = torch.cat([cls, current_2d, current_height, current_gripper_state], dim=1)  # (B, embed_dim+4)
        out = self.mlp(cond)
        trajectory_2d = out[:, : self.n_window * 2].view(B, self.n_window, 2)
        trajectory_height = out[:, self.n_window * 2 : self.n_window * 3].view(B, self.n_window)
        gripper = out[:, self.n_window * 3 :].view(B, self.n_window)
        return trajectory_2d, trajectory_height, gripper


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = MotionTracksTrajectoryPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    cur_2d = torch.rand(2, 2).to(device) * 448
    cur_h = torch.rand(2).to(device) * 0.1
    cur_grip = torch.rand(2).to(device)
    with torch.no_grad():
        t2d, th, grip = model(x, current_2d=cur_2d, current_height=cur_h, current_gripper_state=cur_grip)
    print("trajectory_2d", t2d.shape)
    print("trajectory_height", th.shape)
    print("gripper", grip.shape)
