"""Extract hand wrist keypoints from video frames using mediapipe solutions API."""
import mediapipe as mp
import cv2
import numpy as np
from pathlib import Path

mp_hands = mp.solutions.hands


def extract_wrist_tracks(frames_dir, image_size=448):
    """Extract wrist (landmark 0) 2D pixel coordinates from all frames."""
    frames_dir = Path(frames_dir)
    frame_files = sorted(frames_dir.glob("*.jpg"))

    wrist_coords = []
    visibility = []

    with mp_hands.Hands(
        static_image_mode=True,
        max_num_hands=1,
        min_detection_confidence=0.3
    ) as hands:
        for i, fpath in enumerate(frame_files):
            img = cv2.imread(str(fpath))
            rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            result = hands.process(rgb)

            if result.multi_hand_landmarks:
                wrist = result.multi_hand_landmarks[0].landmark[0]  # landmark 0 = wrist
                px = wrist.x * image_size
                py = wrist.y * image_size
                wrist_coords.append([px, py])
                visibility.append(True)
            else:
                wrist_coords.append([0.0, 0.0])
                visibility.append(False)

            if (i + 1) % 50 == 0:
                print(f"  {i+1}/{len(frame_files)} frames processed")

    wrist_coords = np.array(wrist_coords, dtype=np.float32)
    visibility = np.array(visibility, dtype=bool)

    # Interpolate missing detections
    if not visibility.all():
        for dim in range(2):
            valid_idx = np.where(visibility)[0]
            if len(valid_idx) >= 2:
                interp_vals = np.interp(
                    np.arange(len(wrist_coords)),
                    valid_idx,
                    wrist_coords[valid_idx, dim]
                )
                wrist_coords[:, dim] = np.where(visibility, wrist_coords[:, dim], interp_vals)

    return wrist_coords, visibility


def main():
    base = Path("/data/cameron/scratch_files/hand_vids")

    for name in ["hand1_frames", "hand2_frames", "hand3_frames"]:
        frames_dir = base / name
        print(f"\nProcessing {name}...")
        coords, vis = extract_wrist_tracks(frames_dir)

        out_coords = base / f"{name.replace('_frames', '')}_wrist_uv.npy"
        out_vis = base / f"{name.replace('_frames', '')}_wrist_vis.npy"
        np.save(out_coords, coords)
        np.save(out_vis, vis)

        n_detected = vis.sum()
        print(f"  {name}: {len(coords)} frames, {n_detected}/{len(coords)} detected ({100*n_detected/len(coords):.0f}%)")
        if n_detected > 0:
            print(f"  UV range: x=[{coords[vis,0].min():.1f}, {coords[vis,0].max():.1f}], y=[{coords[vis,1].min():.1f}, {coords[vis,1].max():.1f}]")
        print(f"  Saved: {out_coords}")


if __name__ == "__main__":
    main()
