"""Visualize motion tracks baseline: line plots for 2d, height, gripper (GT vs Pred). Lift to 3D for metrics."""
import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import MotionTracksTrajectoryPredictor
import model as model_module
from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
_SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT_PATH = str(_SCRIPT_DIR / "checkpoints" / "motion_tracks_baseline" / "best.pth")


def run_motion_tracks_model(model, rgb, device, current_2d=None, current_height=None, current_gripper=None):
    """Run motion tracks model; returns pred_2d (1, N_WINDOW, 2), pred_height (1, N_WINDOW), pred_gripper (1, N_WINDOW).
    current_2d: (2,) or (1, 2), current_height: scalar or (1,), current_gripper: scalar or (1,). If None, zeros.
    """
    rgb = rgb.to(device)
    if current_2d is not None:
        current_2d = torch.as_tensor(current_2d, device=device, dtype=torch.float32)
        if current_2d.dim() == 1:
            current_2d = current_2d.unsqueeze(0)
    if current_height is not None:
        current_height = torch.as_tensor(current_height, device=device, dtype=torch.float32)
        if current_height.dim() == 0:
            current_height = current_height.unsqueeze(0)
    if current_gripper is not None:
        current_gripper = torch.as_tensor(current_gripper, device=device, dtype=torch.float32)
        if current_gripper.dim() == 0:
            current_gripper = current_gripper.unsqueeze(0)
    with torch.no_grad():
        pred_2d, pred_height, pred_gripper = model(rgb, current_2d=current_2d, current_height=current_height, current_gripper_state=current_gripper)
    return pred_2d, pred_height, pred_gripper


def lift_to_3d(pred_2d, pred_height, camera_pose, cam_K_norm, image_size=IMAGE_SIZE):
    """Lift (pred_2d, pred_height) to 3D using recover_3d_from_direct_keypoint_and_height.
    pred_2d: (N_WINDOW, 2), pred_height: (N_WINDOW,), camera_pose: (4,4), cam_K_norm: (3,3) normalized.
    Returns (N_WINDOW, 3) or None for failed waypoints.
    """
    cam_K = cam_K_norm.copy()
    cam_K[0] *= image_size
    cam_K[1] *= image_size
    trajectory_3d = []
    for t in range(pred_2d.shape[0]):
        pt_2d = pred_2d[t]
        h = float(pred_height[t])
        pt_3d = recover_3d_from_direct_keypoint_and_height(pt_2d, h, camera_pose, cam_K)
        trajectory_3d.append(pt_3d if pt_3d is not None else np.full(3, np.nan))
    return np.array(trajectory_3d)


def visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=None):
    """Plot line plots: 2d (x,y), height, gripper (GT vs Pred) for each sample."""
    device = next(model.parameters()).device
    model.eval()
    n_samples = len(indices)
    fig, axes = plt.subplots(n_samples, 4, figsize=(14, 3 * n_samples))
    if n_samples == 1:
        axes = axes[np.newaxis, :]
    ts = np.arange(N_WINDOW)

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        rgb = sample["rgb"].unsqueeze(0).to(device)
        trajectory_2d = sample["trajectory_2d"].numpy()
        trajectory_3d = sample["trajectory_3d"].numpy()
        trajectory_height = trajectory_3d[:, 2]
        trajectory_gripper = sample["trajectory_gripper"].numpy()
        camera_pose = sample["camera_pose"].numpy()
        cam_K_norm = sample["cam_K_norm"].numpy()
        current_2d = trajectory_2d[0]
        current_height = float(trajectory_3d[0, 2])
        current_gripper = float(trajectory_gripper[0])

        pred_2d, pred_height, pred_gripper = run_motion_tracks_model(model, rgb, device, current_2d=current_2d, current_height=current_height, current_gripper=current_gripper)
        pred_2d_np = pred_2d[0].cpu().numpy()
        pred_height_np = pred_height[0].cpu().numpy()
        pred_gripper_np = pred_gripper[0].cpu().numpy()

        axes[i, 0].plot(ts, trajectory_2d[:, 0], 's-', label='GT x', color='green', markersize=4)
        axes[i, 0].plot(ts, pred_2d_np[:, 0], 'o-', label='Pred x', color='red', markersize=4)
        axes[i, 0].set_xlabel('Timestep')
        axes[i, 0].set_ylabel('2D x (px)')
        axes[i, 0].legend()
        axes[i, 0].grid(alpha=0.3)
        axes[i, 1].plot(ts, trajectory_2d[:, 1], 's-', label='GT y', color='green', markersize=4)
        axes[i, 1].plot(ts, pred_2d_np[:, 1], 'o-', label='Pred y', color='red', markersize=4)
        axes[i, 1].set_xlabel('Timestep')
        axes[i, 1].set_ylabel('2D y (px)')
        axes[i, 1].legend()
        axes[i, 1].grid(alpha=0.3)
        axes[i, 2].plot(ts, trajectory_height * 1000, 's-', label='GT height (mm)', color='green', markersize=4)
        axes[i, 2].plot(ts, pred_height_np * 1000, 'o-', label='Pred height (mm)', color='red', markersize=4)
        axes[i, 2].set_xlabel('Timestep')
        axes[i, 2].set_ylabel('Height (mm)')
        axes[i, 2].legend()
        axes[i, 2].grid(alpha=0.3)
        axes[i, 3].plot(ts, trajectory_gripper, 's-', label='GT gripper', color='green', markersize=4)
        axes[i, 3].plot(ts, pred_gripper_np, 'o-', label='Pred gripper', color='red', markersize=4)
        axes[i, 3].set_xlabel('Timestep')
        axes[i, 3].set_ylabel('Gripper')
        axes[i, 3].legend()
        axes[i, 3].grid(alpha=0.3)

    plt.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    else:
        plt.show()
    plt.close()


