"""
Fine-tune Ctrl-World for single image-to-video rollouts on DROID.
Uses zero actions (unconditional action conditioning) so the model learns to predict
future frames from the current image only. Same CLI and wandb setup as diffusion_dino/train.py.
"""
import argparse
import os
import sys
import tempfile
from pathlib import Path

import einops
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Ctrl-World root
ctrl_world_root = Path(__file__).resolve().parents[1]
vidgen_root = ctrl_world_root.parent
uva_root = vidgen_root / "unified_video_action"
if ctrl_world_root.exists() and str(ctrl_world_root) not in sys.path:
    sys.path.insert(0, str(ctrl_world_root))
for p in [vidgen_root, uva_root]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p))

from config import wm_args
from models.ctrl_world import CrtlWorld
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline


class CachedClipDataset(Dataset):
    """Load .pt video clips from cache_dir (same format as diffusion_dino)."""

    def __init__(self, cache_dir: str, num_frames: int = 11, max_load_retries: int = 5):
        cache_dir = Path(cache_dir)
        self.clips = sorted(cache_dir.glob("*.pt"))
        self.num_frames = num_frames
        self.max_load_retries = max_load_retries
        if not self.clips:
            raise FileNotFoundError(f"No .pt clips in {cache_dir}. Run precache first.")

    def __len__(self):
        return len(self.clips)

    def _load_one(self, path):
        try:
            out = torch.load(path, map_location="cpu", weights_only=True)
        except (TypeError, EOFError):
            out = torch.load(path, map_location="cpu", weights_only=False)
        if out.dim() == 5:
            out = out[0]
        out = out[:, : self.num_frames]
        return out.unsqueeze(0)

    def __getitem__(self, idx):
        last_err = None
        for attempt in range(self.max_load_retries):
            i = (idx + attempt) % len(self.clips) if attempt > 0 else idx
            try:
                return self._load_one(self.clips[i])
            except (EOFError, OSError, RuntimeError) as e:
                last_err = e
                continue
        raise RuntimeError(f"Failed to load clip (last idx={i}): {last_err}")

# Default paths (align with diffusion_dino where applicable; Ctrl-World uses dataset_example by default)
DEFAULT_KEYGRIP_ROOT = Path("/data/cameron/keygrip")
DEFAULT_SCRATCH_ROOT = DEFAULT_KEYGRIP_ROOT / "scratch"
DEFAULT_DROID_CACHE_DIR = Path("/data/cameron/vidgen/dino_vid_model/vid_cache")
DEFAULT_VAE_CKPT = Path("/data/cameron/vidgen/unified_video_action/pretrained_models/vae/kl16.ckpt")
# Ctrl-World DROID: preprocessed latents live under dataset_root_path; meta under dataset_meta_info_path
DEFAULT_CTRLWORLD_DATA_ROOT = Path("dataset_example")
DEFAULT_CTRLWORLD_META = "dataset_meta_info"

# Ctrl-World needs num_history + num_frames (e.g. 6+5=11). Cache may have fewer; we pad.
CTRLWORLD_NUM_LATENT_FRAMES = 6 + 5  # num_history + num_frames from wm_args


def collate_batch(batch_list):
    """Collate dicts from Dataset_mix into a batch."""
    latent = torch.stack([b["latent"] for b in batch_list], dim=0)
    action = torch.stack([b["action"] for b in batch_list], dim=0)
    text = [b["text"] for b in batch_list]
    return {"latent": latent, "action": action, "text": text}


def collate_video_clips(batch):
    """Collate list of (1, C, T, H, W) to (B, C, T, H, W)."""
    return torch.cat(batch, dim=0)


