"""Co-training: circle (3D, no gripper) + robot (full 3D + gripper) in mixed batches.

Both datasets use the SAME model heads:
- PARA: same volume head (uv + height bins) for both. Circle = no gripper loss.
- ACT: same position MLP. Circle targets = 3D camera-frame coords. Robot = 3D world-frame coords.

Schedule: 50% circle + 50% robot for first 75% of training, then 25% circle + 75% robot.

Usage:
    python train_cotrain.py --model_type para --robot_demos 10 --run_name para_cotrain_10demo --max_minutes 30
"""
import argparse, os, sys, time
from pathlib import Path

import numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset, Dataset

sys.path.insert(0, os.path.dirname(__file__))
os.environ.setdefault("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, PRED_SIZE
from data import CachedTrajectoryDataset
from train import (compute_volume_loss, compute_gripper_loss, discretize_height,
                   normalize_to_01, denormalize_from_01)

IMAGE_SIZE = 448
N_WINDOW = 4


class LabeledDataset(Dataset):
    """Wraps a dataset and adds a 'source' label (circle vs robot)."""
    def __init__(self, dataset, source_label):
        self.dataset = dataset
        self.source_label = source_label
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        sample['source'] = self.source_label
        return sample


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="para", choices=["para", "act"])
    parser.add_argument("--robot_demos", type=int, required=True)
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--max_minutes", type=int, default=30)
    parser.add_argument("--wandb_mode", type=str, default="disabled")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load datasets
    circle_ds = CachedTrajectoryDataset(
        cache_root="/data/libero/ood_objpos_circle_overlay",
        task_ids=[0], image_size=IMAGE_SIZE, n_window=N_WINDOW, frame_stride=3)

    n = args.robot_demos
    robot_cache = "/data/libero/ood_objpos_v3" if n == 256 else f"/data/libero/ood_objpos_v3_splits/exp_{n}demo_finetune_train"
    robot_ds = CachedTrajectoryDataset(
        cache_root=robot_cache,
        task_ids=[0], image_size=IMAGE_SIZE, n_window=N_WINDOW, frame_stride=3)

    circle_labeled = LabeledDataset(circle_ds, "circle")
    robot_labeled = LabeledDataset(robot_ds, "robot")

    print(f"Circle: {len(circle_ds)} samples, Robot ({n} demos): {len(robot_ds)} samples")

    # Compute dataset stats from robot data
    all_eef = []
    for i in range(min(500, len(robot_ds))):
        s = robot_ds[i]
        all_eef.append(s['trajectory_3d'].numpy())
    all_eef = np.concatenate(all_eef, axis=0)
    model_module.MIN_POS = all_eef.min(axis=0).tolist()
    model_module.MAX_POS = all_eef.max(axis=0).tolist()
    print(f"Position range: {model_module.MIN_POS} .. {model_module.MAX_POS}")

    # Also compute from circle data for camera-frame ACT targets
    all_circle_eef = []
    for i in range(min(500, len(circle_ds))):
        s = circle_ds[i]
        all_circle_eef.append(s['trajectory_3d'].numpy())
    all_circle_eef = np.concatenate(all_circle_eef, axis=0)
    circle_min = all_circle_eef.min(axis=0).tolist()
    circle_max = all_circle_eef.max(axis=0).tolist()

    # Build model (standard, no extra heads)
    if args.model_type == "para":
        model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW)
    else:
        from model_act import ACTPredictor
        model = ACTPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW)
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)

    ckpt_dir = Path(f"checkpoints/{args.run_name}")
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    try:
        import wandb
        wandb.init(project="para_libero", name=args.run_name, mode=args.wandb_mode)
    except:
        pass

    best_val_loss = float('inf')
    start_time = time.time()
    global_step = 0

    while True:
        elapsed_min = (time.time() - start_time) / 60
        if elapsed_min > args.max_minutes:
            break

        # Adjust mixing ratio
        progress = elapsed_min / args.max_minutes
        circle_ratio = 0.5 if progress < 0.75 else 0.25

        # Create mixed dataloader by oversampling the smaller dataset
        # Sample indices with the desired ratio
        n_circle = int(len(circle_ds) * circle_ratio / (1 - circle_ratio + 1e-8))
        n_robot = len(robot_ds)
        # Oversample robot if needed
        if n_robot < n_circle:
            robot_indices = np.random.choice(len(robot_ds), size=n_circle, replace=True)
            circle_indices = np.arange(len(circle_ds))
        else:
            circle_indices = np.random.choice(len(circle_ds), size=n_robot, replace=True)
            robot_indices = np.arange(len(robot_ds))

        # Interleave: alternate circle and robot samples
        mixed_indices = []
        ci, ri = 0, 0
        n_c_per_batch = max(1, int(args.batch_size * circle_ratio))
        n_r_per_batch = args.batch_size - n_c_per_batch

        model.train()
        epoch_loss = 0
        n_batches = 0

        # Manual batching for control over mixing
        np.random.shuffle(circle_indices)
        np.random.shuffle(robot_indices)
        ci, ri = 0, 0

        while ci + n_c_per_batch <= len(circle_indices) and ri + n_r_per_batch <= len(robot_indices):
            if (time.time() - start_time) / 60 > args.max_minutes:
                break

            # Get batch samples
            c_samples = [circle_ds[circle_indices[ci + j]] for j in range(n_c_per_batch)]
            r_samples = [robot_ds[robot_indices[ri + j]] for j in range(n_r_per_batch)]
            ci += n_c_per_batch
            ri += n_r_per_batch

            total_loss = torch.tensor(0.0, device=device, requires_grad=True)

            if args.model_type == "para":
                # --- PARA: same volume head for both ---
                # Process circle samples
                c_imgs = torch.stack([s['rgb'] for s in c_samples]).to(device)
                c_traj_2d = torch.stack([s['trajectory_2d'] for s in c_samples]).to(device)
                c_traj_3d = torch.stack([s['trajectory_3d'] for s in c_samples]).to(device)
                c_kp = c_traj_2d[:, 0]

                c_vol, _, _, _ = model(c_imgs, c_kp)
                B_c = len(c_samples)
                c_height = c_traj_3d[:, :, 2]
                c_bins = discretize_height(c_height)
                c_vol_rs = c_vol.reshape(B_c, N_WINDOW, N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE)
                c_loss = compute_volume_loss(c_vol_rs, c_traj_2d, c_bins)

                # Process robot samples
                r_imgs = torch.stack([s['rgb'] for s in r_samples]).to(device)
                r_traj_2d = torch.stack([s['trajectory_2d'] for s in r_samples]).to(device)
                r_traj_3d = torch.stack([s['trajectory_3d'] for s in r_samples]).to(device)
                r_kp = r_traj_2d[:, 0]

                r_vol, r_grip, _, _ = model(r_imgs, r_kp)
                B_r = len(r_samples)
                r_height = r_traj_3d[:, :, 2]
                r_bins = discretize_height(r_height)
                r_vol_rs = r_vol.reshape(B_r, N_WINDOW, N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE)
                r_loss = compute_volume_loss(r_vol_rs, r_traj_2d, r_bins)

                total_loss = c_loss + r_loss

            else:
                # --- ACT: same pos MLP, circle uses camera-frame 3D ---
                min_pos = torch.tensor(model_module.MIN_POS, device=device, dtype=torch.float32)
                max_pos = torch.tensor(model_module.MAX_POS, device=device, dtype=torch.float32)

                # Circle samples: transform to camera frame
                c_imgs = torch.stack([s['rgb'] for s in c_samples]).to(device)
                c_traj_2d = torch.stack([s['trajectory_2d'] for s in c_samples]).to(device)
                c_traj_3d = torch.stack([s['trajectory_3d'] for s in c_samples]).to(device)
                c_w2c = torch.stack([s['world_to_camera'] for s in c_samples]).to(device)  # (B, 4, 4)
                c_kp = c_traj_2d[:, 0]

                # Transform 3D world coords to camera frame
                B_c = len(c_samples)
                c_3d_hom = torch.cat([c_traj_3d, torch.ones(B_c, N_WINDOW, 1, device=device)], dim=-1)  # (B,T,4)
                c_cam_3d = torch.einsum('bij,btj->bti', c_w2c[:, :3, :], c_3d_hom)  # (B,T,3)
                # Normalize camera-frame coords to [0,1] using robot's min/max as approximation
                c_pos_target = normalize_to_01(c_cam_3d, min_pos, max_pos)

                c_eef_norm = normalize_to_01(c_traj_3d[:, 0], min_pos, max_pos)
                c_grip = torch.zeros(B_c, 1, device=device)
                c_pos_pred, _, _ = model(c_imgs, c_kp, current_eef_pos=c_eef_norm, current_gripper=c_grip)
                c_loss = F.mse_loss(c_pos_pred, c_pos_target)

                # Robot samples: world-frame 3D
                r_imgs = torch.stack([s['rgb'] for s in r_samples]).to(device)
                r_traj_2d = torch.stack([s['trajectory_2d'] for s in r_samples]).to(device)
                r_traj_3d = torch.stack([s['trajectory_3d'] for s in r_samples]).to(device)
                r_traj_grip = torch.stack([s['trajectory_gripper'] for s in r_samples]).to(device)
                r_kp = r_traj_2d[:, 0]

                r_pos_target = normalize_to_01(r_traj_3d, min_pos, max_pos)
                r_eef_norm = normalize_to_01(r_traj_3d[:, 0], min_pos, max_pos)
                r_grip = r_traj_grip[:, 0:1]
                r_pos_pred, _, r_grip_pred = model(r_imgs, r_kp, current_eef_pos=r_eef_norm, current_gripper=r_grip)
                r_pos_loss = F.mse_loss(r_pos_pred, r_pos_target)
                r_grip_target = (r_traj_grip > 0).float()
                r_grip_loss = F.binary_cross_entropy_with_logits(r_grip_pred, r_grip_target)
                r_loss = r_pos_loss + r_grip_loss

                total_loss = c_loss + r_loss

            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += total_loss.item()
            n_batches += 1
            global_step += 1

            if global_step % 50 == 0:
                print(f"  step {global_step}: loss={total_loss.item():.4f} circle_ratio={circle_ratio:.2f}", flush=True)
                try:
                    wandb.log({"train/loss": total_loss.item(), "step": global_step})
                except:
                    pass

        if n_batches == 0:
            break
        avg_loss = epoch_loss / n_batches
        elapsed = (time.time() - start_time) / 60
        print(f"Epoch done: step={global_step} loss={avg_loss:.4f} [{elapsed:.1f}min]")

        if avg_loss < best_val_loss:
            best_val_loss = avg_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'global_step': global_step,
                'best_val_loss': best_val_loss,
                'model_type': args.model_type,
                'robot_demos': args.robot_demos,
                'cotrain': True,
                'min_pos': model_module.MIN_POS,
                'max_pos': model_module.MAX_POS,
            }, ckpt_dir / "best.pth")

    print(f"\n{'='*50}")
    print(f"Co-training complete. Best loss: {best_val_loss:.4f}, steps: {global_step}")
    print(f"Checkpoint: {ckpt_dir / 'best.pth'}")
    print(f"{'='*50}")


if __name__ == "__main__":
    main()
