"""
Filter track .pt by motion: keep only points in the top (100 - percentile)% by total distance traveled.
Visualize the filtered tracks over the video (85th percentile = top 15% motion by default).
"""

import argparse
import sys
from pathlib import Path

import numpy as np
import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "scripts"))

from point_track_rtx import draw_tracks, load_gif_frames


def per_point_total_motion(tracks: np.ndarray, visibility: np.ndarray) -> np.ndarray:
    """Total distance traveled per point (frame-to-frame), only where both frames visible. (N,)"""
    T, N, _ = tracks.shape
    disp = tracks[1:] - tracks[:-1]  # (T-1, N, 2)
    dist = np.sqrt((disp ** 2).sum(axis=-1))  # (T-1, N)
    vis = (visibility > 0.5).reshape(T, N)
    # Only count displacement where both t and t+1 are visible
    mask = vis[:-1] & vis[1:]
    motion = np.where(mask, dist, 0.0).sum(axis=0)  # (N,)
    return motion


def main():
    p = argparse.ArgumentParser(description="Filter tracks by motion percentile; visualize top movers")
    p.add_argument("pt", type=str, help="Path to .pt file (e.g. /data/RTX/tracks/viola_99.pt)")
    p.add_argument("--percentile", type=float, default=85.0,
                   help="Keep points with motion >= this percentile (default 85 = top 15%%)")
    p.add_argument("--out", type=str, default=None,
                   help="Output MP4 path (default: same stem as .pt + _motion85.mp4)")
    args = p.parse_args()

    pt_path = Path(args.pt)
    if not pt_path.is_file():
        raise FileNotFoundError(f"Not found: {pt_path}")

    payload = torch.load(pt_path, weights_only=False)
    tracks = payload["tracks"].numpy()   # (T, N, 2)
    visibility = payload["visibility"].numpy()  # (T, N, 1) or (T, N)
    frame_indices = payload["frame_indices"].numpy()
    gif_path = Path(payload["gif_path"])
    fps = float(payload["fps"])
    H = int(payload["height"])
    W = int(payload["width"])

    T, N, _ = tracks.shape
    motion = per_point_total_motion(tracks, visibility)
    thresh = np.percentile(motion, args.percentile)
    keep = motion >= thresh
    n_keep = keep.sum()
    print(f"Motion percentile {args.percentile}: threshold = {thresh:.2f}, keeping {n_keep}/{N} points (top {100 - args.percentile:.0f}%)")

    tracks_f = tracks[:, keep]
    vis_f = visibility[:, keep] if visibility.ndim == 3 else visibility[:, keep, None]
    if vis_f.ndim == 2:
        vis_f = vis_f[:, :, None]

    # Load video frames at frame_indices
    max_idx = int(frame_indices.max()) + 1
    all_frames = load_gif_frames(gif_path, max_frames=max_idx)
    frames = all_frames[frame_indices]  # (T, H, W, 3)

    out_frames = draw_tracks(frames, tracks_f, vis_f)
    out_path = Path(args.out) if args.out else pt_path.with_name(pt_path.stem + f"_motion{int(args.percentile)}.mp4")
    out_path.parent.mkdir(parents=True, exist_ok=True)

    try:
        import imageio.v3 as iio
        iio.imwrite(str(out_path), out_frames, plugin="pyav", codec="libx264", fps=fps)
    except Exception:
        import cv2
        fourcc = cv2.VideoWriter_fourcc(*"avc1")
        writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W, H))
        if not writer.isOpened():
            writer = cv2.VideoWriter(str(out_path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (W, H))
        for f in out_frames:
            writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
        writer.release()
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
