import argparse
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
import imageio.v3 as iio

from scripts.inference_evaluate import load_model_from_config


def _natural_sort_key(p: Path) -> tuple:
    """Sort key for frame paths (e.g. 0.png, 1.png, 10.png)."""
    stem = p.stem
    try:
        return (int(stem),)
    except ValueError:
        return (stem,)


def load_episode_first_chunk(episode_dir: str, num_frames: int = 8, size: int = 256):
    """Load first num_frames from a robot episode directory (image sequence)."""
    episode_path = Path(episode_dir)
    if not episode_path.is_dir():
        raise FileNotFoundError(f"Episode dir not found: {episode_dir}")

    exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
    paths = sorted(
        [p for p in episode_path.iterdir() if p.suffix.lower() in exts],
        key=_natural_sort_key,
    )
    if len(paths) < num_frames:
        raise ValueError(
            f"Episode has {len(paths)} frames, need at least {num_frames}."
        )

    frames = []
    for p in paths[5:num_frames+5]:
        img = iio.imread(str(p))
        if img.ndim == 2:
            img = np.stack([img] * 3, axis=-1)
        frames.append(img)
    frames = np.stack(frames, axis=0)  # [T, H, W, C], uint8

    x = torch.from_numpy(frames).float() / 255.0    # [T, H, W, C]
    x = x.permute(0, 3, 1, 2)                       # [T, C, H, W]
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    x = x.permute(1, 0, 2, 3).unsqueeze(0)           # [1, C, T, H, W]
    x = x * 2.0 - 1.0                               # [-1, 1]
    return x


def load_video_first_chunk(video_path: str, num_frames: int = 17, size: int = 256):
    # Read all frames, then take the first contiguous chunk.
    # Use plugin="FFMPEG" (caps) for imageio v3, or omit to auto-select backend.
    try:
        frames = [frame for frame in iio.imiter(video_path, plugin="FFMPEG")]
    except ValueError:
        frames = [frame for frame in iio.imiter(video_path)]
    if len(frames) < num_frames:
        raise ValueError(f"Video only has {len(frames)} frames, need at least {num_frames}.")

    frames = np.stack(frames[:num_frames], axis=0)  # [T, H, W, C], uint8
    x = torch.from_numpy(frames).float() / 255.0    # [T, H, W, C]
    x = x.permute(0, 3, 1, 2)                       # [T, C, H, W]
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    x = x.permute(1, 0, 2, 3).unsqueeze(0)          # [1, C, T, H, W]
    x = x * 2.0 - 1.0                               # [-1, 1]
    return x


def to_01_video(x):
    # [1, C, T, H, W] in [-1, 1] -> [T, C, H, W] in [0, 1]
    x = ((x[0].permute(1, 0, 2, 3) + 1.0) / 2.0).clamp(0, 1)
    return x


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode-dir", type=str, default="../../keygrip/scratch/parsed_longer_school_slowgrasp_feb10/episode_001")
    parser.add_argument("--num-frames", type=int, default=8)
    parser.add_argument("--config", type=str, default="configs/vidtok_v1_1/vidtok_kl_causal_288_8chn_v1_1.yaml")
    parser.add_argument("--ckpt", type=str, default="checkpoints/vidtok_kl_causal_288_8chn_v1_1.ckpt")
    parser.add_argument("--outdir", type=str, default="out/tokenize_demo")
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    model = load_model_from_config(args.config, args.ckpt)
    model = model.to(args.device).eval()

    # Load first num_frames from robot episode directory.
    x = load_episode_first_chunk(args.episode_dir, num_frames=args.num_frames, size=512).to(args.device)

    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        z, reg_log = model.encode(x, return_reg_log=True)
        x_recon = model.decode(z)

    # Save continuous tokens (latents).
    torch.save({"z": z.detach().cpu()}, outdir / "tokens.pt")

    # If this model exposes discrete indices, save them too.
    if isinstance(reg_log, dict) and "indices" in reg_log:
        torch.save({"indices": reg_log["indices"].detach().cpu()}, outdir / "indices.pt")
        print("Saved discrete indices:", tuple(reg_log["indices"].shape))
    else:
        print("No discrete indices for this checkpoint; saved continuous latents only.")

    print("Input shape:", tuple(x.shape))
    print("Latent z shape:", tuple(z.shape))
    print("Recon shape:", tuple(x_recon.shape))

    orig = to_01_video(x)
    recon = to_01_video(x_recon)

    mse = torch.mean((orig - recon) ** 2).item()
    psnr = -10.0 * np.log10(max(mse, 1e-12))
    print(f"MSE:  {mse:.6f}")
    print(f"PSNR: {psnr:.3f} dB")

    # Save frame grids.
    save_image(make_grid(orig, nrow=4), outdir / "original_grid.png")
    save_image(make_grid(recon, nrow=4), outdir / "recon_grid.png")

    # Save a side-by-side comparison grid: [orig | recon] for each frame.
    paired = torch.cat([orig, recon], dim=3)  # concat width
    save_image(make_grid(paired, nrow=4), outdir / "comparison_grid.png")

    print(f"Saved outputs to: {outdir}")


if __name__ == "__main__":
    main()