"""Precompute CLIP text embeddings for LIBERO task names.

Saves one .pt file per task under <output_dir>/<benchmark>/task_<id>_clip.pt
Each file contains a (D_clip,) float32 tensor.

Usage:
    python precompute_clip_embeddings.py --benchmark libero_spatial --output_dir /data/libero/parsed_libero
"""
import argparse
import os
import sys
from pathlib import Path

import torch

sys.path.insert(0, os.path.dirname(__file__))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--output_dir", type=str, default="/data/libero/parsed_libero")
    parser.add_argument("--clip_model", type=str, default="openai/clip-vit-base-patch32")
    args = parser.parse_args()

    from transformers import CLIPModel, CLIPTokenizer

    print(f"Loading CLIP model: {args.clip_model}")
    tokenizer = CLIPTokenizer.from_pretrained(args.clip_model)
    model = CLIPModel.from_pretrained(args.clip_model)
    model.eval()

    from libero.libero import benchmark as bm
    bench = bm.get_benchmark_dict()[args.benchmark]()
    n_tasks = bench.get_num_tasks()

    out_dir = Path(args.output_dir) / args.benchmark
    out_dir.mkdir(parents=True, exist_ok=True)

    for i in range(n_tasks):
        task = bench.get_task(i)
        # Convert underscores to spaces for natural language
        text = task.name.replace("_", " ")
        print(f"  task_{i}: {text}")

        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            text_features = model.get_text_features(**inputs)  # (1, D_clip)
        # Normalize
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        embedding = text_features[0]  # (D_clip,)

        out_path = out_dir / f"task_{i}_clip.pt"
        torch.save(embedding, out_path)
        print(f"    → {out_path}  shape={embedding.shape}")

    print(f"\nDone. Saved {n_tasks} embeddings (dim={embedding.shape[0]}) to {out_dir}")


if __name__ == "__main__":
    main()
