"""Train ACT / vanilla regression baseline on real data.

Model predicts trajectory_3d (N_WINDOW, 3) + gripper (N_WINDOW,) in global robot frame.
MSE loss; line plots for xyz and gripper (no heatmaps).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from pathlib import Path
import argparse
import os
import sys

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import ACTTrajectoryPredictor
import model as model_module

IMAGE_SIZE = 448
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_3d_loss = 0
    total_gripper_loss = 0
    n_batches = 0
    for batch in dataloader:
        rgb = batch['rgb'].to(device)
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        current_3d = trajectory_3d[:, 0, :]  # (B, 3)
        current_gripper = trajectory_gripper[:, 0]  # (B,)

        pred_3d, pred_gripper = model(rgb, current_3d=current_3d, current_gripper_state=current_gripper)

        loss_3d = F.mse_loss(pred_3d, trajectory_3d)
        loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
        loss = loss_3d + loss_gripper

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_3d_loss += loss_3d.item()
        total_gripper_loss += loss_gripper.item()
        n_batches += 1

    n = max(1, n_batches)
    return total_loss / n, total_3d_loss / n, total_gripper_loss / n


def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_3d_loss = 0
    total_gripper_loss = 0
    n_samples = 0
    sample_data = None
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_3d = batch['trajectory_3d'].to(device)
            trajectory_gripper = batch['trajectory_gripper'].to(device)
            current_3d = trajectory_3d[:, 0, :]
            current_gripper = trajectory_gripper[:, 0]

            pred_3d, pred_gripper = model(rgb, current_3d=current_3d, current_gripper_state=current_gripper)

            loss_3d = F.mse_loss(pred_3d, trajectory_3d)
            loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
            loss = loss_3d + loss_gripper

            total_loss += loss.item() * rgb.shape[0]
            total_3d_loss += loss_3d.item() * rgb.shape[0]
            total_gripper_loss += loss_gripper.item() * rgb.shape[0]
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                sample_data = {
                    'rgb': rgb[0],
                    'trajectory_3d': trajectory_3d[0],
                    'pred_trajectory_3d': pred_3d[0].cpu().numpy(),
                    'pred_gripper': pred_gripper[0],
                    'target_gripper': trajectory_gripper[0],
                }
    n = max(1, n_samples)
    return total_loss / n, total_3d_loss / n, total_gripper_loss / n, sample_data


def main():
    parser = argparse.ArgumentParser(description="Train ACT baseline (vanilla regression)")
    parser.add_argument("--dataset_root", "-d", nargs="+", default=["scratch/parsed_pickplace_exp1_feb9"])
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LEARNING_RATE)
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS)
    parser.add_argument("--checkpoint", type=str, default="",
                        help="Path to checkpoint to resume from")
    parser.add_argument("--run_name", type=str, default="act_baseline")
    args = parser.parse_args()

    CHECKPOINT_DIR = Path(f"volume_dino_tracks_act_baseline/checkpoints/{args.run_name}")
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

    device = torch.device("mps" if torch.backends.mps.is_available() else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("\nLoading dataset...")
    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(dataset_root=dataset_roots[0], image_size=IMAGE_SIZE)
        print(f"  Single root: {dataset_roots[0]}")
    else:
        datasets = [
            RealTrajectoryDataset(dataset_root=root, image_size=IMAGE_SIZE)
            for root in dataset_roots
        ]
        full_dataset = ConcatDataset(datasets)
        for root, d in zip(dataset_roots, datasets):
            print(f"  {root}: {len(d)} samples")
    print(f"  Total: {len(full_dataset)} samples")

    n_total = len(full_dataset)
    n_val = int(n_total * args.val_split)
    n_train = n_total - n_val
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    model = ACTTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    start_epoch = 0
    checkpoint_has_minmax = False
    if args.checkpoint:
        checkpoint_path = args.checkpoint
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint: {checkpoint_path}")
            ckpt = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(ckpt["model_state_dict"], strict=True)
            try:
                optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            except Exception as e:
                print(f"Could not load optimizer state: {e}")
            #start_epoch = ckpt.get("epoch", 0) + 1
            if "min_height" in ckpt and "max_height" in ckpt and "min_gripper" in ckpt and "max_gripper" in ckpt:
                model_module.MIN_HEIGHT = float(ckpt["min_height"])
                model_module.MAX_HEIGHT = float(ckpt["max_height"])
                model_module.MIN_GRIPPER = float(ckpt["min_gripper"])
                model_module.MAX_GRIPPER = float(ckpt["max_gripper"])
                checkpoint_has_minmax = True
                print(f"Using height/gripper range from checkpoint")
            print(f"Resumed from epoch {start_epoch}")
        else:
            print(f"Checkpoint not found: {args.checkpoint}, training from scratch")

    if not checkpoint_has_minmax:
        # Compute min/max height and gripper from dataset (for checkpoint compatibility)
        all_heights = []
        all_grippers = []
        for ds in [train_dataset, val_dataset]:
            for i in range(len(ds)):
                s = ds[i]
                traj_3d = s['trajectory_3d'].numpy()
                all_heights.extend(traj_3d[:, 2].tolist())
                all_grippers.extend(s['trajectory_gripper'].numpy().tolist())
        model_module.MIN_HEIGHT = float(np.min(all_heights))
        model_module.MAX_HEIGHT = float(np.max(all_heights))
        model_module.MIN_GRIPPER = float(np.min(all_grippers))
        model_module.MAX_GRIPPER = float(np.max(all_grippers))
    print(f"Height range: [{model_module.MIN_HEIGHT:.6f}, {model_module.MAX_HEIGHT:.6f}] m")
    print(f"Gripper range: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")

    plt.ion()
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(3, 2, figure=fig, hspace=0.35, wspace=0.25)
    ax_loss = fig.add_subplot(gs[0, :])
    ax_xyz = fig.add_subplot(gs[1, :])
    ax_gripper = fig.add_subplot(gs[2, 0])
    ax_3d = fig.add_subplot(gs[2, 1], projection='3d')
    plt.show(block=False)

    train_losses, val_losses = [], []
    best_val_loss = float('inf')

    for epoch in range(start_epoch, args.epochs):
        train_loss, train_3d, train_grip = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_3d, val_grip, sample_data = validate(model, val_loader, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Epoch {epoch} Train loss: {train_loss:.4f} (3d: {train_3d:.6f}, grip: {train_grip:.6f}) "
              f"Val loss: {val_loss:.4f} (3d: {val_3d:.6f}, grip: {val_grip:.6f})")

        # Line plot: loss
        ax_loss.clear()
        ax_loss.plot(np.arange(len(train_losses)), train_losses, 'o-', label='Train', color='blue')
        ax_loss.plot(np.arange(len(val_losses)), val_losses, 's-', label='Val', color='green')
        ax_loss.set_xlabel('Epoch')
        ax_loss.set_ylabel('MSE Loss')
        ax_loss.legend()
        ax_loss.grid(alpha=0.3)

        # Line plots: xyz and gripper (val sample)
        if sample_data is not None:
            gt_3d = sample_data['trajectory_3d'].cpu().numpy()
            pred_3d = sample_data['pred_trajectory_3d']
            ts = np.arange(N_WINDOW)
            ax_xyz.clear()
            for i, (name, idx) in enumerate([('x', 0), ('y', 1), ('z', 2)]):
                ax_xyz.plot(ts, gt_3d[:, idx], 's-', label=f'GT {name}', color=['red','green','blue'][i], alpha=0.8)
                ax_xyz.plot(ts, pred_3d[:, idx], 'o--', label=f'Pred {name}', color=['red','green','blue'][i])
            ax_xyz.set_xlabel('Timestep')
            ax_xyz.set_ylabel('Position (m)')
            ax_xyz.legend(ncol=2, fontsize=8)
            ax_xyz.grid(alpha=0.3)

            ax_gripper.clear()
            ax_gripper.plot(ts, sample_data['target_gripper'].cpu().numpy(), 's-', label='GT gripper', color='green')
            ax_gripper.plot(ts, sample_data['pred_gripper'].cpu().numpy(), 'o-', label='Pred gripper', color='red')
            ax_gripper.set_xlabel('Timestep')
            ax_gripper.set_ylabel('Gripper')
            ax_gripper.legend()
            ax_gripper.grid(alpha=0.3)

            ax_3d.clear()
            ax_3d.scatter(gt_3d[:, 0], gt_3d[:, 1], gt_3d[:, 2], c='red', s=60, label='GT')
            ax_3d.scatter(pred_3d[:, 0], pred_3d[:, 1], pred_3d[:, 2], c='blue', s=60, label='Pred')
            ax_3d.set_xlabel('X')
            ax_3d.set_ylabel('Y')
            ax_3d.set_zlabel('Z')
            ax_3d.legend()

        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.pause(0.01)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height': model_module.MIN_HEIGHT,
            'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
        }
        torch.save(checkpoint, CHECKPOINT_DIR / 'latest.pth')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, CHECKPOINT_DIR / 'best.pth')
            print(f"  -> Saved best.pth")

    print("Done.")


if __name__ == "__main__":
    main()
