"""Dataset for point-track pretraining on RTX robot video tracks."""
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import cv2
from PIL import Image

# Point-track pretraining constants
N_WINDOW_POINT_TRACK = 15
N_QUERY_POINTS = 32
HEATMAP_SIZE = 64
TRACKS_ROOT = Path("/data/RTX/tracks")
IMAGE_SIZE_PT = 448
# Motion filter: only use tracks in top 15% by total distance traveled (85th percentile)
MOTION_PERCENTILE = 85


def _top_motion_track_indices(tracks, visibility, percentile=85):
    """Return indices of tracks in the top (100 - percentile)% by total distance traveled.
    So percentile=85 means top 15% (most motion). tracks (T, N, 2), visibility (T, N)."""
    T, N, _ = tracks.shape
    total_dist = np.zeros(N, dtype=np.float64)
    for n in range(N):
        for t in range(T - 1):
            if visibility[t, n] and visibility[t + 1, n]:
                d = np.linalg.norm(tracks[t + 1, n] - tracks[t, n])
                total_dist[n] += d
    if N == 0 or np.all(total_dist == 0):
        return np.arange(N)
    thresh = np.percentile(total_dist, percentile)
    mask = total_dist >= thresh
    indices = np.where(mask)[0]
    return indices if len(indices) > 0 else np.arange(N)


def load_gif_frame(path, frame_idx):
    """Load a single frame from a GIF as (H, W, 3) uint8."""
    with Image.open(path) as im:
        im.seek(int(frame_idx))
        return np.array(im.convert("RGB"))


def load_gif_frames_at_indices(path, indices):
    """Load GIF frames at given indices. Returns (len(indices), H, W, 3) uint8."""
    out = []
    with Image.open(path) as im:
        for i in indices:
            im.seek(int(i))
            out.append(np.array(im.convert("RGB")))
    return np.stack(out)


