"""2D Point Track Pretraining — trains PARA or ACT on 2D pixel trajectories only.

PARA mode: N_HEIGHT_BINS=1, CE over H×W per timestep (predict future EEF pixel location)
ACT mode: CLS→MLP→(u,v) pixel coordinates, L2 loss

No height, no gripper, no rotation supervision.

Usage:
    # PARA 2D pretrain on arm-deleted data
    python train_pretrain_2d.py --model_type para --cache_root /data/libero/ood_objpos_arm_deleted \
        --run_name para_pretrain_arm_deleted --max_minutes 10

    # ACT pixel pretrain on circle overlay data
    python train_pretrain_2d.py --model_type act --cache_root /data/libero/ood_objpos_circle_overlay \
        --run_name act_pretrain_circle --max_minutes 10
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.insert(0, os.path.dirname(__file__))

# Override N_HEIGHT_BINS before importing model
import model as model_module
ORIGINAL_N_HEIGHT_BINS = model_module.N_HEIGHT_BINS

from data import CachedTrajectoryDataset
from model import TrajectoryHeatmapPredictor, PRED_SIZE

DINO_REPO_DIR = os.environ.get("DINO_REPO_DIR", "")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "")

N_WINDOW = 4
IMAGE_SIZE = 448


def compute_2d_pixel_loss_para(pred_volume, trajectory_2d, pred_size):
    """CE loss over H×W for each timestep. No height dimension.

    pred_volume: (B, N_WINDOW, 1, H, W) — single height bin
    trajectory_2d: (B, N_WINDOW, 2) — target (u, v) pixel coords
    """
    B, T, _, H, W = pred_volume.shape
    # Flatten spatial dims: (B, T, H*W)
    logits = pred_volume[:, :, 0].reshape(B, T, H * W)

    # Convert (u, v) to pixel indices in pred_size grid
    target_u = (trajectory_2d[:, :, 0] / IMAGE_SIZE * H).long().clamp(0, H - 1)
    target_v = (trajectory_2d[:, :, 1] / IMAGE_SIZE * W).long().clamp(0, W - 1)
    target_flat = target_v * W + target_u  # (B, T)

    loss = F.cross_entropy(logits.reshape(B * T, H * W), target_flat.reshape(B * T))
    return loss


def compute_2d_pixel_loss_act(pred_uv, trajectory_2d):
    """L2 loss on predicted (u, v) pixel coordinates.

    pred_uv: (B, N_WINDOW, 2) — predicted pixel coords (sigmoid → [0, 1])
    trajectory_2d: (B, N_WINDOW, 2) — target (u, v) pixel coords
    """
    # Normalize targets to [0, 1]
    target_norm = trajectory_2d / IMAGE_SIZE
    return F.mse_loss(pred_uv, target_norm)


class ACTPixelPredictor(nn.Module):
    """ACT-style model that predicts (u, v) pixel coordinates instead of 3D positions."""

    def __init__(self, target_size=448, n_window=N_WINDOW, **kwargs):
        super().__init__()
        self.target_size = target_size
        self.n_window = n_window
        self.model_type = "act_pixel"

        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus',
                                    source='local', weights=DINO_WEIGHTS_PATH)
        self.embed_dim = self.dino.embed_dim
        D = self.embed_dim

        # Input: CLS(D) + start_kp(2)
        inp_dim = D + 2

        self.pixel_mlp = nn.Sequential(
            nn.LayerNorm(inp_dim),
            nn.Linear(inp_dim, D),
            nn.GELU(),
            nn.Linear(D, D),
            nn.GELU(),
            nn.Linear(D, n_window * 2),
            nn.Sigmoid(),
        )
        n_params = sum(p.numel() for p in self.parameters())
        print(f"✓ ACT Pixel model: {n_params:,} params, output (B, {n_window}, 2)")

    def _extract_cls(self, x):
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
        else:
            x_norm = self.dino.norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
        return x_norm[:, 0]

    def forward(self, x, start_keypoint_2d, **kwargs):
        B = x.shape[0]
        cls = self._extract_cls(x)
        if start_keypoint_2d.dim() == 1:
            start_keypoint_2d = start_keypoint_2d.unsqueeze(0).expand(B, -1)
        kp_norm = start_keypoint_2d / self.target_size
        inp = torch.cat([cls, kp_norm], dim=-1)
        pred = self.pixel_mlp(inp).reshape(B, self.n_window, 2)
        return pred


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="para", choices=["para", "act"])
    parser.add_argument("--cache_root", type=str, required=True)
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=9999)
    parser.add_argument("--max_minutes", type=int, default=10)
    parser.add_argument("--wandb_project", type=str, default="para_libero")
    parser.add_argument("--wandb_mode", type=str, default="online")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset
    dataset = CachedTrajectoryDataset(
        cache_root=args.cache_root,
        benchmark_name=args.benchmark,
        task_ids=[args.task_id],
        image_size=IMAGE_SIZE,
        n_window=N_WINDOW,
        frame_stride=3,
    )
    print(f"Dataset: {len(dataset)} samples")

    n_val = max(1, len(dataset) // 20)
    n_train = len(dataset) - n_val
    train_ds, val_ds = torch.utils.data.random_split(dataset, [n_train, n_val],
                                                       generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                               num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    # Build model
    if args.model_type == "para":
        # Set N_HEIGHT_BINS=1 for 2D-only pretraining (keep it set during training)
        model_module.N_HEIGHT_BINS = 1
        model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW)
    else:
        model = ACTPixelPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW)

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)

    # Checkpoint dir
    ckpt_dir = Path(f"checkpoints/{args.run_name}")
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # wandb
    try:
        import wandb
        wandb.init(project=args.wandb_project, name=args.run_name, mode=args.wandb_mode)
    except:
        pass

    # Training loop
    best_val_loss = float('inf')
    start_time = time.time()
    global_step = 0

    for epoch in range(args.epochs):
        if (time.time() - start_time) / 60 > args.max_minutes:
            break

        model.train()
        epoch_loss = 0
        n_batches = 0

        for batch in train_loader:
            if (time.time() - start_time) / 60 > args.max_minutes:
                break

            img = batch['rgb'].to(device)
            traj_2d = batch['trajectory_2d'].to(device)  # (B, N_WINDOW, 2)
            start_kp = traj_2d[:, 0]  # first timestep as start keypoint

            if args.model_type == "para":
                volume_logits, _, _, _ = model(img, start_kp)
                # volume_logits: (B, N_WINDOW * 1, H, W)
                B = img.shape[0]
                H = W = PRED_SIZE
                vol = volume_logits.reshape(B, N_WINDOW, 1, H, W)
                loss = compute_2d_pixel_loss_para(vol, traj_2d, PRED_SIZE)
            else:
                pred_uv = model(img, start_kp)
                loss = compute_2d_pixel_loss_act(pred_uv, traj_2d)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()
            n_batches += 1
            global_step += 1

            if global_step % 50 == 0:
                try:
                    wandb.log({"train/loss": loss.item(), "step": global_step})
                except:
                    pass

            # Visualize heatmap every 200 steps
            if global_step % 200 == 0 and args.model_type == "para":
                try:
                    model.eval()
                    with torch.no_grad():
                        vis_vol, _, _, _ = model(img[:1], start_kp[:1])
                    model.train()
                    vis_vol = vis_vol.reshape(1, N_WINDOW, 1, PRED_SIZE, PRED_SIZE)

                    # Get input image as numpy
                    vis_img = img[0].cpu().permute(1, 2, 0).numpy()
                    vis_img = (vis_img * 255).clip(0, 255).astype(np.uint8)
                    vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)

                    # Draw start keypoint
                    kp = start_kp[0].cpu().numpy()
                    cv2.circle(vis_img, (int(kp[0]), int(kp[1])), 6, (0, 255, 0), -1)

                    panels = [cv2.resize(vis_img, (200, 200))]
                    gt_panels = [cv2.resize(vis_img.copy(), (200, 200))]

                    for t in range(N_WINDOW):
                        # Predicted heatmap
                        logits = vis_vol[0, t, 0].cpu()
                        prob = torch.softmax(logits.reshape(-1), dim=0).reshape(PRED_SIZE, PRED_SIZE).numpy()
                        hm = cv2.resize(prob, (IMAGE_SIZE, IMAGE_SIZE))
                        hm_norm = (hm / (hm.max() + 1e-8) * 255).astype(np.uint8)
                        hm_color = cv2.applyColorMap(hm_norm, cv2.COLORMAP_JET)
                        overlay = cv2.addWeighted(vis_img, 0.4, hm_color, 0.6, 0)
                        # Predicted peak
                        peak = np.unravel_index(hm.argmax(), hm.shape)
                        cv2.circle(overlay, (peak[1], peak[0]), 5, (255, 255, 255), -1)
                        cv2.putText(overlay, f"t+{t+1}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
                        panels.append(cv2.resize(overlay, (200, 200)))

                        # GT target
                        gt_img = vis_img.copy()
                        gt_uv = traj_2d[0, t].cpu().numpy()
                        cv2.circle(gt_img, (int(gt_uv[0]), int(gt_uv[1])), 8, (0, 0, 255), -1)
                        cv2.putText(gt_img, f"GT t+{t+1}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255), 2)
                        gt_panels.append(cv2.resize(gt_img, (200, 200)))

                    vis_row = np.concatenate(panels, axis=1)
                    gt_row = np.concatenate(gt_panels, axis=1)
                    vis_combined = np.vstack([vis_row, gt_row])

                    # Save to checkpoint dir
                    vis_dir = ckpt_dir / "vis"
                    vis_dir.mkdir(exist_ok=True)
                    cv2.imwrite(str(vis_dir / f"step_{global_step:06d}.png"), vis_combined)

                    # Log to wandb
                    vis_rgb = cv2.cvtColor(vis_combined, cv2.COLOR_BGR2RGB)
                    wandb.log({"vis/heatmap": wandb.Image(vis_rgb), "step": global_step})
                except Exception as e:
                    print(f"  Vis error: {e}")

        if n_batches == 0:
            break
        avg_train = epoch_loss / n_batches

        # Validation
        model.eval()
        val_loss = 0
        val_n = 0
        with torch.no_grad():
            for batch in val_loader:
                img = batch['rgb'].to(device)
                traj_2d = batch['trajectory_2d'].to(device)
                start_kp = traj_2d[:, 0]

                if args.model_type == "para":
                    volume_logits, _, _, _ = model(img, start_kp)
                    B = img.shape[0]
                    vol = volume_logits.reshape(B, N_WINDOW, 1, PRED_SIZE, PRED_SIZE)
                    loss = compute_2d_pixel_loss_para(vol, traj_2d, PRED_SIZE)
                else:
                    pred_uv = model(img, start_kp)
                    loss = compute_2d_pixel_loss_act(pred_uv, traj_2d)

                val_loss += loss.item()
                val_n += 1

        avg_val = val_loss / max(1, val_n)
        elapsed = (time.time() - start_time) / 60

        print(f"Epoch {epoch+1}: train={avg_train:.4f} val={avg_val:.4f} [{elapsed:.1f}min]")

        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': avg_val,
                'model_type': args.model_type,
                'n_height_bins': 1 if args.model_type == "para" else None,
                'pretrain_2d': True,
            }, ckpt_dir / "best.pth")

        try:
            wandb.log({"val/loss": avg_val, "epoch": epoch})
        except:
            pass

    print(f"\n{'='*50}")
    print(f"Training complete. Best val loss: {best_val_loss:.4f}")
    print(f"Checkpoint: {ckpt_dir / 'best.pth'}")
    print(f"{'='*50}")

    try:
        wandb.finish()
    except:
        pass


if __name__ == "__main__":
    main()
