"""Visualize diffusion baseline: line plots for xyz and gripper (GT vs Pred). No heatmaps."""
import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import DiffusionTrajectoryPredictor
import model as model_module

IMAGE_SIZE = 448
_SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT_PATH = str(_SCRIPT_DIR / "checkpoints" / "diffusion_baseline" / "best.pth")


def run_diffusion_model(model, rgb, device, current_3d=None, current_gripper=None):
    """Run diffusion model (sample); returns pred_trajectory_3d (1, N_WINDOW, 3), pred_gripper (1, N_WINDOW).
    current_3d: (3,) or (1, 3), current_gripper: scalar or (1,). If None, zeros.
    """
    rgb = rgb.to(device)
    if current_3d is not None:
        current_3d = torch.as_tensor(current_3d, device=device, dtype=torch.float32)
        if current_3d.dim() == 1:
            current_3d = current_3d.unsqueeze(0)
    if current_gripper is not None:
        current_gripper = torch.as_tensor(current_gripper, device=device, dtype=torch.float32)
        if current_gripper.dim() == 0:
            current_gripper = current_gripper.unsqueeze(0)
    with torch.no_grad():
        pred_3d, pred_gripper = model(rgb, training=False, current_3d=current_3d, current_gripper_state=current_gripper)
    return pred_3d, pred_gripper


def visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=None):
    """Plot line plots: xyz and gripper (GT vs Pred) for each sample."""
    device = next(model.parameters()).device
    model.eval()
    n_samples = len(indices)
    fig, axes = plt.subplots(n_samples, 4, figsize=(14, 3 * n_samples))
    if n_samples == 1:
        axes = axes[np.newaxis, :]
    ts = np.arange(N_WINDOW)

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        rgb = sample["rgb"].unsqueeze(0).to(device)
        trajectory_3d = sample["trajectory_3d"].numpy()
        trajectory_gripper = sample["trajectory_gripper"].numpy()
        current_3d = trajectory_3d[0]
        current_gripper = float(trajectory_gripper[0])

        pred_3d, pred_gripper = run_diffusion_model(model, rgb, device, current_3d=current_3d, current_gripper=current_gripper)
        pred_3d_np = pred_3d[0].cpu().numpy()
        pred_gripper_np = pred_gripper[0].cpu().numpy()

        for j, (name, col) in enumerate([('x', 0), ('y', 1), ('z', 2)]):
            ax = axes[i, j]
            ax.plot(ts, trajectory_3d[:, col], 's-', label='GT', color='green', markersize=4)
            ax.plot(ts, pred_3d_np[:, col], 'o-', label='Pred', color='red', markersize=4)
            ax.set_xlabel('Timestep')
            ax.set_ylabel(f'{name} (m)')
            ax.set_title(f'Sample {i} {name}')
            ax.legend()
            ax.grid(alpha=0.3)

        ax = axes[i, 3]
        ax.plot(ts, trajectory_gripper, 's-', label='GT gripper', color='green', markersize=4)
        ax.plot(ts, pred_gripper_np, 'o-', label='Pred gripper', color='red', markersize=4)
        ax.set_xlabel('Timestep')
        ax.set_ylabel('Gripper')
        ax.set_title(f'Sample {i} gripper')
        ax.legend()
        ax.grid(alpha=0.3)

    plt.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    else:
        plt.show()
    plt.close()


def compute_metrics(model, dataset, indices=None):
    """Compute 3D MSE/RMSE and gripper abs error (diffusion outputs 3D directly after sampling)."""
    device = next(model.parameters()).device
    model.eval()
    if indices is None:
        indices = list(range(len(dataset)))

    total_3d_err = 0.0
    total_gripper_err = 0.0
    n = 0
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            rgb = sample["rgb"].unsqueeze(0).to(device)
            trajectory_3d = sample["trajectory_3d"].to(device)
            trajectory_gripper = sample["trajectory_gripper"].to(device)
            current_3d = trajectory_3d[0:1, 0, :]
            current_gripper = trajectory_gripper[0:1, 0]

            pred_3d, pred_gripper = model(rgb, training=False, current_3d=current_3d, current_gripper_state=current_gripper)
            total_3d_err += F.mse_loss(pred_3d, trajectory_3d.unsqueeze(0)).item()
            total_gripper_err += torch.abs(pred_gripper[0] - trajectory_gripper).mean().item()
            n += 1

    return {
        "n_samples": n,
        "avg_3d_mse": total_3d_err / n if n else 0,
        "avg_3d_rmse_m": np.sqrt(total_3d_err / n) if n else 0,
        "avg_gripper_abs_error": total_gripper_err / n if n else 0,
    }


def main():
    parser = argparse.ArgumentParser(description="Visualize diffusion baseline (line plots xyz + gripper)")
    parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT_PATH)
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_school_cap")
    parser.add_argument("--episode", type=str, default=None)
    parser.add_argument("--start_frame", type=int, default=None)
    parser.add_argument("--max_episodes", type=int, default=1)
    parser.add_argument("--num_samples", type=int, default=6)
    parser.add_argument("--save_dir", type=str, default="scratch/diffusion_eval_vis")
    parser.add_argument("--metrics_only", action="store_true")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = DiffusionTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
    model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    model = model.to(device)
    model.eval()
    print(f"Loaded epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=None if args.episode else args.max_episodes,
    )
    print(f"Total samples: {len(dataset)}")

    if args.episode and args.start_frame is not None:
        idx = None
        for i, (ep_dir, frame_idx) in enumerate(dataset.samples):
            if ep_dir.name == args.episode and frame_idx == args.start_frame:
                idx = i
                break
        if idx is None:
            raise ValueError(f"No sample with episode={args.episode} and start_frame={args.start_frame}")
        indices = [idx]
        title = f"Diffusion — {args.episode} frame {args.start_frame}"
    else:
        num_samples = min(args.num_samples, len(dataset))
        indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int).tolist()
        title = "Diffusion — GT vs Pred (xyz + gripper)"

    print("\nComputing metrics...")
    metrics = compute_metrics(model, dataset, indices)
    print(f"Samples: {metrics['n_samples']}")
    print(f"Avg 3D RMSE: {metrics['avg_3d_rmse_m']*1000:.3f} mm")
    print(f"Avg Gripper Abs Error: {metrics['avg_gripper_abs_error']:.4f}")

    if not args.metrics_only:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        save_path = save_dir / "diffusion_gt_vs_pred.png"
        if args.episode and args.start_frame is not None:
            save_path = save_dir / f"diffusion_{args.episode}_frame{args.start_frame}.png"
        visualize_gt_vs_pred_lineplots(model, dataset, indices, title, save_path=save_path)


if __name__ == "__main__":
    main()
