"""Train motion tracks baseline on real data.

Model predicts 2D (N_WINDOW, 2) + height (N_WINDOW,) + gripper (N_WINDOW,) in camera frame.
MSE loss; line plots for 2d (pixel err), height, gripper (no heatmaps).
Lift to 3D with recover_3d_from_direct_keypoint_and_height for eval.
"""
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path
import argparse
import os
import sys

sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import MotionTracksTrajectoryPredictor
import model as model_module
from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000

# 2D targets are in pixel coords (0..IMAGE_SIZE); raw MSE is huge. Scale 2D loss so it's
# initially ~10 and comparable to height/gripper: use normalized coords then scale for balance.
SCALE_2D_LOSS = 70.0  # normalized MSE ~0.15 -> scaled ~10


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_2d_loss = 0
    total_height_loss = 0
    total_gripper_loss = 0
    n_batches = 0
    for batch in dataloader:
        rgb = batch['rgb'].to(device)
        trajectory_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2) pixels
        trajectory_3d = batch['trajectory_3d'].to(device)  # (B, N_WINDOW, 3)
        trajectory_gripper = batch['trajectory_gripper'].to(device)  # (B, N_WINDOW)
        trajectory_height = trajectory_3d[:, :, 2]  # (B, N_WINDOW)
        current_2d = trajectory_2d[:, 0, :]  # (B, 2)
        current_height = trajectory_3d[:, 0, 2]  # (B,)
        current_gripper = trajectory_gripper[:, 0]  # (B,)

        pred_2d, pred_height, pred_gripper = model(rgb, current_2d=current_2d, current_height=current_height, current_gripper_state=current_gripper)

        # 2D in normalized [0,1] then scale so loss ~10 initially (was ~30k in pixel space)
        loss_2d = F.mse_loss(pred_2d / IMAGE_SIZE, trajectory_2d / IMAGE_SIZE) * SCALE_2D_LOSS
        loss_height = F.mse_loss(pred_height, trajectory_height)
        loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
        loss = loss_2d + loss_height + loss_gripper

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_2d_loss += loss_2d.item()
        total_height_loss += loss_height.item()
        total_gripper_loss += loss_gripper.item()
        n_batches += 1

    n = max(1, n_batches)
    return total_loss / n, total_2d_loss / n, total_height_loss / n, total_gripper_loss / n


def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_2d_loss = 0
    total_height_loss = 0
    total_gripper_loss = 0
    n_samples = 0
    sample_data = None
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch['rgb'].to(device)
            trajectory_2d = batch['trajectory_2d'].to(device)
            trajectory_3d = batch['trajectory_3d'].to(device)
            trajectory_gripper = batch['trajectory_gripper'].to(device)
            trajectory_height = trajectory_3d[:, :, 2]
            current_2d = trajectory_2d[:, 0, :]
            current_height = trajectory_3d[:, 0, 2]
            current_gripper = trajectory_gripper[:, 0]

            pred_2d, pred_height, pred_gripper = model(rgb, current_2d=current_2d, current_height=current_height, current_gripper_state=current_gripper)

            loss_2d = F.mse_loss(pred_2d / IMAGE_SIZE, trajectory_2d / IMAGE_SIZE) * SCALE_2D_LOSS
            loss_height = F.mse_loss(pred_height, trajectory_height)
            loss_gripper = F.mse_loss(pred_gripper, trajectory_gripper)
            loss = loss_2d + loss_height + loss_gripper

            total_loss += loss.item() * rgb.shape[0]
            total_2d_loss += loss_2d.item() * rgb.shape[0]
            total_height_loss += loss_height.item() * rgb.shape[0]
            total_gripper_loss += loss_gripper.item() * rgb.shape[0]
            n_samples += rgb.shape[0]

            if batch_idx == 0 and sample_data is None:
                sample_data = {
                    'rgb': rgb[0],
                    'trajectory_2d': trajectory_2d[0].cpu().numpy(),
                    'trajectory_height': trajectory_height[0].cpu().numpy(),
                    'trajectory_gripper': trajectory_gripper[0].cpu().numpy(),
                    'trajectory_3d': trajectory_3d[0].cpu().numpy(),
                    'pred_2d': pred_2d[0].cpu().numpy(),
                    'pred_height': pred_height[0].cpu().numpy(),
                    'pred_gripper': pred_gripper[0].cpu().numpy(),
                }
    n = max(1, n_samples)
    return (total_loss / n, total_2d_loss / n, total_height_loss / n, total_gripper_loss / n, sample_data)