def video_cache_batch_to_latent(model, video, device, args):
    """Convert raw video batch (B, 3, T, H, W) from cache to Ctrl-World latent batch.
    Pads to 11 frames if needed, resizes to 192x320 per view, encodes with SVD VAE, stacks 3 views."""
    B, C, T, H, W = video.shape
    need = CTRLWORLD_NUM_LATENT_FRAMES
    if T < need:
        pad = video[:, :, -1:].expand(-1, -1, need - T, -1, -1)
        video = torch.cat([video, pad], dim=2)
        T = need
    elif T > need:
        video = video[:, :, :need]
        T = need
    h_svd, w_svd = 192, 320
    if H != h_svd or W != w_svd:
        # (B, C, T, H, W) -> (B*T, C, H, W) for interpolate
        video = video.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        video = torch.nn.functional.interpolate(
            video, size=(h_svd, w_svd), mode="bilinear", align_corners=False
        )
        video = video.reshape(B, T, C, h_svd, w_svd).permute(0, 2, 1, 3, 4)
    vae = model.pipeline.vae
    scaling = vae.config.scaling_factor
    latents_one_view = []
    for t in range(T):
        frame = video[:, :, t]
        with torch.no_grad():
            enc = vae.encode(frame.to(device)).latent_dist.mode()
            enc = enc * scaling
        latents_one_view.append(enc)
    lat = torch.stack(latents_one_view, dim=1)
    lat = lat.repeat(1, 1, 1, 3, 1)
    return {
        "latent": lat,
        "action": torch.zeros(B, need, args.action_dim, device=device, dtype=lat.dtype),
        "text": [""] * B,
    }


