"""
Load a UVA checkpoint into policy (MAR + VAE) for video-only sampling.
Self-contained: no imports from unified_video_action.
Uses the simplified video-only MAR (no text, proprio, action, wrist).
"""

import os
from pathlib import Path
from types import SimpleNamespace

import dill
import torch
from omegaconf import OmegaConf

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only


def _resolve_path(repo_root: Path, raw_path) -> str:
    p = str(raw_path)
    if os.path.isabs(p):
        return p
    return str(repo_root / p)


def load_policy_from_checkpoint(ckpt_path: Path, device: torch.device, repo_root: Path = None):
    """Load policy and cfg from a full UVA checkpoint. Returns (policy, cfg)."""
    repo_root = repo_root or Path(__file__).resolve().parents[1]

    with open(ckpt_path, "rb") as f:
        payload = torch.load(f, map_location=device, pickle_module=dill)

    cfg = payload["cfg"]
    state_dict = payload["state_dicts"].get("ema_model") or payload["state_dicts"].get("model")
    if state_dict is None:
        raise KeyError("Checkpoint must contain state_dicts['ema_model'] or state_dicts['model']")

    # Resolve VAE path relative to repo
    policy_cfg = cfg.model.policy
    ae_path = getattr(policy_cfg.vae_model_params, "autoencoder_path", None)
    vae_params = {
        "autoencoder_path": _resolve_path(repo_root, ae_path) if ae_path else None,
        "ddconfig": policy_cfg.vae_model_params.ddconfig,
    }

    # Build VAE
    with torch.no_grad():
        vae_model = AutoencoderKL(**vae_params)
    vae_model.eval()
    for p in vae_model.parameters():
        p.requires_grad = False

    # Video-only MAR: only architectural params from checkpoint
    ar_params = policy_cfg.autoregressive_model_params
    mar_model = mar_base_video_only(
        img_size=ar_params.img_size,
        vae_stride=ar_params.vae_stride,
        patch_size=ar_params.patch_size,
        vae_embed_dim=ar_params.vae_embed_dim,
        mask_ratio_min=ar_params.mask_ratio_min,
        label_drop_prob=ar_params.label_drop_prob,
        attn_dropout=ar_params.attn_dropout,
        proj_dropout=ar_params.proj_dropout,
        diffloss_d=ar_params.diffloss_d,
        diffloss_w=ar_params.diffloss_w,
        num_sampling_steps=str(ar_params.num_sampling_steps),
        grad_checkpointing=ar_params.grad_checkpointing,
    )

    # Load state dict: only keys that exist in our model (strict=False drops the rest)
    model_sd = {k[6:]: v for k, v in state_dict.items() if k.startswith("model.")}
    vae_sd = {k[10:]: v for k, v in state_dict.items() if k.startswith("vae_model.")}

    mar_model.load_state_dict(model_sd, strict=False)
    vae_model.load_state_dict(vae_sd, strict=False)

    policy = SimpleNamespace()
    policy.model = mar_model
    policy.vae_model = vae_model
    policy.model.to(device)
    policy.vae_model.to(device)
    policy.model.eval()
    policy.vae_model.eval()

    return policy, cfg