def compute_metrics(model, dataset, indices=None):
    """Compute pixel error, height error, gripper error, and 3D RMSE (after lifting)."""
    device = next(model.parameters()).device
    model.eval()
    if indices is None:
        indices = list(range(len(dataset)))

    total_pixel = 0.0
    total_height = 0.0
    total_gripper = 0.0
    total_3d_rmse = 0.0
    n = 0
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            rgb = sample["rgb"].unsqueeze(0).to(device)
            trajectory_2d = sample["trajectory_2d"].numpy()
            trajectory_3d = sample["trajectory_3d"].numpy()
            trajectory_gripper = sample["trajectory_gripper"].numpy()
            camera_pose = sample["camera_pose"].numpy()
            cam_K_norm = sample["cam_K_norm"].numpy()
            current_2d = trajectory_2d[0:1, :].copy()
            current_height = trajectory_3d[0:1, 2].copy()
            current_gripper = trajectory_gripper[0:1].copy()
            current_2d_t = torch.from_numpy(current_2d).float().to(device)
            current_height_t = torch.from_numpy(current_height).float().to(device)
            current_gripper_t = torch.from_numpy(current_gripper).float().to(device)

            pred_2d, pred_height, pred_gripper = model(rgb, current_2d=current_2d_t, current_height=current_height_t, current_gripper_state=current_gripper_t)
            pred_2d_np = pred_2d[0].cpu().numpy()
            pred_height_np = pred_height[0].cpu().numpy()
            pred_gripper_np = pred_gripper[0].cpu().numpy()

            pixel_err = np.mean([np.linalg.norm(trajectory_2d[t] - pred_2d_np[t]) for t in range(N_WINDOW)])
            height_err = np.mean(np.abs(trajectory_3d[:, 2] - pred_height_np))
            gripper_err = np.mean(np.abs(trajectory_gripper - pred_gripper_np))
            total_pixel += pixel_err
            total_height += height_err
            total_gripper += gripper_err

            pred_3d = lift_to_3d(pred_2d_np, pred_height_np, camera_pose, cam_K_norm)
            valid = ~np.isnan(pred_3d).any(axis=1)
            if valid.any():
                rmse_3d = np.sqrt(np.mean((pred_3d[valid] - trajectory_3d[valid]) ** 2))
                total_3d_rmse += rmse_3d
            n += 1

    return {
        "n_samples": n,
        "avg_pixel_error_px": total_pixel / n if n else 0,
        "avg_height_error_m": total_height / n if n else 0,
        "avg_height_error_mm": (total_height / n * 1000) if n else 0,
        "avg_gripper_abs_error": total_gripper / n if n else 0,
        "avg_3d_rmse_m": total_3d_rmse / n if n else 0,
    }


def main():
    parser = argparse.ArgumentParser(description="Visualize motion tracks baseline (2d + height + gripper)")
    parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT_PATH)
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_school_cap")
    parser.add_argument("--episode", type=str, default=None)
    parser.add_argument("--start_frame", type=int, default=None)
    parser.add_argument("--max_episodes", type=int, default=1)
    parser.add_argument("--num_samples", type=int, default=6)
    parser.add_argument("--save_dir", type=str, default="scratch/motion_tracks_eval_vis")
    parser.add_argument("--metrics_only", action="store_true")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = MotionTracksTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
    model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    model = model.to(device)
    model.eval()
    print(f"Loaded epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=None if args.episode else args.max_episodes,
    )
    print(f"Total samples: {len(dataset)}")

    if args.episode and args.start_frame is not None:
        idx = None
        for i, (ep_dir, frame_idx) in enumerate(dataset.samples):
            if ep_dir.name == args.episode and frame_idx == args.start_frame:
                idx = i
                break
        if idx is None:
            raise ValueError(f"No sample with episode={args.episode} and start_frame={args.start_frame}")
        indices = [idx]
        title = f"Motion tracks — {args.episode} frame {args.start_frame}"
    else:
        num_samples = min(args.num_samples, len(dataset))
        indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int).tolist()
        title = "Motion tracks — GT vs Pred (2d + height + gripper)"

    print("\nComputing metrics...")
    metrics = compute_metrics(model, dataset, indices)
    print(f"Samples: {metrics['n_samples']}")
    print(f"Avg Pixel Error: {metrics['avg_pixel_error_px']:.2f} px")
    print(f"Avg Height Error: {metrics['avg_height_error_mm']:.2f} mm")
    print(f"Avg Gripper Abs Error: {metrics['avg_gripper_abs_error']:.4f}")
    print(f"Avg 3D RMSE (lifted): {metrics['avg_3d_rmse_m']*1000:.2f} mm")

    if not args.metrics_only:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        save_path = save_dir / "motion_tracks_gt_vs_pred.png"
        if args.episode and args.start_frame is not None:
            save_path = save_dir / f"motion_tracks_{args.episode}_frame{args.start_frame}.png"
        visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=save_path)


if __name__ == "__main__":
    main()
