from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn


class MARVAEWrapper(nn.Module):
    """
    Thin wrapper around the MAR / LDM-style VAE used in UVA.

    This expects that you have cloned and installed the MAR repo:
      https://github.com/LTH14/mar

    and that its `AutoencoderKL` implementation and checkpoint (e.g. `kl16.ckpt`)
    are available. We keep the wrapper small so you can easily adapt it to the
    exact MAR version / ckpt you use.
    """

    def __init__(self, ckpt_path: str, device: torch.device):
        super().__init__()
        ckpt_path = Path(ckpt_path)
        if not ckpt_path.is_file():
            raise FileNotFoundError(f"MAR VAE checkpoint not found: {ckpt_path}")

        # Import from MAR. Adjust the import path if your clone layout differs.
        try:
            from mar.models.autoencoder_kl import AutoencoderKL  # type: ignore
        except Exception as e:
            raise ImportError(
                "Could not import AutoencoderKL from MAR. Make sure the MAR repo "
                "is on your PYTHONPATH (e.g. `pip install -e .` inside the MAR repo)."
            ) from e

        # This mirrors how MAR loads its KL-16 VAE (kl16.ckpt).
        # If your checkpoint / config differs, adjust kwargs accordingly.
        self.vae = AutoencoderKL(
            embed_dim=16,
            z_channels=16,
            resolution=256,
            in_channels=3,
            out_ch=3,
        )
        state = torch.load(str(ckpt_path), map_location="cpu")
        # MAR stores weights under "state_dict" in the checkpoint.
        state_dict = state.get("state_dict", state)
        self.vae.load_state_dict(state_dict, strict=False)
        self.vae.eval().to(device)
        for p in self.vae.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode RGB in [-1, 1] into latents.

        Args:
            x: (B, 3, H, W) in [-1, 1]

        Returns:
            z: (B, C, H_l, W_l) latent tensor
        """
        # MAR / LDM-style VAEs expect inputs in [0, 1].
        x01 = (x + 1.0) / 2.0
        z = self.vae.encode(x01).latent_dist.sample()
        return z

    @torch.no_grad()
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode latents back to RGB in [-1, 1].

        Args:
            z: (B, C, H_l, W_l)

        Returns:
            x: (B, 3, H, W) in [-1, 1]
        """
        x01 = self.vae.decode(z).clamp(0.0, 1.0)
        x = x01 * 2.0 - 1.0
        return x

    @property
    def latent_shape(self) -> Tuple[int, int, int]:
        """
        Returns (C_l, H_l, W_l) for 256x256 inputs.
        """
        # Encode a dummy frame to infer shape (cheap, done once).
        dummy = torch.zeros(1, 3, 256, 256, device=next(self.parameters()).device)
        z = self.encode(dummy)
        _, c, h, w = z.shape
        return c, h, w

