"""Visualize ACT baseline: line plots for xyz and gripper (GT vs Pred). No heatmaps."""
import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import ACTTrajectoryPredictor
import model as model_module

IMAGE_SIZE = 448
_SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT_PATH = str(_SCRIPT_DIR / "checkpoints" / "act_baseline" / "best.pth")


def run_act_model(model, rgb, device, current_3d=None, current_gripper=None):
    """Run ACT model; returns pred_trajectory_3d (1, N_WINDOW, 3), pred_gripper (1, N_WINDOW).
    current_3d: (3,) or (1, 3), current_gripper: scalar or (1,). If None, zeros.
    """
    rgb = rgb.to(device)
    if current_3d is not None:
        current_3d = torch.as_tensor(current_3d, device=device, dtype=torch.float32)
        if current_3d.dim() == 1:
            current_3d = current_3d.unsqueeze(0)
    if current_gripper is not None:
        current_gripper = torch.as_tensor(current_gripper, device=device, dtype=torch.float32)
        if current_gripper.dim() == 0:
            current_gripper = current_gripper.unsqueeze(0)
    with torch.no_grad():
        pred_3d, pred_gripper = model(rgb, current_3d=current_3d, current_gripper_state=current_gripper)
    return pred_3d, pred_gripper


def visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=None):
    """Plot line plots: xyz (3 subplots) and gripper (GT vs Pred) for each sample."""
    device = next(model.parameters()).device
    model.eval()
    n_samples = len(indices)
    fig, axes = plt.subplots(n_samples, 4, figsize=(14, 3 * n_samples))  # x, y, z, gripper per sample
    if n_samples == 1:
        axes = axes[np.newaxis, :]
    ts = np.arange(N_WINDOW)

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        rgb = sample["rgb"].unsqueeze(0).to(device)
        trajectory_3d = sample["trajectory_3d"].numpy()  # (N_WINDOW, 3)
        trajectory_gripper = sample["trajectory_gripper"].numpy()  # (N_WINDOW,)
        current_3d = trajectory_3d[0]
        current_gripper = float(trajectory_gripper[0])

        pred_3d, pred_gripper = run_act_model(model, rgb, device, current_3d=current_3d, current_gripper=current_gripper)
        pred_3d_np = pred_3d[0].cpu().numpy()
        pred_gripper_np = pred_gripper[0].cpu().numpy()

        for j, (name, col) in enumerate([('x', 0), ('y', 1), ('z', 2)]):
            ax = axes[i, j]
            ax.plot(ts, trajectory_3d[:, col], 's-', label='GT', color='green', markersize=4)
            ax.plot(ts, pred_3d_np[:, col], 'o-', label='Pred', color='red', markersize=4)
            ax.set_xlabel('Timestep')
            ax.set_ylabel(f'{name} (m)')
            ax.set_title(f'Sample {i} {name}')
            ax.legend()
            ax.grid(alpha=0.3)

        ax = axes[i, 3]
        ax.plot(ts, trajectory_gripper, 's-', label='GT gripper', color='green', markersize=4)
        ax.plot(ts, pred_gripper_np, 'o-', label='Pred gripper', color='red', markersize=4)
        ax.set_xlabel('Timestep')
        ax.set_ylabel('Gripper')
        ax.set_title(f'Sample {i} gripper')
        ax.legend()
        ax.grid(alpha=0.3)

    plt.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    else:
        plt.show()
    plt.close()


def compute_metrics(model, dataset, indices=None):
    """Compute 3D L2 error and gripper abs error (ACT outputs 3D directly)."""
    device = next(model.parameters()).device
    model.eval()
    if indices is None:
        indices = list(range(len(dataset)))

    total_3d_err = 0.0
    total_gripper_err = 0.0
    n = 0
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            rgb = sample["rgb"].unsqueeze(0).to(device)
            trajectory_3d = sample["trajectory_3d"].to(device)
            trajectory_gripper = sample["trajectory_gripper"].to(device)
            current_3d = trajectory_3d[0:1, 0, :]  # (1, 3)
            current_gripper = trajectory_gripper[0:1, 0]  # (1,)

            pred_3d, pred_gripper = model(rgb, current_3d=current_3d, current_gripper_state=current_gripper)
            total_3d_err += F.mse_loss(pred_3d, trajectory_3d.unsqueeze(0)).item()
            total_gripper_err += torch.abs(pred_gripper[0] - trajectory_gripper).mean().item()
            n += 1

    return {
        "n_samples": n,
        "avg_3d_mse": total_3d_err / n if n else 0,
        "avg_3d_rmse_m": np.sqrt(total_3d_err / n) if n else 0,
        "avg_gripper_abs_error": total_gripper_err / n if n else 0,
    }


def main():
    parser = argparse.ArgumentParser(description="Visualize ACT baseline (line plots xyz + gripper)")
    parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT_PATH)
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_school_cap")
    parser.add_argument("--episode", type=str, default=None)
    parser.add_argument("--start_frame", type=int, default=None)
    parser.add_argument("--max_episodes", type=int, default=1)
    parser.add_argument("--num_samples", type=int, default=6)
    parser.add_argument("--save_dir", type=str, default="scratch/act_eval_vis")
    parser.add_argument("--metrics_only", action="store_true")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = ACTTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
    model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    model = model.to(device)
    model.eval()
    print(f"Loaded epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=None if args.episode else args.max_episodes,
    )
    print(f"Total samples: {len(dataset)}")

    if args.episode and args.start_frame is not None:
        idx = None
        for i, (ep_dir, frame_idx) in enumerate(dataset.samples):
            if ep_dir.name == args.episode and frame_idx == args.start_frame:
                idx = i
                break
        if idx is None:
            raise ValueError(f"No sample with episode={args.episode} and start_frame={args.start_frame}")
        indices = [idx]
        title = f"ACT — {args.episode} frame {args.start_frame}"
    else:
        num_samples = min(args.num_samples, len(dataset))
        indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int).tolist()
        title = "ACT — GT vs Pred (xyz + gripper)"

    print("\nComputing metrics...")
    metrics = compute_metrics(model, dataset, indices)
    print(f"Samples: {metrics['n_samples']}")
    print(f"Avg 3D RMSE: {metrics['avg_3d_rmse_m']*1000:.3f} mm")
    print(f"Avg Gripper Abs Error: {metrics['avg_gripper_abs_error']:.4f}")

    if not args.metrics_only:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        save_path = save_dir / "act_gt_vs_pred.png"
        if args.episode and args.start_frame is not None:
            save_path = save_dir / f"act_{args.episode}_frame{args.start_frame}.png"
        visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=save_path)


if __name__ == "__main__":
    main()
