import argparse
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
import imageio.v3 as iio

from scripts.inference_evaluate import load_model_from_config


def load_video_first_chunk(video_path: str, num_frames: int = 17, size: int = 256):
    # Read all frames, then take the first contiguous chunk.
    # Use plugin="FFMPEG" (caps) for imageio v3, or omit to auto-select backend.
    try:
        frames = [frame for frame in iio.imiter(video_path, plugin="FFMPEG")]
    except ValueError:
        frames = [frame for frame in iio.imiter(video_path)]
    if len(frames) < num_frames:
        raise ValueError(f"Video only has {len(frames)} frames, need at least {num_frames}.")

    frames = np.stack(frames[:num_frames], axis=0)  # [T, H, W, C], uint8
    x = torch.from_numpy(frames).float() / 255.0    # [T, H, W, C]
    x = x.permute(0, 3, 1, 2)                       # [T, C, H, W]
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    x = x.permute(1, 0, 2, 3).unsqueeze(0)          # [1, C, T, H, W]
    x = x * 2.0 - 1.0                               # [-1, 1]
    return x


def to_01_video(x):
    # [1, C, T, H, W] in [-1, 1] -> [T, C, H, W] in [0, 1]
    x = ((x[0].permute(1, 0, 2, 3) + 1.0) / 2.0).clamp(0, 1)
    return x


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--video", type=str, default="../tmpvids/22008760.mp4")
    parser.add_argument("--num-frames", type=int, default=8)
    parser.add_argument("--config", type=str, default="configs/vidtok_v1_1/vidtok_kl_causal_288_8chn_v1_1.yaml")
    parser.add_argument("--ckpt", type=str, default="checkpoints/vidtok_kl_causal_288_8chn_v1_1.ckpt")
    parser.add_argument("--outdir", type=str, default="out/tokenize_demo")
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    model = load_model_from_config(args.config, args.ckpt)
    model = model.to(args.device).eval()

    # Load first num_frames from video.
    # x = load_video_first_chunk(args.video, num_frames=args.num_frames, size=256).to(args.device)
    # Load single image as 1-frame video [1, C, 1, H, W]
    img_path = Path(__file__).resolve().parents[1] / "unified_video_action" / "fake_libero_img.png"
    frame = iio.imread(str(img_path))  # [H, W, C]
    x = torch.from_numpy(frame).float() / 255.0
    x = x.permute(2, 0, 1).unsqueeze(0).unsqueeze(2)  # [1, C, 1, H, W]
    x = F.interpolate(x.squeeze(2), size=(256, 256), mode="bilinear", align_corners=False).unsqueeze(2)  # [1, C, 1, 256, 256]
    x = (x * 2.0 - 1.0).to(args.device)

    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        z, reg_log = model.encode(x, return_reg_log=True)
        x_recon = model.decode(z)

    # Save continuous tokens (latents).
    torch.save({"z": z.detach().cpu()}, outdir / "tokens.pt")

    # If this model exposes discrete indices, save them too.
    if isinstance(reg_log, dict) and "indices" in reg_log:
        torch.save({"indices": reg_log["indices"].detach().cpu()}, outdir / "indices.pt")
        print("Saved discrete indices:", tuple(reg_log["indices"].shape))
    else:
        print("No discrete indices for this checkpoint; saved continuous latents only.")

    print("Input shape:", tuple(x.shape))
    print("Latent z shape:", tuple(z.shape))
    print("Recon shape:", tuple(x_recon.shape))

    orig = to_01_video(x)
    recon = to_01_video(x_recon)

    mse = torch.mean((orig - recon) ** 2).item()
    psnr = -10.0 * np.log10(max(mse, 1e-12))
    print(f"MSE:  {mse:.6f}")
    print(f"PSNR: {psnr:.3f} dB")

    # Save frame grids.
    save_image(make_grid(orig, nrow=4), outdir / "original_grid.png")
    save_image(make_grid(recon, nrow=4), outdir / "recon_grid.png")

    # Save a side-by-side comparison grid: [orig | recon] for each frame.
    paired = torch.cat([orig, recon], dim=3)  # concat width
    save_image(make_grid(paired, nrow=4), outdir / "comparison_grid.png")

    print(f"Saved outputs to: {outdir}")


if __name__ == "__main__":
    main()