"""Train diffusion policy baseline on real data.

Model: diffusion over (trajectory_3d, gripper) state conditioned on image. N~10 steps. Global robot frame.
Same interface: outputs trajectory_3d (N_WINDOW, 3) + gripper (N_WINDOW,). Line plots for viz (no heatmaps).
"""
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from pathlib import Path
import argparse
import os
import sys

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import DiffusionTrajectoryPredictor
import model as model_module

IMAGE_SIZE = 448
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    n_batches = 0
    for batch in dataloader:
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)
        trajectory_gripper = batch['trajectory_gripper'].to(device)
        current_3d = trajectory_3d[:, 0, :]  # (B, 3)
        current_gripper = trajectory_gripper[:, 0]  # (B,)

        loss = model(
            rgb,
            gt_trajectory_3d=trajectory_3d,
            gt_gripper=trajectory_gripper,
            training=True,
            current_3d=current_3d,
            current_gripper_state=current_gripper,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / max(1, n_batches)


def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_mse_3d = 0
    total_mse_gripper = 0
    n_samples = 0
    sample_data = None
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_3d = batch['trajectory_3d'].to(device)
            trajectory_gripper = batch['trajectory_gripper'].to(device)
            current_3d = trajectory_3d[:, 0, :]
            current_gripper = trajectory_gripper[:, 0]

            pred_3d, pred_gripper = model(rgb, training=False, current_3d=current_3d, current_gripper_state=current_gripper)

            mse_3d = F.mse_loss(pred_3d, trajectory_3d)
            mse_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
            total_mse_3d += mse_3d.item() * rgb.shape[0]
            total_mse_gripper += mse_gripper.item() * rgb.shape[0]
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                sample_data = {
                    'rgb': rgb[0],
                    'trajectory_3d': trajectory_3d[0],
                    'pred_trajectory_3d': pred_3d[0].cpu().numpy(),
                    'pred_gripper': pred_gripper[0].cpu().numpy(),
                    'target_gripper': trajectory_gripper[0].cpu().numpy(),
                }
    n = max(1, n_samples)
    val_loss = (total_mse_3d + total_mse_gripper) / n
    return val_loss, total_mse_3d / n, total_mse_gripper / n, sample_data


def main():
    parser = argparse.ArgumentParser(description="Train diffusion baseline (diffusion over trajectory state)")
    parser.add_argument("--dataset_root", "-d", nargs="+", default=["scratch/parsed_school_long_recap"])
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LEARNING_RATE)
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS)
    parser.add_argument("--run_name", type=str, default="diffusion_baseline")
    args = parser.parse_args()

    _script_dir = Path(__file__).resolve().parent
    CHECKPOINT_DIR = _script_dir / "checkpoints" / args.run_name
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

    device = torch.device("mps" if torch.backends.mps.is_available() else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(dataset_root=dataset_roots[0], image_size=IMAGE_SIZE)
    else:
        full_dataset = ConcatDataset([
            RealTrajectoryDataset(dataset_root=r, image_size=IMAGE_SIZE) for r in dataset_roots
        ])
    print(f"Total samples: {len(full_dataset)}")

    n_total = len(full_dataset)
    n_val = int(n_total * args.val_split)
    n_train = n_total - n_val
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    model = DiffusionTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    # Min/max for checkpoint compatibility
    all_heights = []
    all_grippers = []
    for ds in [train_dataset, val_dataset]:
        for i in range(len(ds)):
            s = ds[i]
            traj_3d = s['trajectory_3d'].numpy()
            all_heights.extend(traj_3d[:, 2].tolist())
            all_grippers.extend(s['trajectory_gripper'].numpy().tolist())
    model_module.MIN_HEIGHT = float(np.min(all_heights))
    model_module.MAX_HEIGHT = float(np.max(all_heights))
    model_module.MIN_GRIPPER = float(np.min(all_grippers))
    model_module.MAX_GRIPPER = float(np.max(all_grippers))
    print(f"Height range: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
    print(f"Gripper range: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")

    plt.ion()
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(3, 2, figure=fig, hspace=0.35, wspace=0.25)
    ax_loss = fig.add_subplot(gs[0, :])
    ax_xyz = fig.add_subplot(gs[1, :])
    ax_gripper = fig.add_subplot(gs[2, 0])
    ax_3d = fig.add_subplot(gs[2, 1], projection='3d')
    plt.show(block=False)

    train_losses, val_losses = [], []
    best_val_loss = float('inf')

    for epoch in range(args.epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_3d, val_grip, sample_data = validate(model, val_loader, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Epoch {epoch} Train loss (diff): {train_loss:.4f} Val loss (MSE): {val_loss:.4f} (3d: {val_3d:.6f}, grip: {val_grip:.6f})")

        ax_loss.clear()
        ax_loss.plot(np.arange(len(train_losses)), train_losses, 'o-', label='Train (diff)', color='blue')
        ax_loss.plot(np.arange(len(val_losses)), val_losses, 's-', label='Val (MSE)', color='green')
        ax_loss.set_xlabel('Epoch')
        ax_loss.set_ylabel('Loss')
        ax_loss.legend()
        ax_loss.grid(alpha=0.3)

        if sample_data is not None:
            gt_3d = sample_data['trajectory_3d'].cpu().numpy()
            pred_3d = sample_data['pred_trajectory_3d']
            ts = np.arange(N_WINDOW)
            ax_xyz.clear()
            for i, (name, idx) in enumerate([('x', 0), ('y', 1), ('z', 2)]):
                ax_xyz.plot(ts, gt_3d[:, idx], 's-', label=f'GT {name}', color=['red','green','blue'][i], alpha=0.8)
                ax_xyz.plot(ts, pred_3d[:, idx], 'o--', label=f'Pred {name}', color=['red','green','blue'][i])
            ax_xyz.set_xlabel('Timestep')
            ax_xyz.set_ylabel('Position (m)')
            ax_xyz.legend(ncol=2, fontsize=8)
            ax_xyz.grid(alpha=0.3)

            ax_gripper.clear()
            ax_gripper.plot(ts, sample_data['target_gripper'], 's-', label='GT gripper', color='green')
            ax_gripper.plot(ts, sample_data['pred_gripper'], 'o-', label='Pred gripper', color='red')
            ax_gripper.set_xlabel('Timestep')
            ax_gripper.set_ylabel('Gripper')
            ax_gripper.legend()
            ax_gripper.grid(alpha=0.3)

            ax_3d.clear()
            ax_3d.scatter(gt_3d[:, 0], gt_3d[:, 1], gt_3d[:, 2], c='red', s=60, label='GT')
            ax_3d.scatter(pred_3d[:, 0], pred_3d[:, 1], pred_3d[:, 2], c='blue', s=60, label='Pred')
            ax_3d.set_xlabel('X')
            ax_3d.set_ylabel('Y')
            ax_3d.set_zlabel('Z')
            ax_3d.legend()

        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.pause(0.01)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height': model_module.MIN_HEIGHT,
            'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
        }
        torch.save(checkpoint, CHECKPOINT_DIR / 'latest.pth')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, CHECKPOINT_DIR / 'best.pth')
            print(f"  -> Saved best.pth")

    print("Done.")


if __name__ == "__main__":
    main()