def main():
    parser = argparse.ArgumentParser(description="Train motion tracks baseline (2d + height + gripper)")
    parser.add_argument("--dataset_root", "-d", nargs="+", default=["scratch/parsed_school_long_recap"])
    parser.add_argument("--val_split", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--lr", type=float, default=LEARNING_RATE)
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS)
    parser.add_argument("--checkpoint", type=str, default="",
                        help="Path to checkpoint to resume from")
    parser.add_argument("--run_name", type=str, default="motion_tracks_baseline")
    args = parser.parse_args()

    _script_dir = Path(__file__).resolve().parent
    CHECKPOINT_DIR = _script_dir / "checkpoints" / args.run_name
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

    device = torch.device("mps" if torch.backends.mps.is_available() else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("\nLoading dataset...")
    dataset_roots = args.dataset_root if isinstance(args.dataset_root, list) else [args.dataset_root]
    if len(dataset_roots) == 1:
        full_dataset = RealTrajectoryDataset(dataset_root=dataset_roots[0], image_size=IMAGE_SIZE)
        print(f"  Single root: {dataset_roots[0]}")
    else:
        datasets = [
            RealTrajectoryDataset(dataset_root=root, image_size=IMAGE_SIZE)
            for root in dataset_roots
        ]
        full_dataset = ConcatDataset(datasets)
        for root, d in zip(dataset_roots, datasets):
            print(f"  {root}: {len(d)} samples")
    print(f"  Total: {len(full_dataset)} samples")
    n_total = len(full_dataset)
    n_val = max(1, int(n_total * args.val_split))
    n_train = n_total - n_val
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val])
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    # Compute min/max height and gripper from dataset for checkpoint
    all_heights = []
    all_gripper = []
    for i in range(min(500, len(full_dataset))):
        s = full_dataset[i]
        all_heights.append(s['trajectory_3d'][:, 2].numpy())
        all_gripper.append(s['trajectory_gripper'].numpy())
    all_heights = np.concatenate(all_heights)
    all_gripper = np.concatenate(all_gripper)
    min_height = float(all_heights.min())
    max_height = float(all_heights.max())
    min_gripper = float(all_gripper.min())
    max_gripper = float(all_gripper.max())
    model_module.MIN_HEIGHT = min_height
    model_module.MAX_HEIGHT = max_height
    model_module.MIN_GRIPPER = min_gripper
    model_module.MAX_GRIPPER = max_gripper
    print(f"Height range: [{min_height*1000:.2f}, {max_height*1000:.2f}] mm")
    print(f"Gripper range: [{min_gripper:.3f}, {max_gripper:.3f}]")

    model = MotionTracksTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    start_epoch = 0
    if args.checkpoint:
        checkpoint_path = args.checkpoint
        if not os.path.exists(checkpoint_path):
            alt = checkpoint_path.rsplit(".", 1)
            if len(alt) == 2:
                other_ext = ".pth" if alt[1].lower() == "pt" else ".pt"
                alt_path = alt[0] + other_ext
                if os.path.exists(alt_path):
                    checkpoint_path = alt_path
                    print(f"Checkpoint not found at {args.checkpoint}, using {checkpoint_path}")
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint: {checkpoint_path}")
            ckpt = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(ckpt["model_state_dict"], strict=True)
            try:
                optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            except Exception as e:
                print(f"Could not load optimizer state: {e}")
            # Checkpoint saves 1-indexed epoch; next epoch to run (0-indexed) is that value
            #start_epoch = ckpt.get("epoch", 1)
            if "min_height" in ckpt and "max_height" in ckpt and "min_gripper" in ckpt and "max_gripper" in ckpt:
                min_height = float(ckpt["min_height"])
                max_height = float(ckpt["max_height"])
                min_gripper = float(ckpt["min_gripper"])
                max_gripper = float(ckpt["max_gripper"])
                model_module.MIN_HEIGHT = min_height
                model_module.MAX_HEIGHT = max_height
                model_module.MIN_GRIPPER = min_gripper
                model_module.MAX_GRIPPER = max_gripper
                print(f"Using height/gripper range from checkpoint")
            print(f"Resumed from epoch {start_epoch}")
        else:
            print(f"Checkpoint not found: {args.checkpoint}, training from scratch")

    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(start_epoch, args.epochs):
        t_loss, t_2d, t_h, t_g = train_epoch(model, train_loader, optimizer, device)
        train_losses.append(t_loss)
        v_loss, v_2d, v_h, v_g, sample_data = validate(model, val_loader, device)
        val_losses.append(v_loss)

        print(f"Epoch {epoch+1}/{args.epochs} train_loss={t_loss:.6f} (2d={t_2d:.4f} h={t_h:.6f} g={t_g:.4f}) "
              f"val_loss={v_loss:.6f} (2d={v_2d:.4f} h={v_h:.6f} g={v_g:.4f})")

        if v_loss < best_val_loss:
            best_val_loss = v_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'min_height': min_height,
                'max_height': max_height,
                'min_gripper': min_gripper,
                'max_gripper': max_gripper,
            }, CHECKPOINT_DIR / "best.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'min_height': min_height,
            'max_height': max_height,
            'min_gripper': min_gripper,
            'max_gripper': max_gripper,
        }, CHECKPOINT_DIR / "latest.pth")

        # Line plot viz every 50 epochs
        if sample_data is not None and (epoch + 1) % 50 == 0:
            ts = np.arange(N_WINDOW)
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            axes[0, 0].plot(ts, sample_data['trajectory_2d'][:, 0], 's-', label='GT x', color='green', markersize=4)
            axes[0, 0].plot(ts, sample_data['pred_2d'][:, 0], 'o-', label='Pred x', color='red', markersize=4)
            axes[0, 0].set_xlabel('Timestep')
            axes[0, 0].set_ylabel('2D x (px)')
            axes[0, 0].legend()
            axes[0, 0].grid(alpha=0.3)
            axes[0, 1].plot(ts, sample_data['trajectory_2d'][:, 1], 's-', label='GT y', color='green', markersize=4)
            axes[0, 1].plot(ts, sample_data['pred_2d'][:, 1], 'o-', label='Pred y', color='red', markersize=4)
            axes[0, 1].set_xlabel('Timestep')
            axes[0, 1].set_ylabel('2D y (px)')
            axes[0, 1].legend()
            axes[0, 1].grid(alpha=0.3)
            axes[1, 0].plot(ts, sample_data['trajectory_height'] * 1000, 's-', label='GT height (mm)', color='green', markersize=4)
            axes[1, 0].plot(ts, sample_data['pred_height'] * 1000, 'o-', label='Pred height (mm)', color='red', markersize=4)
            axes[1, 0].set_xlabel('Timestep')
            axes[1, 0].set_ylabel('Height (mm)')
            axes[1, 0].legend()
            axes[1, 0].grid(alpha=0.3)
            axes[1, 1].plot(ts, sample_data['trajectory_gripper'], 's-', label='GT gripper', color='green', markersize=4)
            axes[1, 1].plot(ts, sample_data['pred_gripper'], 'o-', label='Pred gripper', color='red', markersize=4)
            axes[1, 1].set_xlabel('Timestep')
            axes[1, 1].set_ylabel('Gripper')
            axes[1, 1].legend()
            axes[1, 1].grid(alpha=0.3)
            plt.suptitle(f"Motion tracks epoch {epoch+1}")
            plt.tight_layout()
            plt.savefig(CHECKPOINT_DIR / f"vis_epoch_{epoch+1}.png", dpi=120, bbox_inches='tight')
            plt.close()

    print(f"Done. Best val loss: {best_val_loss:.6f}")


if __name__ == "__main__":
    main()
