"""Train PARA on real Panda robot data.

Usage:
  cd /data/cameron/para/panda_streaming
  CUDA_VISIBLE_DEVICES=6 MUJOCO_GL=egl \
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
  DINO_WEIGHTS_PATH=/data/cameron/.cache/torch/hub/checkpoints/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
  python train_panda_para.py \
    --data_dir /data/cameron/panda_data/data_20260420_115853_632_frames \
    --run_name panda_para_test_v1 \
    --epochs 500 --batch_size 4
"""
import sys, os
sys.path.insert(0, os.path.dirname(__file__))

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import wandb
import json
import glob
import cv2

from data_panda_para import (
    PandaTrajectoryDataset, T_CAM_WORLD, CAM_K, IMG_W, IMG_H,
    N_WINDOW, project_to_pixel,
)

# Import model directly (avoid importing libero data.py which needs sim)
libero_dir = os.path.join(os.path.dirname(__file__), '..', 'libero')
import importlib.util
_spec = importlib.util.spec_from_file_location("model", os.path.join(libero_dir, "model.py"))
model_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(model_module)
TrajectoryHeatmapPredictor = model_module.TrajectoryHeatmapPredictor
N_HEIGHT_BINS = model_module.N_HEIGHT_BINS
N_ROT_BINS = model_module.N_ROT_BINS
PRED_SIZE = model_module.PRED_SIZE

IMAGE_SIZE = 448
VOLUME_LOSS_WEIGHT = 1.0
GRIPPER_LOSS_WEIGHT = 5.0
ROTATION_LOSS_WEIGHT = 0.5
VIS_EVERY_EPOCHS = 10


# ── Loss helpers (self-contained, no robosuite dependency) ──────────────

def discretize_height(height_values):
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    normalized = ((height_values - min_h) / (max_h - min_h + 1e-8)).clamp(0, 1)
    return (normalized * (N_HEIGHT_BINS - 1)).long().clamp(0, N_HEIGHT_BINS - 1)


def compute_volume_loss(pred_volume_logits, trajectory_2d, target_height_bins):
    B, N, Nh, H, W = pred_volume_logits.shape
    px = trajectory_2d[:, :, 0].long().clamp(0, W - 1)
    py = trajectory_2d[:, :, 1].long().clamp(0, H - 1)
    h_bin = target_height_bins.clamp(0, Nh - 1)
    losses = []
    for t in range(N):
        logits_flat = pred_volume_logits[:, t].reshape(B, -1)
        target_idx = (h_bin[:, t] * (H * W) + py[:, t] * W + px[:, t]).long()
        losses.append(F.cross_entropy(logits_flat, target_idx))
    return torch.stack(losses).mean()


def compute_gripper_loss(pred_gripper_logits, target_gripper):
    target_binary = (target_gripper > 0).float()
    return F.binary_cross_entropy_with_logits(pred_gripper_logits, target_binary)


def discretize_rotation(euler_values):
    min_r = torch.tensor(model_module.MIN_ROT, device=euler_values.device, dtype=torch.float32)
    max_r = torch.tensor(model_module.MAX_ROT, device=euler_values.device, dtype=torch.float32)
    normalized = ((euler_values - min_r) / (max_r - min_r + 1e-8)).clamp(0, 1)
    return (normalized * (N_ROT_BINS - 1)).long().clamp(0, N_ROT_BINS - 1)


def compute_rotation_loss(pred_rotation_logits, target_euler):
    target_bins = discretize_rotation(target_euler)
    B, N, _, Nr = pred_rotation_logits.shape
    losses = []
    for axis in range(3):
        logits = pred_rotation_logits[:, :, axis, :].reshape(B * N, Nr)
        target = target_bins[:, :, axis].reshape(B * N)
        losses.append(F.cross_entropy(logits, target))
    return torch.stack(losses).mean()


