"""
Side-by-side GT vs PRED video rollout comparison.
Takes a GT trajectory, picks ~5 equally spaced frames as conditioning,
generates predicted video from each, and stitches GT | PRED with labels.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "9"

import argparse
import torch
import numpy as np
import imageio
import decord
from PIL import Image, ImageDraw, ImageFont
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

def add_label(frame_np, text, position="top-left", font_size=20):
    """Add text label to a frame."""
    img = Image.fromarray(frame_np)
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
    except:
        font = ImageFont.load_default()
    # Background rectangle
    bbox = draw.textbbox((0, 0), text, font=font)
    tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
    x, y = 5, 5
    draw.rectangle([x-2, y-2, x+tw+4, y+th+4], fill=(0, 0, 0))
    draw.text((x, y), text, fill=(255, 255, 255), font=font)
    return np.array(img)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str, default=None, help="Checkpoint dir (default: latest)")
    parser.add_argument("--video", type=str, default="dataset/libero/libero_0000.mp4", help="GT video path")
    parser.add_argument("--num_starts", type=int, default=5, help="Number of conditioning frames")
    parser.add_argument("--num_frames", type=int, default=7, help="Frames to generate per rollout")
    parser.add_argument("--width", type=int, default=576)
    parser.add_argument("--height", type=int, default=320)
    parser.add_argument("--output", type=str, default="eval_output/side_by_side.mp4")
    args = parser.parse_args()

    # Find latest checkpoint
    if args.ckpt is None:
        ckpt_dir = "output_libero_7f"
        ckpts = sorted([d for d in os.listdir(ckpt_dir) if d.startswith("checkpoint-")])
        if ckpts:
            args.ckpt = os.path.join(ckpt_dir, ckpts[-1])
            print(f"Using latest checkpoint: {args.ckpt}")
        else:
            args.ckpt = ckpt_dir
            print(f"Using final model: {args.ckpt}")

    # Load model
    print("Loading fine-tuned model...")
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.ckpt, subfolder="unet", torch_dtype=torch.float16
    )
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
        unet=unet, torch_dtype=torch.float16, variant="fp16"
    )
    pipe.to("cuda")
    print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # Load GT video
    print(f"Loading GT video: {args.video}")
    vr = decord.VideoReader(args.video)
    total_frames = len(vr)
    print(f"  Total frames: {total_frames}")

    # Pick equally spaced start frames (leave room for num_frames after each)
    max_start = total_frames - args.num_frames
    if max_start <= 0:
        max_start = total_frames - 1
    start_indices = np.linspace(0, max_start, args.num_starts, dtype=int)
    print(f"  Start frames: {start_indices.tolist()}")

    all_combined_frames = []

    for si, start_idx in enumerate(start_indices):
        print(f"\n[{si+1}/{args.num_starts}] Conditioning on frame {start_idx}...")

        # Get conditioning frame
        cond_frame = vr[start_idx].asnumpy()
        cond_image = Image.fromarray(cond_frame).resize((args.width, args.height))

        # Get GT frames
        end_idx = min(start_idx + args.num_frames, total_frames)
        gt_indices = list(range(start_idx, end_idx))
        # Pad if not enough frames
        while len(gt_indices) < args.num_frames:
            gt_indices.append(gt_indices[-1])
        gt_frames = vr.get_batch(gt_indices).asnumpy()
        gt_frames_resized = [np.array(Image.fromarray(f).resize((args.width, args.height))) for f in gt_frames]

        # Generate predicted frames
        with torch.inference_mode():
            pred_pil_frames = pipe(
                cond_image, height=args.height, width=args.width,
                num_frames=args.num_frames, decode_chunk_size=4,
                num_inference_steps=25,
            ).frames[0]
        pred_frames = [np.array(f) for f in pred_pil_frames]

        # Create side-by-side frames with labels
        for fi in range(args.num_frames):
            gt_labeled = add_label(gt_frames_resized[fi], f"GT  t={start_idx+fi}")
            pred_labeled = add_label(pred_frames[fi], f"PRED  t={start_idx+fi}")
            combined = np.concatenate([gt_labeled, pred_labeled], axis=1)
            all_combined_frames.append(combined)

        # Add a separator (black frame) between rollouts
        h, w = all_combined_frames[-1].shape[:2]
        separator = np.zeros((h, w, 3), dtype=np.uint8)
        sep_img = Image.fromarray(separator)
        draw = ImageDraw.Draw(sep_img)
        try:
            font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
        except:
            font = ImageFont.load_default()
        text = f"--- Rollout {si+1}/{args.num_starts}, start frame {start_idx} ---"
        draw.text((w//4, h//2 - 12), text, fill=(255, 255, 255), font=font)
        all_combined_frames.append(np.array(sep_img))

        torch.cuda.empty_cache()

    # Save combined video
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    imageio.mimwrite(args.output, all_combined_frames, fps=4, quality=8)
    print(f"\nSaved side-by-side video: {args.output}")
    print(f"  {len(all_combined_frames)} total frames ({args.num_starts} rollouts × {args.num_frames} frames)")


if __name__ == "__main__":
    main()