def run_vis_sample(model, batch, args, device, global_step):
    """Generate one video with zero actions and log to wandb (cond_frame, pred_video, gt_video)."""
    import wandb

    pipeline = model.pipeline
    num_history = args.num_history
    num_frames = args.num_frames

    # Take first sample
    video_gt = batch["latent"][:1].to(device, non_blocking=True)
    texts = [batch["text"][0]]
    # Zero actions for image-to-video
    actions = torch.zeros(
        1, int(num_frames + num_history), args.action_dim,
        device=device, dtype=video_gt.dtype
    )

    his_latent_gt = video_gt[:, :num_history]
    future_latent_ft = video_gt[:, num_history:]
    current_latent = future_latent_ft[:, 0]

    with torch.no_grad():
        action_latent = model.action_encoder(
            actions, texts, model.tokenizer, model.text_encoder,
            frame_level_cond=args.frame_level_cond,
        )

        _, pred_latents = CtrlWorldDiffusionPipeline.__call__(
            pipeline,
            image=current_latent,
            text=action_latent,
            width=args.width,
            height=int(3 * args.height),
            num_frames=num_frames,
            history=his_latent_gt,
            num_inference_steps=args.num_inference_steps,
            decode_chunk_size=args.decode_chunk_size,
            max_guidance_scale=args.guidance_scale,
            fps=args.fps,
            motion_bucket_id=args.motion_bucket_id,
            mask=None,
            output_type="latent",
            return_dict=False,
            frame_level_cond=args.frame_level_cond,
            his_cond_zero=args.his_cond_zero,
        )

    # (1, num_frames, 4, 72, 40) -> (1, num_frames, 4, 24, 40) per view for decode
    pred_latents = einops.rearrange(
        pred_latents, "b f c (m h) (n w) -> (b m n) f c h w", m=3, n=1
    )
    video_gt_cat = torch.cat([his_latent_gt, future_latent_ft], dim=1)
    video_gt_cat = einops.rearrange(
        video_gt_cat, "b f c (m h) (n w) -> (b m n) f c h w", m=3, n=1
    )

    # Decode (first view only for single-view viz: b=0)
    bsz, frame_num = pred_latents.shape[:2]
    decode_chunk = args.decode_chunk_size
    decoded_pred = []
    flat_pred = pred_latents[:1].flatten(0, 1)
    for i in range(0, flat_pred.shape[0], decode_chunk):
        chunk = flat_pred[i : i + decode_chunk] / pipeline.vae.config.scaling_factor
        decode_kwargs = {"num_frames": chunk.shape[0]}
        decoded_pred.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
    pred_video = torch.cat(decoded_pred, dim=0).reshape(1, frame_num, *decoded_pred[0].shape[1:])

    decoded_gt = []
    flat_gt = video_gt_cat[:1].flatten(0, 1)
    num_gt_frames = flat_gt.shape[0]
    for i in range(0, flat_gt.shape[0], decode_chunk):
        chunk = flat_gt[i : i + decode_chunk] / pipeline.vae.config.scaling_factor
        decode_kwargs = {"num_frames": chunk.shape[0]}
        decoded_gt.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
    gt_video = torch.cat(decoded_gt, dim=0).reshape(1, num_gt_frames, *decoded_gt[0].shape[1:])

    # (1, T, 3, H, W) in [-1,1] -> (T, H, W, 3) [0,1] -> uint8
    pred_np = ((pred_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
    pred_np = pred_np.permute(0, 2, 3, 1).numpy()
    pred_frames = (pred_np * 255).astype(np.uint8)

    gt_np = ((gt_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
    gt_np = gt_np.permute(0, 2, 3, 1).numpy()
    gt_frames = (gt_np * 255).astype(np.uint8)

    # Cond frame = first frame of gt (current image)
    cond_frame = gt_frames[0]

    wandb.log(
        {
            "vis/cond_frame": wandb.Image(cond_frame),
            "vis/start_frame": wandb.Image(cond_frame),
            "vis/pred_frame0": wandb.Image(pred_frames[0] if len(pred_frames) > 0 else cond_frame),
        },
        step=global_step,
    )
    try:
        import torchvision
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_pred:
            pred_path = f_pred.name
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_gt:
            gt_path = f_gt.name
        torchvision.io.write_video(pred_path, torch.from_numpy(pred_frames), fps=4)
        torchvision.io.write_video(gt_path, torch.from_numpy(gt_frames), fps=4)
        wandb.log(
            {
                "vis/pred_video": wandb.Video(pred_path, format="mp4"),
                "vis/gt_video": wandb.Video(gt_path, format="mp4"),
            },
            step=global_step,
        )
        Path(pred_path).unlink(missing_ok=True)
        Path(gt_path).unlink(missing_ok=True)
    except Exception:
        pass


def main():
    parser = argparse.ArgumentParser(description="Fine-tune Ctrl-World for image-to-video on DROID")
    # Same args as diffusion_dino/train.py
    parser.add_argument("--keygrip_root", type=Path, default=DEFAULT_KEYGRIP_ROOT, help="Keygrip repo root")
    parser.add_argument("--vae_ckpt", type=Path, default=DEFAULT_VAE_CKPT, help="VAE ckpt (unused; SVD uses built-in)")
    parser.add_argument("--data_root", type=Path, default=DEFAULT_CTRLWORLD_DATA_ROOT, help="Data root (maps to dataset_root_path)")
    parser.add_argument("--dataset", type=str, default="droid_cache", choices=["self_collected", "droid_cache"], help="Dataset source")
    parser.add_argument("--cache_dir", type=Path, default=DEFAULT_DROID_CACHE_DIR, help="Cache dir (unused for Ctrl-World DROID)")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--steps", type=int, default=100000)
    parser.add_argument("--vis_every", type=int, default=100, help="Visualization / sampling cadence")
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_wandb", action="store_true")
    parser.add_argument("--name", type=str, default="ctrl_world_img2vid")
    parser.add_argument("--save_every", type=int, default=500)
    parser.add_argument("--ckpt_dir", type=Path, default=Path("ckpt_ctrl_world_img2vid"))

    # Ctrl-World specific
    parser.add_argument("--svd_model_path", type=str, default=None)
    parser.add_argument("--clip_model_path", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None, help="Resume from this checkpoint")
    parser.add_argument("--dataset_root_path", type=str, default=None, help="Override data root for Dataset_mix")
    parser.add_argument("--dataset_meta_info_path", type=str, default=None)
    parser.add_argument("--dataset_names", type=str, default="droid_subset")

    args_parse = parser.parse_args()
    args = wm_args()

    # Merge CLI over config
    args.dataset_root_path = str(args_parse.dataset_root_path or args_parse.data_root)
    args.dataset_meta_info_path = args_parse.dataset_meta_info_path or DEFAULT_CTRLWORLD_META
    args.dataset_names = args_parse.dataset_names
    args.dataset_cfgs = args.dataset_names
    args.train_batch_size = args_parse.batch_size
    args.learning_rate = args_parse.lr
    if args_parse.svd_model_path is not None:
        args.svd_model_path = args_parse.svd_model_path
    if args_parse.clip_model_path is not None:
        args.clip_model_path = args_parse.clip_model_path

    device = torch.device(args_parse.device if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        # With CUDA_VISIBLE_DEVICES=X, only that GPU is visible and PyTorch reports it as "cuda:0"
        print(f"Using GPU: {torch.cuda.get_device_name(0)} (visible device 0)")

    if args_parse.log_wandb:
        import wandb
        if "WANDB_API_KEY" in os.environ:
            wandb.login(key=os.environ["WANDB_API_KEY"])
        wandb.init(project="ctrl_world_img2vid", config=vars(args_parse), name=args_parse.name, mode="online")

    # Model
    model = CrtlWorld(args)
    if args_parse.ckpt_path:
        state = torch.load(args_parse.ckpt_path, map_location="cpu")
        model.load_state_dict(state, strict=True)
        print(f"Loaded checkpoint from {args_parse.ckpt_path}")
    model.to(device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args_parse.lr)
    args.ckpt_dir = args_parse.ckpt_dir
    args.ckpt_dir.mkdir(parents=True, exist_ok=True)

    use_cache = args_parse.dataset == "droid_cache"
    if use_cache:
        cache_path = args_parse.cache_dir
        train_dataset = CachedClipDataset(
            str(cache_path), num_frames=CTRLWORLD_NUM_LATENT_FRAMES
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=args_parse.batch_size,
            shuffle=True,
            num_workers=args_parse.workers,
            collate_fn=collate_video_clips,
            pin_memory=True,
            persistent_workers=args_parse.workers > 0,
        )
        print(f"Train dataset (droid_cache): {len(train_dataset)} clips from {cache_path}")
    else:
        from dataset.dataset_droid_exp33 import Dataset_mix
        train_dataset = Dataset_mix(args, mode="train")
        train_loader = DataLoader(
            train_dataset,
            batch_size=args_parse.batch_size,
            shuffle=True,
            num_workers=args_parse.workers,
            collate_fn=collate_batch,
            pin_memory=True,
            persistent_workers=args_parse.workers > 0,
        )
        print(f"Train dataset (Dataset_mix): {len(train_dataset)} samples")

    global_step = 0
    pbar = tqdm(total=args_parse.steps, desc="steps", unit="step")

    while True:
        for batch in train_loader:
            if use_cache:
                video = batch.to(device, non_blocking=True)
                batch = video_cache_batch_to_latent(model, video, device, args)
            else:
                batch["action"] = torch.zeros_like(batch["action"])
                batch["latent"] = batch["latent"].to(device, non_blocking=True)
                batch["action"] = batch["action"].to(device, non_blocking=True)

            loss, _ = model(batch)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            if global_step % 50 == 0:
                print(f"step {global_step} loss {loss.item():.4f}")
            if args_parse.log_wandb:
                import wandb
                wandb.log({"train/loss": float(loss.item())}, step=global_step)

            if args_parse.log_wandb and args_parse.vis_every and (global_step > 0 and global_step % args_parse.vis_every == 0):
                print("Sampling video (zero actions)...")
                model.eval()
                run_vis_sample(model, batch, args, device, global_step)
                model.train()

            global_step += 1
            pbar.update(1)

            if args_parse.save_every and (global_step % args_parse.save_every == 0):
                ckpt_path = args_parse.ckpt_dir / f"ctrl_world_img2vid_step{global_step}.pt"
                torch.save(model.state_dict(), ckpt_path)
                print(f"Saved {ckpt_path}")

            if global_step >= args_parse.steps:
                pbar.close()
                if args_parse.log_wandb:
                    import wandb
                    wandb.finish()
                return

    pbar.close()
    if args_parse.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
