"""
Run CoTracker 3 on RTX GIFs: forward point tracking from the first frame.

Uses a grid of points on frame 0 and tracks them to the end. Grid density and
video FPS are parameters you can inspect.

Requires: pip install opencv-python (and PyTorch with CUDA for best speed).
CoTracker3 is loaded via torch.hub from facebookresearch/co-tracker.
"""

import argparse
import math
from pathlib import Path

import numpy as np
import torch
from PIL import Image


DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")


def find_gifs(root: Path):
    """Return sorted list of root/*/*.gif."""
    root = Path(root).resolve()
    out = []
    for subdir in sorted(root.iterdir()):
        if not subdir.is_dir():
            continue
        for f in sorted(subdir.glob("*.gif")):
            out.append(f)
    return out


def get_gif_effective_fps(path: Path) -> float | None:
    """Return effective FPS from GIF duration metadata, or None if not available."""
    with Image.open(path) as im:
        n_frames = getattr(im, "n_frames", 1)
        total_ms = 0
        for i in range(n_frames):
            im.seek(i)
            d = im.info.get("duration", 0)
            total_ms += int(d) if d is not None else 0
    if total_ms <= 0:
        return None
    return n_frames / (total_ms / 1000.0)


def load_gif_frames(path: Path, max_frames: int | None = None):
    """Load all frames of a GIF as numpy (T, H, W, 3) uint8."""
    frames = []
    with Image.open(path) as im:
        n = getattr(im, "n_frames", 1)
        if max_frames is not None:
            n = min(n, max_frames)
        for i in range(n):
            im.seek(i)
            frames.append(np.array(im.convert("RGB")))
    return np.stack(frames, axis=0)


def subsample_to_fps(frames: np.ndarray, source_fps: float, target_fps: float):
    """Subsample frames to target_fps. frames: (T,H,W,C). Returns (T',H,W,C)."""
    if target_fps <= 0 or source_fps <= 0 or target_fps >= source_fps:
        return frames
    step = source_fps / target_fps
    indices = np.round(np.arange(0, len(frames), step)).astype(int)
    indices = np.clip(indices, 0, len(frames) - 1)
    return frames[indices]


def run_cotracker(video: torch.Tensor, grid_size: int, device: str = "cuda"):
    """
    Run CoTracker3 offline on video. Forward tracking only (grid on first frame).

    video: (1, T, C, H, W) float in [0, 255]
    Returns: pred_tracks (1, T, N, 2), pred_visibility (1, T, N, 1)
    """
    cotracker = torch.hub.load(
        "facebookresearch/co-tracker",
        "cotracker3_offline",
        trust_repo=True,
    ).to(device)
    video = video.to(device)
    with torch.no_grad():
        pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size)
    return pred_tracks.cpu(), pred_visibility.cpu()


def draw_tracks(
    frames: np.ndarray,
    pred_tracks: np.ndarray,
    pred_visibility: np.ndarray,
    point_radius: int = 3,
    line_thickness: int = 2,
):
    """
    Overlay tracks on frames. pred_tracks (T, N, 2), pred_visibility (T, N, 1).
    Uses rounded coordinates for stable drawing. Returns (T, H, W, 3) uint8.
    """
    try:
        import cv2
    except ImportError:
        raise ImportError("opencv-python is required for drawing tracks. pip install opencv-python")
    T, H, W, C = frames.shape
    out = frames.copy()
    _, N, _ = pred_tracks.shape
    v = pred_visibility.reshape(-1)
    if v.size == T * N:
        vis = (v > 0.5).reshape(T, N)
    else:
        vis = (pred_visibility[..., 0].squeeze() > 0.5)
        if vis.ndim == 1:
            vis = vis.reshape(T, N)
    assert vis.shape == (T, N), f"visibility shape {vis.shape} vs (T={T}, N={N})"

    # Round to integer coords for stable circles/lines (avoid truncation artifacts)
    tracks_int = np.round(pred_tracks).astype(np.int32)
    tracks_int[..., 0] = np.clip(tracks_int[..., 0], 0, W - 1)
    tracks_int[..., 1] = np.clip(tracks_int[..., 1], 0, H - 1)

    for t in range(T):
        for n in range(N):
            if not vis[t, n]:
                continue
            x, y = int(tracks_int[t, n, 0]), int(tracks_int[t, n, 1])
            cv2.circle(out[t], (x, y), point_radius, (0, 255, 0), -1)
        if t > 0:
            for n in range(N):
                if not (vis[t - 1, n] and vis[t, n]):
                    continue
                x0, y0 = int(tracks_int[t - 1, n, 0]), int(tracks_int[t - 1, n, 1])
                x1, y1 = int(tracks_int[t, n, 0]), int(tracks_int[t, n, 1])
                cv2.line(out[t], (x0, y0), (x1, y1), (0, 200, 0), line_thickness)
    return out


