"""Diffusion policy baseline: image + current robot state -> diffusion over (3D + gripper) state, N~10 steps, global robot frame.

Conditioned on current 3D position (3) and gripper state (1) concatenated to CLS.
State = (trajectory_3d flattened, gripper flattened) = (N_WINDOW*4,). Condition = [CLS, current_3d, current_gripper].
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR = "dinov3"
DINO_WEIGHTS_PATH = "dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

N_WINDOW = 12
MIN_HEIGHT = 0.043347
MAX_HEIGHT = 0.043347
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32

# Diffusion
NUM_DIFFUSION_STEPS = 10
CURRENT_STATE_DIM = 4  # current_3d (3) + current_gripper (1)


def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in improved DDPM."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device).float() * -embeddings)
        embeddings = time[:, None].float() * embeddings[None, :]
        embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
        return embeddings


class DiffusionTrajectoryPredictor(nn.Module):
    """Diffusion over (trajectory_3d, gripper) state conditioned on image. Global robot frame. N~10 steps."""

    def __init__(self, target_size=448, n_window=N_WINDOW, freeze_backbone=False, hidden_dim=512,
                 num_steps=NUM_DIFFUSION_STEPS, t_embed_dim=64):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.state_dim = n_window * 4  # 3*N_WINDOW + N_WINDOW
        self.num_steps = num_steps
        self.patch_size = DINO_PATCH_SIZE

        print("Loading DINOv2 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH
        )
        if freeze_backbone:
            for param in self.dino.parameters():
                param.requires_grad = False
            self.dino.eval()
            print("✓ Frozen DINOv2 backbone")
        else:
            print("✓ DINOv2 backbone is trainable")

        self.embed_dim = self.dino.embed_dim
        # Condition = [CLS, current_3d, current_gripper]
        self.cond_proj = nn.Linear(self.embed_dim + CURRENT_STATE_DIM, hidden_dim)

        # Beta schedule
        betas = cosine_beta_schedule(num_steps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))

        # Denoiser: predicts noise given (noisy_state, t_embed, cond)
        self.t_embed = SinusoidalPositionEmbeddings(t_embed_dim)
        self.denoiser = nn.Sequential(
            nn.Linear(self.state_dim + t_embed_dim + hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, self.state_dim),
        )
        print(f"✓ Diffusion: state_dim={self.state_dim}, num_steps={num_steps}, cond=[CLS, current_3d, current_gripper]")

    def to(self, device):
        super().to(device)
        if hasattr(self, 'dino'):
            self.dino = self.dino.to(device)
        return self

    def _extract_cls(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        cls_token = x_tokens[:, 0]
        return cls_token

    def _state_to_trajectory_gripper(self, state):
        """state (B, state_dim) -> trajectory_3d (B, N_WINDOW, 3), gripper (B, N_WINDOW)."""
        B = state.shape[0]
        traj = state[:, : self.n_window * 3].view(B, self.n_window, 3)
        grip = state[:, self.n_window * 3 :].view(B, self.n_window)
        return traj, grip

    def _trajectory_gripper_to_state(self, trajectory_3d, gripper):
        """(B, N_WINDOW, 3), (B, N_WINDOW) -> (B, state_dim)."""
        B = trajectory_3d.shape[0]
        traj_flat = trajectory_3d.reshape(B, -1)
        grip_flat = gripper.reshape(B, -1)
        return torch.cat([traj_flat, grip_flat], dim=1)

    def _denoise_step(self, x_t, t, cond):
        """Single denoise step: predict noise and return x_{t-1} (or x_0 for last step)."""
        t_embed = self.t_embed(t)  # (B, t_embed_dim)
        denoiser_in = torch.cat([x_t, t_embed, cond], dim=1)
        noise_pred = self.denoiser(denoiser_in)
        # DDPM: x_{t-1} = (1/sqrt(alpha_t)) * (x_t - (1-alpha_t)/sqrt(1-alpha_bar_t) * noise_pred) + sigma_t * z
        alpha_t = self.alphas_cumprod[t].view(-1, 1)
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1)
        sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        x0_pred = (x_t - sqrt_one_minus * noise_pred) / sqrt_alpha
        if t > 0:
            beta_t = self.betas[t - 1].view(-1, 1)
            sigma_t = torch.sqrt(beta_t)
            noise = torch.randn_like(x_t, device=x_t.device)
            x_prev = (x_t - sqrt_one_minus * noise_pred) / sqrt_alpha
            x_prev = x_prev + sigma_t * noise
        else:
            x_prev = x0_pred
        return x_prev, x0_pred

    def forward(self, x, gt_trajectory_3d=None, gt_gripper=None, training=False,
                start_keypoint_2d=None, current_height=None, current_gripper=None,
                current_3d=None, current_gripper_state=None):
        """
        Args:
            x: (B, 3, H, W)
            current_3d: (B, 3) current gripper 3D in world. If None, zeros.
            current_gripper_state: (B,) or (B, 1) current gripper. If None, zeros.
            gt_trajectory_3d, gt_gripper: for training only
            training: if True, compute diffusion loss; else sample and return (trajectory_3d, gripper)

        Returns:
            If training: loss (scalar)
            Else: trajectory_3d (B, N_WINDOW, 3), gripper (B, N_WINDOW)
        """
        cls = self._extract_cls(x)  # (B, embed_dim)
        B = cls.shape[0]
        device = cls.device
        if current_3d is None:
            current_3d = torch.zeros(B, 3, device=device, dtype=cls.dtype)
        if current_gripper_state is None:
            current_gripper_state = torch.zeros(B, device=device, dtype=cls.dtype)
        if current_gripper_state.dim() == 1:
            current_gripper_state = current_gripper_state.unsqueeze(1)  # (B, 1)
        cond = torch.cat([cls, current_3d, current_gripper_state], dim=1)  # (B, embed_dim+4)
        cond = self.cond_proj(cond)   # (B, hidden_dim)

        if training and gt_trajectory_3d is not None and gt_gripper is not None:
            x0 = self._trajectory_gripper_to_state(gt_trajectory_3d, gt_gripper)  # (B, state_dim)
            B = x0.shape[0]
            t = torch.randint(0, self.num_steps, (B,), device=x.device, dtype=torch.long)
            noise = torch.randn_like(x0, device=x.device)
            sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1)
            sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
            x_t = sqrt_alpha * x0 + sqrt_one_minus * noise
            t_embed = self.t_embed(t)
            denoiser_in = torch.cat([x_t, t_embed, cond], dim=1)
            noise_pred = self.denoiser(denoiser_in)
            loss = F.mse_loss(noise_pred, noise)
            return loss

        # Inference: sample from diffusion (reverse process)
        B = x.shape[0]
        x_t = torch.randn(B, self.state_dim, device=x.device, dtype=x.dtype)
        for t in reversed(range(self.num_steps)):
            t_batch = torch.full((B,), t, device=x.device, dtype=torch.long)
            x_t, _ = self._denoise_step(x_t, t_batch, cond)
        trajectory_3d, gripper = self._state_to_trajectory_gripper(x_t)
        return trajectory_3d, gripper


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = DiffusionTrajectoryPredictor(target_size=448, n_window=N_WINDOW, freeze_backbone=True)
    model = model.to(device)
    x = torch.randn(2, 3, 448, 448).to(device)
    cur_3d = torch.randn(2, 3).to(device) * 0.1
    cur_grip = torch.rand(2).to(device)
    with torch.no_grad():
        traj, grip = model(x, training=False, current_3d=cur_3d, current_gripper_state=cur_grip)
    print("trajectory_3d", traj.shape)
    print("gripper", grip.shape)
