"""
Run center-frame CoTracker on all RTX GIFs; save tracks as .pt tensors.

Processes paths in shuffled order for diversity when paused. Saves to
/data/RTX/tracks/ with filename = path after RTX_Video, / replaced by _.
Every 10th video also gets a track visualization MP4 in /data/RTX/tracks_vis/.

Each .pt contains: tracks, visibility, frame_indices (into original GIF),
gif_path, fps, and metadata for overlaying on the video.
"""

import argparse
import math
import random
import sys
from pathlib import Path

import numpy as np
import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "scripts"))

from point_track_rtx import (
    draw_tracks,
    find_gifs,
    get_gif_effective_fps,
    load_gif_frames,
    run_cotracker,
)

DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")
DEFAULT_TRACKS_DIR = Path("/data/RTX/tracks")
DEFAULT_TRACKS_VIS_DIR = Path("/data/RTX/tracks_vis")


def path_to_track_name(gif_path: Path, data_root: Path) -> str:
    """Path after data_root with / replaced by _, and .gif removed."""
    try:
        rel = gif_path.resolve().relative_to(data_root.resolve())
    except ValueError:
        rel = gif_path.name
    return str(rel).replace("/", "_").replace("\\", "_").removesuffix(".gif")


def _write_vis_mp4(out_frames: np.ndarray, out_path: Path, fps: float, H: int, W: int) -> None:
    """Write (T, H, W, 3) uint8 to mp4 at out_path."""
    out_path.parent.mkdir(parents=True, exist_ok=True)
    try:
        import imageio.v3 as iio
        iio.imwrite(str(out_path), out_frames, plugin="pyav", codec="libx264", fps=fps)
        return
    except Exception:
        pass
    import cv2
    fourcc = cv2.VideoWriter_fourcc(*"avc1")
    writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W, H))
    if not writer.isOpened():
        writer = cv2.VideoWriter(str(out_path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (W, H))
    for f in out_frames:
        writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
    writer.release()


def process_one(
    gif_path: Path,
    data_root: Path,
    tracks_dir: Path,
    grid_size: int,
    fps: float,
    source_fps: float,
    max_frames: int | None,
    device: str,
    vis_dir: Path | None = None,
    index: int = 0,
) -> tuple[bool, bool]:
    """Load GIF, run center tracking, save .pt. If vis_dir set and (index+1)%10==0, save mp4 vis.
    Returns (success, skipped). skipped=True when .pt already existed."""
    out_name = path_to_track_name(gif_path, data_root) + ".pt"
    out_path = tracks_dir / out_name
    if out_path.exists():
        return (True, True)  # skip already done (resume-friendly)

    source_fps_eff = get_gif_effective_fps(gif_path)
    if source_fps_eff is None or source_fps_eff <= 0:
        source_fps_eff = source_fps
    load_limit = max_frames
    if load_limit is not None and source_fps_eff > 0 and fps < source_fps_eff:
        load_limit = math.ceil(load_limit * source_fps_eff / fps)

    frames = load_gif_frames(gif_path, max_frames=load_limit)
    T, H, W, C = frames.shape
    if T < 2:
        return (False, False)

    # Subsample and record original frame indices for each subsampled frame
    subsampled_indices = None
    if source_fps_eff > 0 and fps < source_fps_eff:
        step = source_fps_eff / fps
        subsampled_indices = np.round(np.arange(0, len(frames), step)).astype(int)
        subsampled_indices = np.clip(subsampled_indices, 0, len(frames) - 1)
        frames = frames[subsampled_indices]
        T = len(frames)
        if max_frames is not None and T > max_frames:
            frames = frames[:max_frames]
            subsampled_indices = subsampled_indices[:max_frames]
            T = len(frames)
    else:
        if max_frames is not None and T > max_frames:
            frames = frames[:max_frames]
            T = len(frames)
        subsampled_indices = np.arange(T)

    video = torch.from_numpy(frames).permute(0, 3, 1, 2).float().unsqueeze(0)

    k = T // 2
    middle = T - 1 - k
    if middle < 0 or k <= 0:
        pred_tracks, pred_visibility = run_cotracker(video, grid_size, device=device)
        pred_tracks = pred_tracks.numpy()[0]
        pred_visibility = pred_visibility.numpy()[0]
        if pred_visibility.ndim == 2:
            pred_visibility = pred_visibility[:, :, None]
        out_indices = np.arange(T)
        out_frames_arr = frames
    else:
        indices_back = list(range(middle, middle - k, -1))
        indices_fwd = list(range(middle, middle + k + 1))
        clip_back = video[:, indices_back]
        clip_fwd = video[:, indices_fwd]
        tracks_back, vis_back = run_cotracker(clip_back, grid_size, device=device)
        tracks_fwd, vis_fwd = run_cotracker(clip_fwd, grid_size, device=device)
        tracks_back = tracks_back.numpy()[0]
        vis_back = vis_back.numpy()[0]
        tracks_fwd = tracks_fwd.numpy()[0]
        vis_fwd = vis_fwd.numpy()[0]
        pred_tracks = np.concatenate([tracks_back, tracks_fwd[1:]], axis=0)
        pred_visibility = np.concatenate([vis_back, vis_fwd[1:]], axis=0)
        if pred_visibility.ndim == 2:
            pred_visibility = pred_visibility[:, :, None]
        out_indices = np.array(indices_back + indices_fwd[1:], dtype=np.int64)
        out_frames_arr = frames[out_indices]

    # Frame indices into the original GIF (before subsampling)
    frame_indices = torch.from_numpy(subsampled_indices[out_indices])

    payload = {
        "tracks": torch.from_numpy(pred_tracks).float(),          # (T, N, 2) xy per frame
        "visibility": torch.from_numpy(pred_visibility).float(),  # (T, N, 1)
        "frame_indices": frame_indices.long(),                    # (T,) indices in original GIF for overlay
        "gif_path": str(gif_path.resolve()),
        "fps": fps,
        "source_fps": float(source_fps_eff),
        "grid_size": grid_size,
        "height": H,
        "width": W,
        "num_frames": len(out_indices),
        "num_points": pred_tracks.shape[1],
    }
    tracks_dir.mkdir(parents=True, exist_ok=True)
    torch.save(payload, out_path)

    # Save track visualization every 10th video
    if vis_dir is not None and (index + 1) % 10 == 0:
        vis_frames = draw_tracks(out_frames_arr, pred_tracks, pred_visibility)
        vis_name = path_to_track_name(gif_path, data_root) + "_tracks.mp4"
        _write_vis_mp4(vis_frames, vis_dir / vis_name, fps, H, W)

    return (True, False)  # success, not skipped


def main():
    p = argparse.ArgumentParser(description="Run CoTracker on all RTX GIFs; save .pt tracks")
    p.add_argument("--data-root", type=str, default=str(DEFAULT_DATA_ROOT))
    p.add_argument("--tracks-dir", type=str, default=str(DEFAULT_TRACKS_DIR))
    p.add_argument("--tracks-vis-dir", type=str, default=str(DEFAULT_TRACKS_VIS_DIR),
                   help="Save track MP4 visualization every 10th video here")
    p.add_argument("--grid-size", type=int, default=32)
    p.add_argument("--fps", type=float, default=4.0)
    p.add_argument("--source-fps", type=float, default=20.0)
    p.add_argument("--max-frames", type=int, default=60)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--limit", type=int, default=None, help="Max number of GIFs to process (for testing)")
    args = p.parse_args()

    data_root = Path(args.data_root).resolve()
    tracks_dir = Path(args.tracks_dir).resolve()
    tracks_vis_dir = Path(args.tracks_vis_dir).resolve()
    if not data_root.is_dir():
        raise FileNotFoundError(f"Data root not found: {data_root}")

    gifs = find_gifs(data_root)
    if not gifs:
        raise FileNotFoundError(f"No *.gif under {data_root}")

    if args.seed is not None:
        random.seed(args.seed)
    random.shuffle(gifs)
    if args.limit is not None:
        gifs = gifs[: args.limit]

    print(f"Processing {len(gifs)} GIFs (shuffled), saving to {tracks_dir}, vis every 10th to {tracks_vis_dir}")
    done = 0
    tracked = 0
    for i, path in enumerate(gifs):
        try:
            ok, skipped = process_one(
                path,
                data_root,
                tracks_dir,
                args.grid_size,
                args.fps,
                args.source_fps,
                args.max_frames,
                args.device,
                vis_dir=tracks_vis_dir,
                index=i,
            )
            if ok:
                done += 1
            if ok and not skipped:
                tracked += 1
            if skipped:
                status = "skip (exists)"
            elif ok:
                status = "saved"
            else:
                status = "fail"
            vis_note = " +vis" if (ok and not skipped and (i + 1) % 10 == 0) else ""
            print(f"  [{i+1}/{len(gifs)}] {path.name} -> {status}{vis_note}")
        except Exception as e:
            print(f"  [{i+1}/{len(gifs)}] {path.name} ERROR: {e}")
    print(f"Done. Tracked & saved {tracked}/{len(gifs)} (skipped {done - tracked} existing).")


if __name__ == "__main__":
    main()