def main():
    p = argparse.ArgumentParser(
        description="Run CoTracker 3 on one RTX GIF (forward tracking from first frame)"
    )
    p.add_argument(
        "--data-root",
        type=str,
        default=str(DEFAULT_DATA_ROOT),
        help="Root with subdirs of .gif files (used if --video not set)",
    )
    p.add_argument(
        "--video",
        type=str,
        default=None,
        help="Path to a single GIF. If not set, use first GIF under --data-root.",
    )
    p.add_argument(
        "--grid-size",
        type=int,
        default=32,
        help="Grid density: grid_size x grid_size points on first frame (default 32).",
    )
    p.add_argument(
        "--fps",
        type=float,
        default=2.0,
        help="Load GIF and output video at this FPS (default 2).",
    )
    p.add_argument(
        "--source-fps",
        type=float,
        default=20.0,
        help="Fallback FPS when GIF has no duration metadata; used for subsampling to --fps (default 20).",
    )
    p.add_argument(
        "--max-frames",
        type=int,
        default=60,
        help="Max frames to process (after subsampling to --fps). When subsampling, loads more from GIF as needed (default 60).",
    )
    p.add_argument(
        "--out",
        type=str,
        default="out/rtx_track",
        help="Output path prefix: writes {out}_tracks.mp4 (default: out/rtx_track).",
    )
    p.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device for CoTracker (default: cuda).",
    )
    args = p.parse_args()

    if args.video:
        gif_path = Path(args.video)
        if not gif_path.is_file():
            raise FileNotFoundError(f"Video not found: {gif_path}")
    else:
        root = Path(args.data_root).resolve()
        if not root.is_dir():
            raise FileNotFoundError(f"Data root not found: {root}")
        gifs = find_gifs(root)
        if not gifs:
            raise FileNotFoundError(f"No *.gif under {root}")
        gif_path = gifs[-1]
        print(f"Using first GIF: {gif_path}")

    # Use GIF's effective FPS from duration metadata when available, else --source-fps
    source_fps = get_gif_effective_fps(gif_path)
    if source_fps is None or source_fps <= 0:
        source_fps = args.source_fps
        print(f"  Using --source-fps {source_fps} (no duration metadata in GIF)")
    else:
        print(f"  GIF effective FPS: {source_fps:.1f}")

    # When subsampling, load enough source frames so we get up to max_frames after subsampling
    load_limit = args.max_frames
    if load_limit is not None and source_fps > 0 and args.fps < source_fps:
        load_limit = math.ceil(load_limit * source_fps / args.fps)
    print(f"Loading GIF: {gif_path} (max_frames={load_limit})")
    frames = load_gif_frames(gif_path, max_frames=load_limit)
    T, H, W, C = frames.shape
    print(f"  Loaded {T} frames, {H}x{W}")

    # Subsample to --fps when lower than source so tracking sees more motion per frame
    if source_fps > 0 and args.fps < source_fps:
        frames = subsample_to_fps(frames, source_fps, args.fps)
        if args.max_frames is not None and len(frames) > args.max_frames:
            frames = frames[: args.max_frames]
        T = len(frames)
        print(f"  Subsampled to {args.fps} fps: {T} frames")

    # (T, H, W, C) -> (1, T, C, H, W) float [0, 255]
    video = torch.from_numpy(frames).permute(0, 3, 1, 2).float().unsqueeze(0)
    assert video.shape[1] == T and video.shape[3] == H and video.shape[4] == W

    # Track from middle: k backward + k forward, then concatenate
    k = T // 2
    middle = T - 1 - k  # so backward [middle .. middle-k+1] and forward [middle+1 .. middle+k] fit in [0,T)
    if middle < 0 or k <= 0:
        # Fallback: not enough frames for center tracking
        print(f"Running CoTracker 3 (grid_size={args.grid_size}, forward only) ...")
        pred_tracks, pred_visibility = run_cotracker(video, args.grid_size, device=args.device)
        pred_tracks = pred_tracks.numpy()[0]
        pred_visibility = pred_visibility.numpy()[0]
        out_frames_arr = frames
    else:
        # Backward clip: middle, middle-1, ..., middle-k+1 (k frames)
        indices_back = list(range(middle, middle - k, -1))
        clip_back = video[:, indices_back]
        # Forward clip: middle, middle+1, ..., middle+k (k+1 frames); we use tracks from index 1 for forward
        indices_fwd = list(range(middle, middle + k + 1))
        clip_fwd = video[:, indices_fwd]

        print(f"Running CoTracker 3 from center frame {middle} (grid_size={args.grid_size}): {k} back + {k} forward ...")
        tracks_back, vis_back = run_cotracker(clip_back, args.grid_size, device=args.device)
        tracks_fwd, vis_fwd = run_cotracker(clip_fwd, args.grid_size, device=args.device)

        tracks_back = tracks_back.numpy()[0]   # (k, N, 2)
        vis_back = vis_back.numpy()[0]
        tracks_fwd = tracks_fwd.numpy()[0]     # (k+1, N, 2)
        vis_fwd = vis_fwd.numpy()[0]

        # Concatenate: backward (middle -> middle-k+1) + forward (middle+1 -> middle+k); drop duplicate middle from forward
        pred_tracks = np.concatenate([tracks_back, tracks_fwd[1:]], axis=0)
        pred_visibility = np.concatenate([vis_back, vis_fwd[1:]], axis=0)
        if pred_visibility.ndim == 2:
            pred_visibility = pred_visibility[:, :, None]

        # Output frames in same order: [middle, middle-1, ..., middle-k+1, middle+1, ..., middle+k]
        out_indices = indices_back + indices_fwd[1:]
        out_frames_arr = frames[out_indices]
        print(f"  Tracks shape: {pred_tracks.shape}")

    print("Drawing tracks ...")
    out_frames = draw_tracks(out_frames_arr, pred_tracks, pred_visibility)

    out_path = Path(args.out).with_suffix("") if Path(args.out).suffix else Path(args.out)
    out_path = out_path.parent / (out_path.name + "_tracks.mp4")
    out_path.parent.mkdir(parents=True, exist_ok=True)

    # Prefer imageio+pyav with libx264 for smooth playback; fallback to cv2
    written = False
    try:
        import imageio.v3 as iio
        iio.imwrite(
            str(out_path),
            out_frames,
            plugin="pyav",
            codec="libx264",
            fps=args.fps,
        )
        written = True
    except Exception:
        pass
    if not written:
        import cv2
        fourcc = cv2.VideoWriter_fourcc(*"avc1")  # H.264 when available
        writer = cv2.VideoWriter(str(out_path), fourcc, args.fps, (W, H))
        if not writer.isOpened():
            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            writer = cv2.VideoWriter(str(out_path), fourcc, args.fps, (W, H))
        for f in out_frames:
            writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
        writer.release()
    print(f"Saved: {out_path} (fps={args.fps})")
    print("Done.")


if __name__ == "__main__":
    main()
