"""Train PARA point tracker on hand wrist keypoints.

Input: single frame + current wrist position
Output: heatmap over future wrist positions (N_WINDOW timesteps ahead)
"""
import os, sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from pathlib import Path
import cv2
import wandb

# Add PARA model path
sys.path.insert(0, "/data/cameron/para_normalized_losses/libero")
os.environ.setdefault("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

import model as model_module
from model import TrajectoryHeatmapPredictor, PRED_SIZE

IMAGE_SIZE = 448
N_WINDOW = 4
FRAME_STRIDE = 1  # predict 1 frame ahead per step (at 4fps = 0.25s per step)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class HandKeypointDataset(Dataset):
    """Dataset of hand video frames with wrist keypoint tracks.

    If episodes.json exists in the frames directory, only samples windows
    from within annotated episode boundaries. Otherwise uses the whole video.
    """

    def __init__(self, frames_dirs, wrist_files, vis_files, n_window=N_WINDOW, frame_stride=FRAME_STRIDE):
        self.samples = []
        self.n_window = n_window
        self.frame_stride = frame_stride

        for frames_dir, wrist_file, vis_file in zip(frames_dirs, wrist_files, vis_files):
            frames_dir = Path(frames_dir)
            wrist_uv = np.load(wrist_file)  # (T, 2)
            visibility = np.load(vis_file)  # (T,)
            frame_files = sorted(frames_dir.glob("*.jpg"))

            # Load episode boundaries if available
            episodes_file = frames_dir / "episodes.json"
            if episodes_file.exists():
                import json
                ep_data = json.load(open(episodes_file))
                episodes = ep_data.get("episodes", [])
                if episodes:
                    ranges = [(ep["start"], ep["end"]) for ep in episodes]
                    print(f"  {frames_dir.name}: {len(episodes)} episodes, "
                          f"{sum(e-s+1 for s,e in ranges)} frames in episodes / {len(frame_files)} total")
                else:
                    ranges = [(0, len(frame_files) - 1)]
            else:
                ranges = [(0, len(frame_files) - 1)]

            for ep_start, ep_end in ranges:
                # Sample windows only within this episode
                for t in range(ep_start, ep_end + 1):
                    future_indices = [t + (i + 1) * frame_stride for i in range(n_window)]
                    # All future frames must be within this episode
                    if future_indices[-1] > ep_end:
                        continue
                    if not visibility[t]:
                        continue
                    if not all(visibility[fi] for fi in future_indices):
                        continue

                    self.samples.append({
                        'frame_path': str(frame_files[t]),
                        'future_frame_paths': [str(frame_files[fi]) for fi in future_indices],
                        'start_uv': wrist_uv[t],
                        'target_uv': wrist_uv[future_indices],  # (N_WINDOW, 2)
                        'video_name': frames_dir.name,
                    })

        print(f"HandKeypointDataset: {len(self.samples)} samples from {len(frames_dirs)} videos")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]

        # Load and preprocess image
        img = cv2.imread(s['frame_path'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
        img = img.astype(np.float32) / 255.0
        for c in range(3):
            img[:, :, c] = (img[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)  # (3, H, W)

        start_uv = torch.from_numpy(s['start_uv'].copy()).float()  # (2,)
        target_uv = torch.from_numpy(s['target_uv'].copy()).float()  # (N_WINDOW, 2)

        return {
            'rgb': img_tensor,
            'start_uv': start_uv,
            'target_uv': target_uv,
        }


def compute_heatmap_loss(volume_logits, target_uv, pred_size=PRED_SIZE):
    """Cross-entropy loss over spatial heatmap."""
    B, T = target_uv.shape[:2]

    # volume_logits: (B, T, 1, H, W) — single height bin
    logits = volume_logits[:, :, 0]  # (B, T, H, W)
    logits_flat = logits.reshape(B * T, pred_size * pred_size)

    # Convert target UV (in 448 space) to grid indices (in pred_size space)
    target_u = (target_uv[:, :, 0] / IMAGE_SIZE * pred_size).long().clamp(0, pred_size - 1)
    target_v = (target_uv[:, :, 1] / IMAGE_SIZE * pred_size).long().clamp(0, pred_size - 1)
    target_flat = target_v * pred_size + target_u  # (B, T)
    target_flat = target_flat.reshape(B * T)

    loss = F.cross_entropy(logits_flat, target_flat)
    return loss


def compute_pixel_error(volume_logits, target_uv, pred_size=PRED_SIZE):
    """Mean pixel error between predicted argmax and target."""
    B, T = target_uv.shape[:2]
    logits = volume_logits[:, :, 0]  # (B, T, H, W)

    # Argmax to get predicted position
    logits_flat = logits.reshape(B, T, pred_size * pred_size)
    argmax = logits_flat.argmax(dim=-1)  # (B, T)
    pred_v = argmax // pred_size
    pred_u = argmax % pred_size

    # Convert back to 448 space
    pred_px = pred_u.float() / pred_size * IMAGE_SIZE
    pred_py = pred_v.float() / pred_size * IMAGE_SIZE

    # Target in 448 space
    err_x = (pred_px - target_uv[:, :, 0]).abs()
    err_y = (pred_py - target_uv[:, :, 1]).abs()
    err = (err_x ** 2 + err_y ** 2).sqrt()
    return err.mean().item()


def make_vis_panel(model, dataset, device, n_vis=4):
    """Create visualization panel: input + heatmap overlays for each timestep.
    Returns RGB numpy array suitable for wandb.Image."""
    model.eval()
    n_vis = min(n_vis, len(dataset))
    rows = []

    for i in range(n_vis):
        sample = dataset[i]
        img_t = sample['rgb'].unsqueeze(0).to(device)
        start_uv = sample['start_uv'].unsqueeze(0).to(device)
        target_uv = sample['target_uv']

        with torch.no_grad():
            vol, _, _, _ = model(img_t, start_uv)
        vol = vol.reshape(1, N_WINDOW, 1, PRED_SIZE, PRED_SIZE)

        # Denormalize image for visualization
        img_np = sample['rgb'].permute(1, 2, 0).numpy().copy()
        for c in range(3):
            img_np[:, :, c] = img_np[:, :, c] * IMAGENET_STD[c] + IMAGENET_MEAN[c]
        img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
        img_rgb = img_np  # already RGB

        thumb = 200
        # Input panel: image + start keypoint (green) + all GT targets (colored dots)
        vis_img = img_rgb.copy()
        su, sv = int(start_uv[0, 0].item()), int(start_uv[0, 1].item())
        cv2.circle(vis_img, (su, sv), 8, (0, 255, 0), -1)  # start = green filled
        cv2.circle(vis_img, (su, sv), 8, (255, 255, 255), 2)  # white outline
        # Draw GT trajectory as connected dots
        colors_gt = [(255, 80, 80), (80, 80, 255), (80, 255, 80), (255, 200, 50)]
        for t in range(N_WINDOW):
            gu, gv = int(target_uv[t, 0].item()), int(target_uv[t, 1].item())
            cv2.circle(vis_img, (gu, gv), 6, colors_gt[t], -1)
            cv2.circle(vis_img, (gu, gv), 6, (255, 255, 255), 1)
            if t > 0:
                pu, pv = int(target_uv[t-1, 0].item()), int(target_uv[t-1, 1].item())
                cv2.line(vis_img, (pu, pv), (gu, gv), colors_gt[t], 1, cv2.LINE_AA)
            else:
                cv2.line(vis_img, (su, sv), (gu, gv), colors_gt[t], 1, cv2.LINE_AA)
        cv2.putText(vis_img, "input+GT", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        panels = [cv2.resize(vis_img, (thumb, thumb))]

        for t in range(N_WINDOW):
            logits = vol[0, t, 0].cpu()
            prob = torch.softmax(logits.reshape(-1), dim=0).reshape(PRED_SIZE, PRED_SIZE).numpy()
            hm = cv2.resize(prob, (IMAGE_SIZE, IMAGE_SIZE))
            hm_norm = (hm / (hm.max() + 1e-8) * 255).astype(np.uint8)
            hm_color = cv2.applyColorMap(hm_norm, cv2.COLORMAP_JET)
            hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB)  # to RGB
            overlay = (img_rgb * 0.4 + hm_color * 0.6).clip(0, 255).astype(np.uint8)
            # GT circle (green outline)
            gu, gv = int(target_uv[t, 0].item()), int(target_uv[t, 1].item())
            cv2.circle(overlay, (gu, gv), 7, (0, 255, 0), 2)
            # Pred circle (white filled)
            peak = np.unravel_index(hm.argmax(), hm.shape)
            cv2.circle(overlay, (peak[1], peak[0]), 5, (255, 255, 255), -1)
            # Error line (yellow) from pred to GT
            cv2.line(overlay, (peak[1], peak[0]), (gu, gv), (255, 255, 0), 1, cv2.LINE_AA)
            px_err = np.sqrt((peak[1] - gu)**2 + (peak[0] - gv)**2)
            cv2.putText(overlay, f"t+{t+1} ({px_err:.0f}px)", (5, 20),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)
            panels.append(cv2.resize(overlay, (thumb, thumb)))

        rows.append(np.concatenate(panels, axis=1))

    return np.concatenate(rows, axis=0)


def make_vis_panel_future_frames(model, dataset, device, n_vis=4):
    """Same as make_vis_panel but overlays heatmaps on the FUTURE frames instead of start frame.
    Returns RGB numpy array suitable for wandb.Image."""
    model.eval()
    n_vis = min(n_vis, len(dataset))
    rows = []

    for i in range(n_vis):
        sample = dataset[i]
        raw_sample = dataset.samples[i]
        img_t = sample['rgb'].unsqueeze(0).to(device)
        start_uv = sample['start_uv'].unsqueeze(0).to(device)
        target_uv = sample['target_uv']

        with torch.no_grad():
            vol, _, _, _ = model(img_t, start_uv)
        vol = vol.reshape(1, N_WINDOW, 1, PRED_SIZE, PRED_SIZE)

        # Denormalize start image for first panel
        img_np = sample['rgb'].permute(1, 2, 0).numpy().copy()
        for c in range(3):
            img_np[:, :, c] = img_np[:, :, c] * IMAGENET_STD[c] + IMAGENET_MEAN[c]
        img_np = (img_np * 255).clip(0, 255).astype(np.uint8)

        thumb = 200
        # First panel: start frame with keypoint
        vis_img = img_np.copy()
        su, sv = int(start_uv[0, 0].item()), int(start_uv[0, 1].item())
        cv2.circle(vis_img, (su, sv), 8, (0, 255, 0), -1)
        cv2.circle(vis_img, (su, sv), 8, (255, 255, 255), 2)
        cv2.putText(vis_img, "start", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        panels = [cv2.resize(vis_img, (thumb, thumb))]

        # Future frame panels
        future_paths = raw_sample['future_frame_paths']
        for t in range(N_WINDOW):
            # Load future frame
            fut_img = cv2.imread(future_paths[t])
            fut_img = cv2.cvtColor(fut_img, cv2.COLOR_BGR2RGB)
            fut_img = cv2.resize(fut_img, (IMAGE_SIZE, IMAGE_SIZE))

            # Heatmap overlay on future frame
            logits = vol[0, t, 0].cpu()
            prob = torch.softmax(logits.reshape(-1), dim=0).reshape(PRED_SIZE, PRED_SIZE).numpy()
            hm = cv2.resize(prob, (IMAGE_SIZE, IMAGE_SIZE))
            hm_norm = (hm / (hm.max() + 1e-8) * 255).astype(np.uint8)
            hm_color = cv2.applyColorMap(hm_norm, cv2.COLORMAP_JET)
            hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB)
            overlay = (fut_img.astype(np.float32) * 0.4 + hm_color.astype(np.float32) * 0.6).clip(0, 255).astype(np.uint8)

            # GT circle (green outline)
            gu, gv = int(target_uv[t, 0].item()), int(target_uv[t, 1].item())
            cv2.circle(overlay, (gu, gv), 7, (0, 255, 0), 2)
            # Pred circle (white filled)
            peak = np.unravel_index(hm.argmax(), hm.shape)
            cv2.circle(overlay, (peak[1], peak[0]), 5, (255, 255, 255), -1)
            # Error line (yellow)
            cv2.line(overlay, (peak[1], peak[0]), (gu, gv), (255, 255, 0), 1, cv2.LINE_AA)
            px_err = np.sqrt((peak[1] - gu)**2 + (peak[0] - gv)**2)
            cv2.putText(overlay, f"t+{t+1} ({px_err:.0f}px)", (5, 20),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)
            panels.append(cv2.resize(overlay, (thumb, thumb)))

        rows.append(np.concatenate(panels, axis=1))

    return np.concatenate(rows, axis=0)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=999)
    parser.add_argument("--max_minutes", type=float, default=0)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--frame_stride", type=int, default=FRAME_STRIDE)
    parser.add_argument("--gpu", type=int, default=4)
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    base = Path("/data/cameron/scratch_files/hand_vids")

    # Set N_HEIGHT_BINS=1 for 2D-only prediction
    model_module.N_HEIGHT_BINS = 1

    # Create datasets
    train_ds = HandKeypointDataset(
        frames_dirs=[base / "hand1_frames", base / "hand2_frames"],
        wrist_files=[base / "hand1_wrist_uv.npy", base / "hand2_wrist_uv.npy"],
        vis_files=[base / "hand1_wrist_vis.npy", base / "hand2_wrist_vis.npy"],
        frame_stride=args.frame_stride,
    )
    test_ds = HandKeypointDataset(
        frames_dirs=[base / "hand3_frames"],
        wrist_files=[base / "hand3_wrist_uv.npy"],
        vis_files=[base / "hand3_wrist_vis.npy"],
        frame_stride=args.frame_stride,
    )

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
    test_loader = DataLoader(test_ds, batch_size=min(args.batch_size, len(test_ds)), shuffle=False, num_workers=0)

    # Build model
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW)
    model = model.to(device)
    print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)

    ckpt_dir = base / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    wandb.init(
        project="para_hand_tracking",
        name="hand_wrist_episodes_stride1",
        config={
            "lr": args.lr,
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "n_window": N_WINDOW,
            "frame_stride": FRAME_STRIDE,
            "image_size": IMAGE_SIZE,
            "train_samples": len(train_ds),
            "test_samples": len(test_ds),
        }
    )

    best_test_loss = float('inf')
    import time
    start_time = time.time()

    for epoch in range(args.epochs):
        if args.max_minutes > 0 and (time.time() - start_time) / 60 > args.max_minutes:
            print(f"Time limit reached ({args.max_minutes:.0f} min). Stopping.")
            break
        # Train
        model.train()
        train_losses = []
        train_errors = []
        for batch in train_loader:
            rgb = batch['rgb'].to(device)
            start_uv = batch['start_uv'].to(device)
            target_uv = batch['target_uv'].to(device)

            vol, _, _, _ = model(rgb, start_uv)
            vol = vol.reshape(rgb.shape[0], N_WINDOW, 1, PRED_SIZE, PRED_SIZE)

            loss = compute_heatmap_loss(vol, target_uv)
            px_err = compute_pixel_error(vol, target_uv)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_losses.append(loss.item())
            train_errors.append(px_err)

        # Test
        model.eval()
        test_losses = []
        test_errors = []
        with torch.no_grad():
            for batch in test_loader:
                rgb = batch['rgb'].to(device)
                start_uv = batch['start_uv'].to(device)
                target_uv = batch['target_uv'].to(device)

                vol, _, _, _ = model(rgb, start_uv)
                vol = vol.reshape(rgb.shape[0], N_WINDOW, 1, PRED_SIZE, PRED_SIZE)

                loss = compute_heatmap_loss(vol, target_uv)
                px_err = compute_pixel_error(vol, target_uv)

                test_losses.append(loss.item())
                test_errors.append(px_err)

        avg_train_loss = np.mean(train_losses)
        avg_train_err = np.mean(train_errors)
        avg_test_loss = np.mean(test_losses) if test_losses else float('inf')
        avg_test_err = np.mean(test_errors) if test_errors else 0

        print(f"Epoch {epoch+1}/{args.epochs}: "
              f"train_loss={avg_train_loss:.4f} train_px={avg_train_err:.1f}px | "
              f"test_loss={avg_test_loss:.4f} test_px={avg_test_err:.1f}px")

        # wandb logging
        log_dict = {
            "train/loss": avg_train_loss,
            "train/pixel_error": avg_train_err,
            "test/loss": avg_test_loss,
            "test/pixel_error": avg_test_err,
            "epoch": epoch + 1,
        }

        # Visualize every 5 epochs or first
        if (epoch + 1) % 5 == 0 or epoch == 0:
            train_vis = make_vis_panel(model, train_ds, device, n_vis=4)
            test_vis = make_vis_panel(model, test_ds, device, n_vis=min(4, len(test_ds)))
            log_dict["vis/train"] = wandb.Image(train_vis, caption=f"Train epoch {epoch+1}")
            log_dict["vis/test"] = wandb.Image(test_vis, caption=f"Test epoch {epoch+1}")
            # Future-frame visualizations
            train_fut = make_vis_panel_future_frames(model, train_ds, device, n_vis=4)
            test_fut = make_vis_panel_future_frames(model, test_ds, device, n_vis=min(4, len(test_ds)))
            log_dict["vis/train_future_frames"] = wandb.Image(train_fut, caption=f"Train on future frames epoch {epoch+1}")
            log_dict["vis/test_future_frames"] = wandb.Image(test_fut, caption=f"Test on future frames epoch {epoch+1}")

        wandb.log(log_dict)

        # Save best
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'epoch': epoch,
                'test_loss': avg_test_loss,
                'test_px_error': avg_test_err,
            }, ckpt_dir / "best.pth")
            print(f"  -> Saved best (test_loss={avg_test_loss:.4f}, test_px={avg_test_err:.1f}px)")

    # Final save
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': args.epochs - 1,
        'test_loss': avg_test_loss,
        'test_px_error': avg_test_err,
    }, ckpt_dir / "latest.pth")

    print(f"\nDone! Best test loss: {best_test_loss:.4f}")
    print(f"Checkpoints: {ckpt_dir}")
    print(f"wandb run: {wandb.run.get_url()}")
    wandb.finish()


if __name__ == "__main__":
    main()