def extract_pred_2d_and_height(volume_logits):
    B, N, Nh, H, W = volume_logits.shape
    pred_2d = torch.zeros(B, N, 2, device=volume_logits.device)
    pred_h_bins = torch.zeros(B, N, device=volume_logits.device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.reshape(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_h_bins[:, t] = vol_t[torch.arange(B, device=volume_logits.device), :, py, px].argmax(dim=1)
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    bin_centers = torch.linspace(0, 1, N_HEIGHT_BINS, device=volume_logits.device)
    pred_height = bin_centers[pred_h_bins] * (max_h - min_h) + min_h
    return pred_2d, pred_height


# ── Fast dataset stats (NPY only, no image loading) ────────────────────

def compute_dataset_stats_fast(data_dir, episodes):
    """Compute height/gripper/rotation stats from NPY files only."""
    import mujoco
    from scipy.spatial.transform import Rotation as Rot
    from ExoConfigs.panda_exo_handeye_4x2 import PANDA_HANDEYE_4X2_CONFIG
    mj_model = mujoco.MjModel.from_xml_string(PANDA_HANDEYE_4X2_CONFIG.xml)
    mj_data = mujoco.MjData(mj_model)
    hand_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "hand")

    heights, grippers, eulers = [], [], []
    for ep in episodes:
        for idx in range(ep["start"], ep["end"] + 1):
            npy_path = os.path.join(data_dir, f"{idx:06d}.npy")
            if not os.path.exists(npy_path):
                continue
            js = np.load(npy_path).astype(np.float64)
            mj_data.qpos[:7] = js[:7]
            mujoco.mj_forward(mj_model, mj_data)
            pos = mj_data.xpos[hand_id].copy()
            quat_wxyz = mj_data.xquat[hand_id].copy()
            quat_xyzw = quat_wxyz[[1, 2, 3, 0]]
            euler = Rot.from_quat(quat_xyzw).as_euler('xyz')
            heights.append(pos[2])
            gw = js[7] if len(js) > 7 else 1.0
            grippers.append(2.0 * gw - 1.0)
            eulers.append(euler)

    heights = np.array(heights)
    grippers = np.array(grippers)
    eulers = np.array(eulers)
    stats = {
        "min_height": float(heights.min()), "max_height": float(heights.max()),
        "min_gripper": float(grippers.min()), "max_gripper": float(grippers.max()),
        "min_rot": eulers.min(axis=0).tolist(), "max_rot": eulers.max(axis=0).tolist(),
    }
    print(f"Stats: height=[{stats['min_height']:.4f}, {stats['max_height']:.4f}], "
          f"gripper=[{stats['min_gripper']:.2f}, {stats['max_gripper']:.2f}]")
    return stats


# ── Visualization (wandb timestep strip) ────────────────────────────────

def project_world_to_pixel(pos_3d, image_size):
    """Project 3D world point to pixel coords at image_size resolution."""
    pix = project_to_pixel(pos_3d, T_CAM_WORLD, CAM_K)
    if pix is None:
        return None
    u = pix[0] * image_size / IMG_W
    v = pix[1] * image_size / IMG_H
    return int(round(u)), int(round(v))


def build_wandb_strip(sample, split_name):
    """Build horizontal strip: one tile per timestep with heatmap + GT/pred annotations."""
    tiles = []
    for t in range(N_WINDOW):
        frame = sample['rgb_frames_raw'][t].cpu().numpy()
        H, W = frame.shape[:2]
        vis = (frame * 255).astype(np.uint8).copy()

        # Heatmap overlay (red channel)
        if 'pred_heatmap' in sample:
            heat = sample['pred_heatmap'][t].detach().cpu().numpy()
            heat = heat - heat.min()
            if heat.max() > 1e-8:
                heat = heat / heat.max()
            heat_rgb = np.zeros_like(frame)
            heat_rgb[..., 0] = heat
            vis = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
            vis = (vis * 255).astype(np.uint8)

            # Predicted pixel (green crosshair)
            pred_y, pred_x = np.unravel_index(heat.argmax(), heat.shape)
            if 0 <= pred_x < W and 0 <= pred_y < H:
                cv2.drawMarker(vis, (int(pred_x), int(pred_y)), (0, 255, 0),
                               cv2.MARKER_CROSS, 14, 2)
                cv2.putText(vis, "pred", (int(pred_x) + 8, int(pred_y) - 8),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1)

        # GT EEF projection (white dot)
        eef_pos = sample['trajectory_3d'][t].cpu().numpy().astype(np.float64)
        pt = project_world_to_pixel(eef_pos, H)  # H==W==image_size
        if pt is not None:
            u, v = pt
            if 0 <= u < W and 0 <= v < H:
                cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)
                cv2.putText(vis, "eef", (u + 8, v - 8),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)

                # Ground projection (cyan ring + yellow line)
                ground_pos = eef_pos.copy()
                ground_pos[2] = 0.0
                gpt = project_world_to_pixel(ground_pos, H)
                if gpt is not None:
                    ug, vg = gpt
                    if 0 <= ug < W and 0 <= vg < H:
                        cv2.circle(vis, (ug, vg), 6, (0, 255, 255), 2)
                        cv2.line(vis, (u, v), (ug, vg), (255, 255, 0), 2)
                        cv2.putText(vis, f"h={eef_pos[2]:.3f}", (ug + 8, vg + 12),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 255), 1)

        # Timestep label
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2)
        cv2.putText(vis, f"t={t}", (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1)
        tiles.append(vis)

    strip = np.concatenate(tiles, axis=1)
    return wandb.Image(strip, caption=f"{split_name}: t=0..{N_WINDOW-1}")


def build_vis_sample(model, batch, device):
    """Run model on first sample in batch, return dict for visualization."""
    model.eval()
    with torch.no_grad():
        rgb = batch['rgb'][0:1].to(device)
        traj_2d = batch['trajectory_2d'][0:1].to(device)
        scale = PRED_SIZE / IMAGE_SIZE

        start_kp = traj_2d[:, 0, :]
        volume_logits, _, _, feats = model(rgb, start_kp)
        pred_2d, pred_height = extract_pred_2d_and_height(volume_logits)

        # Upsample heatmaps to IMAGE_SIZE
        pred_heatmaps = []
        for t in range(N_WINDOW):
            vol_t = volume_logits[0, t]
            vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
            hm = vol_probs.max(dim=0)[0]
            hm_up = F.interpolate(hm[None, None], size=(IMAGE_SIZE, IMAGE_SIZE),
                                   mode='bilinear', align_corners=False)[0, 0]
            pred_heatmaps.append(hm_up)

        return {
            'rgb_frames_raw': batch['rgb_frames_raw'][0],
            'trajectory_3d': batch['trajectory_3d'][0],
            'trajectory_2d': batch['trajectory_2d'][0],
            'pred_heatmap': torch.stack(pred_heatmaps),
        }


# ── Main training loop ──────────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", required=True)
    p.add_argument("--run_name", default="panda_para_test")
    p.add_argument("--epochs", type=int, default=500)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--frame_stride", type=int, default=1)
    p.add_argument("--wandb_mode", default="online")
    p.add_argument("--freeze_backbone", action="store_true")
    p.add_argument("--val_split", type=float, default=0.15)
    p.add_argument("--vis_every", type=int, default=VIS_EVERY_EPOCHS)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}", flush=True)

    # Load episode annotations for stats
    with open(os.path.join(args.data_dir, "episodes.json")) as f:
        ep_data = json.load(f)

    # Fast stats (NPY only)
    stats = compute_dataset_stats_fast(args.data_dir, ep_data["episodes"])

    # Set model ranges
    model_module.MIN_HEIGHT = stats["min_height"]
    model_module.MAX_HEIGHT = stats["max_height"]
    model_module.MIN_GRIPPER = stats["min_gripper"]
    model_module.MAX_GRIPPER = stats["max_gripper"]
    model_module.MIN_ROT = stats["min_rot"]
    model_module.MAX_ROT = stats["max_rot"]
    model_module.MIN_POS = [0, 0, stats["min_height"]]
    model_module.MAX_POS = [1, 1, stats["max_height"]]

    # Checkpoint dir
    ckpt_dir = Path(f"checkpoints/{args.run_name}")
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    with open(ckpt_dir / "dataset_stats.json", "w") as f:
        json.dump(stats, f, indent=2)

    # Dataset
    dataset = PandaTrajectoryDataset(args.data_dir, frame_stride=args.frame_stride)
    n_val = max(1, int(len(dataset) * args.val_split))
    n_train = len(dataset) - n_val
    train_ds, val_ds = random_split(dataset, [n_train, n_val],
                                     generator=torch.Generator().manual_seed(42))
    print(f"Train: {n_train}, Val: {n_val}", flush=True)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                               num_workers=4, pin_memory=True, drop_last=True,
                               persistent_workers=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    # Model
    model = TrajectoryHeatmapPredictor().to(device)
    if args.freeze_backbone:
        for param in model.dino.parameters():
            param.requires_grad = False
        print("Backbone frozen", flush=True)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    wandb.init(project="para_panda", name=args.run_name, mode=args.wandb_mode,
               config={**vars(args), **stats})

    best_val_loss = float("inf")

    for epoch in range(args.epochs):
        # ── Train ──
        model.train()
        train_losses, train_vol_losses, train_grip_losses, train_rot_losses = [], [], [], []
        for batch in train_loader:
            rgb = batch["rgb"].to(device)
            traj_2d = batch["trajectory_2d"].to(device)
            traj_3d = batch["trajectory_3d"].to(device)
            traj_grip = batch["trajectory_gripper"].to(device)
            traj_euler = batch["trajectory_euler"].to(device)

            scale = PRED_SIZE / IMAGE_SIZE
            traj_2d_scaled = traj_2d * scale
            height_bins = discretize_height(traj_3d[:, :, 2])

            start_kp = traj_2d[:, 0, :]
            query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
            volume_logits, gripper_logits, rotation_logits, _ = model(
                rgb, start_kp, query_pixels=query_pixels)

            vol_loss = compute_volume_loss(volume_logits, traj_2d_scaled, height_bins)
            grip_loss = compute_gripper_loss(gripper_logits, traj_grip)
            rot_loss = compute_rotation_loss(rotation_logits, traj_euler)

            loss = (VOLUME_LOSS_WEIGHT * vol_loss
                    + GRIPPER_LOSS_WEIGHT * grip_loss
                    + ROTATION_LOSS_WEIGHT * rot_loss)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.item())
            train_vol_losses.append(vol_loss.item())
            train_grip_losses.append(grip_loss.item())
            train_rot_losses.append(rot_loss.item())

        scheduler.step()
        train_loss = np.mean(train_losses)

        # ── Val ──
        model.eval()
        val_losses, val_pixel_errs = [], []
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch["rgb"].to(device)
                traj_2d = batch["trajectory_2d"].to(device)
                traj_3d = batch["trajectory_3d"].to(device)
                traj_grip = batch["trajectory_gripper"].to(device)
                traj_euler = batch["trajectory_euler"].to(device)

                scale = PRED_SIZE / IMAGE_SIZE
                traj_2d_scaled = traj_2d * scale
                height_bins = discretize_height(traj_3d[:, :, 2])

                start_kp = traj_2d[:, 0, :]
                query_pixels = traj_2d_scaled.long().clamp(0, PRED_SIZE - 1)
                volume_logits, gripper_logits, rotation_logits, _ = model(
                    rgb, start_kp, query_pixels=query_pixels)

                vol_loss = compute_volume_loss(volume_logits, traj_2d_scaled, height_bins)
                grip_loss = compute_gripper_loss(gripper_logits, traj_grip)
                rot_loss = compute_rotation_loss(rotation_logits, traj_euler)
                v_loss = (VOLUME_LOSS_WEIGHT * vol_loss
                          + GRIPPER_LOSS_WEIGHT * grip_loss
                          + ROTATION_LOSS_WEIGHT * rot_loss)
                val_losses.append(v_loss.item())

                pred_2d, _ = extract_pred_2d_and_height(volume_logits)
                err = (pred_2d - traj_2d_scaled).norm(dim=-1).mean().item()
                val_pixel_errs.append(err)

        val_loss = np.mean(val_losses)
        val_px_err = np.mean(val_pixel_errs)

        # ── Log ──
        log_dict = {
            "epoch": epoch,
            "train/loss": train_loss,
            "train/vol_loss": np.mean(train_vol_losses),
            "train/grip_loss": np.mean(train_grip_losses),
            "train/rot_loss": np.mean(train_rot_losses),
            "val/loss": val_loss,
            "val/pixel_error": val_px_err,
            "lr": scheduler.get_last_lr()[0],
        }

        # Visualizations
        if epoch % args.vis_every == 0:
            try:
                train_batch = next(iter(train_loader))
                val_batch = next(iter(val_loader))
                train_sample = build_vis_sample(model, train_batch, device)
                val_sample = build_vis_sample(model, val_batch, device)
                train_strip = build_wandb_strip(train_sample, "train")
                val_strip = build_wandb_strip(val_sample, "val")
                if train_strip:
                    log_dict["vis/train_strip"] = train_strip
                if val_strip:
                    log_dict["vis/val_strip"] = val_strip
            except Exception as e:
                print(f"Vis error: {e}", flush=True)

        wandb.log(log_dict)

        if epoch % 10 == 0:
            print(f"Epoch {epoch:4d} | train={train_loss:.4f} | "
                  f"val={val_loss:.4f} | px_err={val_px_err:.1f}px", flush=True)

        # Save
        ckpt_data = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch, **stats,
        }
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({**ckpt_data, "val_loss": val_loss}, ckpt_dir / "best.pth")

        if epoch % 50 == 0:
            torch.save(ckpt_data, ckpt_dir / "latest.pth")

    wandb.finish()
    print(f"Done! Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoints: {ckpt_dir}")


if __name__ == "__main__":
    main()