class RTXPointTrackDataset(Dataset):
    """Load .pt track files from /data/RTX/tracks; first frame as input, predict next N_WINDOW heatmaps for N_QUERY_POINTS.
    Optionally filter to top motion tracks only (motion_percentile=85 -> top 15%% by total distance)."""

    def __init__(self, tracks_root=None, image_size=IMAGE_SIZE_PT, n_window=N_WINDOW_POINT_TRACK, n_query=N_QUERY_POINTS, heatmap_size=HEATMAP_SIZE, max_samples=None, motion_percentile=MOTION_PERCENTILE, shuffle_seed=42):
        self.tracks_root = Path(tracks_root or TRACKS_ROOT)
        self.image_size = image_size
        self.n_window = n_window
        self.n_query = n_query
        self.heatmap_size = heatmap_size
        self.motion_percentile = motion_percentile
        self.pt_paths = sorted(self.tracks_root.glob("*.pt"))
        if not self.pt_paths:
            raise FileNotFoundError(f"No .pt files in {self.tracks_root}")
        if shuffle_seed is not None:
            rng = random.Random(shuffle_seed)
            rng.shuffle(self.pt_paths)
        if max_samples is not None:
            self.pt_paths = self.pt_paths[:max_samples]
        print(f"RTXPointTrackDataset: {len(self.pt_paths)} .pt files from {self.tracks_root}" + (f" (motion filter: top {100 - motion_percentile}%%)" if motion_percentile is not None else "") + (" (paths shuffled)" if shuffle_seed is not None else ""))

    def __len__(self):
        return len(self.pt_paths)

    def __getitem__(self, idx):
        pt_path = self.pt_paths[idx]
        payload = torch.load(pt_path, weights_only=False)
        tracks = payload["tracks"].numpy()   # (T, N, 2)
        T, N = tracks.shape[0], tracks.shape[1]
        # Visibility: saved by point_track_rtx_batch.py from CoTracker (T, N, 1). Used to mask CE loss and motion filter.
        if "visibility" in payload:
            visibility = payload["visibility"].numpy().squeeze()
        else:
            visibility = np.ones((T, N), dtype=np.float32)
        if visibility.ndim == 1:
            if len(visibility) == T:
                visibility = np.broadcast_to(visibility.reshape(-1, 1), (T, N))
            elif len(visibility) == N:
                visibility = np.broadcast_to(visibility.reshape(1, -1), (T, N))
            else:
                visibility = np.broadcast_to(visibility[:, None], (T, N))
        elif visibility.ndim == 2:
            if visibility.shape == (T, N):
                pass
            elif visibility.shape == (N, T):
                visibility = visibility.T
            elif visibility.shape == (T, 1):
                visibility = np.broadcast_to(visibility, (T, N)).copy()
            elif visibility.shape == (1, N):
                visibility = np.broadcast_to(visibility, (T, N)).copy()
            else:
                visibility = np.ones((T, N), dtype=np.float32)
        frame_indices = payload["frame_indices"].numpy()
        gif_path = Path(payload["gif_path"])
        H_orig = int(payload["height"])
        W_orig = int(payload["width"])
        T, N, _ = tracks.shape
        if T < self.n_window + 1:
            last = np.tile(tracks[-1:], (self.n_window + 1 - T, 1, 1))
            tracks = np.concatenate([tracks, last], axis=0)
            vis_last = np.broadcast_to(visibility[-1:], (self.n_window + 1 - T, N))
            visibility = np.concatenate([visibility, vis_last], axis=0)
            T = tracks.shape[0]
        # Restrict to top motion tracks
        if self.motion_percentile is not None:
            candidate_indices = _top_motion_track_indices(tracks, visibility, self.motion_percentile)
        else:
            candidate_indices = np.arange(N)
        n_candidates = len(candidate_indices)
        if n_candidates == 0:
            candidate_indices = np.arange(N)
            n_candidates = N
        query_idx = np.random.choice(candidate_indices, size=min(self.n_query, n_candidates), replace=(n_candidates < self.n_query))
        if len(query_idx) < self.n_query:
            query_idx = np.concatenate([query_idx, np.random.choice(candidate_indices, size=self.n_query - len(query_idx))])
        query_idx = query_idx[: self.n_query]
        # First frame as input
        first_frame = load_gif_frame(gif_path, frame_indices[0])  # (H_orig, W_orig, 3)
        first_frame = cv2.resize(first_frame, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        rgb_tensor = torch.from_numpy(first_frame).permute(2, 0, 1).float()
        if rgb_tensor.max() > 1.0:
            rgb_tensor = rgb_tensor / 255.0
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        rgb_tensor = (rgb_tensor - mean) / std
        # Query start positions in image_size space (448)
        scale_448_x = self.image_size / W_orig
        scale_448_y = self.image_size / H_orig
        query_start = tracks[0, query_idx] * np.array([scale_448_x, scale_448_y])
        query_start_2d = torch.from_numpy(query_start).float()
        # Targets: future frames 1..n_window, in heatmap_size (64) grid for CE
        scale_64_x = self.heatmap_size / W_orig
        scale_64_y = self.heatmap_size / H_orig
        target_indices = []
        vis_mask = []
        for q in range(self.n_query):
            ti = []
            vm = []
            for t in range(1, self.n_window + 1):
                x, y = tracks[t, query_idx[q], 0], tracks[t, query_idx[q], 1]
                x64 = np.clip(int(round(x * scale_64_x)), 0, self.heatmap_size - 1)
                y64 = np.clip(int(round(y * scale_64_y)), 0, self.heatmap_size - 1)
                ti.append(y64 * self.heatmap_size + x64)
                vm.append(visibility[t, query_idx[q]] > 0.5)
            target_indices.append(ti)
            vis_mask.append(vm)
        target_heatmap_indices = torch.from_numpy(np.array(target_indices)).long()
        visibility_mask = torch.from_numpy(np.array(vis_mask)).bool()
        # For wandb: frames and tracks in original res
        frame_inds_vis = frame_indices[: self.n_window + 1]
        frames_vis = load_gif_frames_at_indices(gif_path, frame_inds_vis)
        tracks_vis = tracks[: self.n_window + 1, query_idx]
        return {
            "rgb": rgb_tensor,
            "query_start_2d": query_start_2d,
            "target_heatmap_indices": target_heatmap_indices,
            "visibility": visibility_mask,
            "frames_vis": torch.from_numpy(frames_vis),
            "tracks_vis": torch.from_numpy(tracks_vis).float(),
            "pt_name": pt_path.stem,
        }
