"""Train ACT joints baseline on real data.

Model predicts trajectory_joints (N_WINDOW, 6) + gripper (N_WINDOW,) — direct joint regression, no IK.
MSE loss; line plots for joints and gripper.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from pathlib import Path
import argparse
import os
import sys

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW, N_JOINTS
from model import ACTJointsTrajectoryPredictor
import model as model_module

IMAGE_SIZE = 448
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_joints_loss = 0
    total_gripper_loss = 0
    n_batches = 0
    for batch in dataloader:
        rgb = batch['rgb'].to(device)
        trajectory_joints = batch['trajectory_joints'].to(device)  # (B, N_WINDOW, 6)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        current_joints = trajectory_joints[:, 0, :]  # (B, 6)
        current_gripper = trajectory_gripper[:, 0]  # (B,)

        pred_joints, pred_gripper = model(rgb, current_joints=current_joints, current_gripper_state=current_gripper)

        loss_joints = F.mse_loss(pred_joints, trajectory_joints)
        loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
        loss = loss_joints + loss_gripper

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_joints_loss += loss_joints.item()
        total_gripper_loss += loss_gripper.item()
        n_batches += 1

    n = max(1, n_batches)
    return total_loss / n, total_joints_loss / n, total_gripper_loss / n


def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_joints_loss = 0
    total_gripper_loss = 0
    n_samples = 0
    sample_data = None
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_joints = batch['trajectory_joints'].to(device)
            trajectory_gripper = batch['trajectory_gripper'].to(device)
            current_joints = trajectory_joints[:, 0, :]
            current_gripper = trajectory_gripper[:, 0]

            pred_joints, pred_gripper = model(rgb, current_joints=current_joints, current_gripper_state=current_gripper)

            loss_joints = F.mse_loss(pred_joints, trajectory_joints)
            loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
            loss = loss_joints + loss_gripper

            total_loss += loss.item() * rgb.shape[0]
            total_joints_loss += loss_joints.item() * rgb.shape[0]
            total_gripper_loss += loss_gripper.item() * rgb.shape[0]
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                sample_data = {
                    'rgb': rgb[0],
                    'trajectory_joints': trajectory_joints[0],
                    'pred_trajectory_joints': pred_joints[0].cpu().numpy(),
                    'pred_gripper': pred_gripper[0],
                    'target_gripper': trajectory_gripper[0],
                }
    n = max(1, n_samples)
    return total_loss / n, total_joints_loss / n, total_gripper_loss / n, sample_data


def main():
    parser = argparse.ArgumentParser(description="Train ACT joints baseline (6D joint regression)")
    parser.add_argument("--dataset_root", "-d", nargs="+", default=["scratch/parsed_school_long_recap"])
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LEARNING_RATE)
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS)
    parser.add_argument("--run_name", type=str, default="act_baseline_joints")
    args = parser.parse_args()

    CHECKPOINT_DIR = Path("volume_dino_tracks_act_baseline_joints/checkpoints") / args.run_name
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

    device = torch.device("mps" if torch.backends.mps.is_available() else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(dataset_root=dataset_roots[0], image_size=IMAGE_SIZE)
    else:
        full_dataset = ConcatDataset([
            RealTrajectoryDataset(dataset_root=r, image_size=IMAGE_SIZE) for r in dataset_roots
        ])
    print(f"Total samples: {len(full_dataset)}")

    n_total = len(full_dataset)
    n_val = int(n_total * args.val_split)
    n_train = n_total - n_val
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    model = ACTJointsTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    # Gripper range from dataset (for checkpoint compatibility)
    all_grippers = []
    for ds in [train_dataset, val_dataset]:
        for i in range(len(ds)):
            s = ds[i]
            all_grippers.extend(s['trajectory_gripper'].numpy().tolist())
    model_module.MIN_GRIPPER = float(np.min(all_grippers))
    model_module.MAX_GRIPPER = float(np.max(all_grippers))
    print(f"Gripper range: [{model_module.MIN_GRIPPER:.6f}, {model_module.MAX_GRIPPER:.6f}]")

    plt.ion()
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(3, 2, figure=fig, hspace=0.35, wspace=0.25)
    ax_loss = fig.add_subplot(gs[0, :])
    ax_joints = fig.add_subplot(gs[1, :])
    ax_gripper = fig.add_subplot(gs[2, 0])
    ax_joints_3d = fig.add_subplot(gs[2, 1])
    plt.show(block=False)

    train_losses, val_losses = [], []
    best_val_loss = float('inf')

    for epoch in range(args.epochs):
        train_loss, train_joints, train_grip = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_joints, val_grip, sample_data = validate(model, val_loader, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Epoch {epoch} Train loss: {train_loss:.4f} (joints: {train_joints:.6f}, grip: {train_grip:.6f}) "
              f"Val loss: {val_loss:.4f} (joints: {val_joints:.6f}, grip: {val_grip:.6f})")

        # Line plot: loss
        ax_loss.clear()
        ax_loss.plot(np.arange(len(train_losses)), train_losses, 'o-', label='Train', color='blue')
        ax_loss.plot(np.arange(len(val_losses)), val_losses, 's-', label='Val', color='green')
        ax_loss.set_xlabel('Epoch')
        ax_loss.set_ylabel('MSE Loss')
        ax_loss.legend()
        ax_loss.grid(alpha=0.3)

        # Line plots: joints (first 3) and gripper
        if sample_data is not None:
            gt_j = sample_data['trajectory_joints'].cpu().numpy()  # (N_WINDOW, 6)
            pred_j = sample_data['pred_trajectory_joints']
            ts = np.arange(N_WINDOW)
            ax_joints.clear()
            for j in range(min(3, N_JOINTS)):
                ax_joints.plot(ts, gt_j[:, j], 's-', label=f'GT j{j}', color=['red','green','blue'][j], alpha=0.8)
                ax_joints.plot(ts, pred_j[:, j], 'o--', label=f'Pred j{j}', color=['red','green','blue'][j])
            ax_joints.set_xlabel('Timestep')
            ax_joints.set_ylabel('Joint (rad)')
            ax_joints.legend(ncol=2, fontsize=8)
            ax_joints.grid(alpha=0.3)

            ax_gripper.clear()
            ax_gripper.plot(ts, sample_data['target_gripper'].cpu().numpy(), 's-', label='GT gripper', color='green')
            ax_gripper.plot(ts, sample_data['pred_gripper'].cpu().numpy(), 'o-', label='Pred gripper', color='red')
            ax_gripper.set_xlabel('Timestep')
            ax_gripper.set_ylabel('Gripper')
            ax_gripper.legend()
            ax_gripper.grid(alpha=0.3)

            ax_joints_3d.clear()
            # Show joint 0 vs 1 (2D slice of joint space)
            ax_joints_3d.plot(gt_j[:, 0], gt_j[:, 1], 's-', label='GT', color='green', alpha=0.8)
            ax_joints_3d.plot(pred_j[:, 0], pred_j[:, 1], 'o-', label='Pred', color='red')
            ax_joints_3d.set_xlabel('Joint 0')
            ax_joints_3d.set_ylabel('Joint 1')
            ax_joints_3d.legend()
            ax_joints_3d.grid(alpha=0.3)

        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.pause(0.01)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'min_height': model_module.MIN_HEIGHT,
            'max_height': model_module.MAX_HEIGHT,
            'min_gripper': model_module.MIN_GRIPPER,
            'max_gripper': model_module.MAX_GRIPPER,
        }
        torch.save(checkpoint, CHECKPOINT_DIR / 'latest.pth')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, CHECKPOINT_DIR / 'best.pth')
            print(f"  -> Saved best.pth")

    print("Done.")


if __name__ == "__main__":
    main()
