"""Precompute Depth Anything v3 (DA3-SMALL) depth predictions for our datasets.

For each input frame, saves a (DA3_INPUT, DA3_INPUT) float16 .npy of the predicted depth
in a sister directory `da3_depth/` next to the source images.

Usage:
  CUDA_VISIBLE_DEVICES=9 PYTHONPATH=/data/cameron/da3_repo/src:$PYTHONPATH \\
    python precompute_da3_depth.py --dataset smith300
  CUDA_VISIBLE_DEVICES=9 PYTHONPATH=/data/cameron/da3_repo/src:$PYTHONPATH \\
    python precompute_da3_depth.py --dataset libero
"""
import argparse, os, sys, types, glob, time
from pathlib import Path

# Stub heavy export deps DA3 doesn't need at inference time.
for n in ['depth_anything_3.utils.export', 'depth_anything_3.utils.pose_align']:
    m = types.ModuleType(n); sys.modules[n] = m
sys.modules['depth_anything_3.utils.export'].export = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].align_poses_umeyama = lambda *a, **k: None

import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from depth_anything_3.api import DepthAnything3

DA3_INPUT = 504           # DA3 default — patch_size 14 × 36 = 504
DA3_WEIGHTS = "/data/cameron/da3_weights"


def preprocess(bgr, size=DA3_INPUT):
    """uint8 BGR HWC → (1, 1, 3, size, size) float32 in [0, 1] RGB."""
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = cv2.resize(rgb, (size, size), interpolation=cv2.INTER_LINEAR)
    rgb = rgb.astype(np.float32) / 255.0
    t = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).unsqueeze(0)  # (1, 1, 3, H, W)
    return t


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, choices=["smith300", "libero", "umi"], required=True)
    p.add_argument("--smith300_root", type=str,
                   default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--libero_root", type=str,
                   default="/data/libero/parsed_libero/libero_spatial/task_0")
    p.add_argument("--umi_root", type=str,
                   default="/data/cameron/mac_robot_datasets/umi_fold_towel")
    p.add_argument("--max_libero_demos", type=int, default=50)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--overwrite", action="store_true")
    p.add_argument("--weights_path", type=str, default=DA3_WEIGHTS,
                   help="Path to DA3 weights dir (default: SMALL at /data/cameron/da3_weights)")
    p.add_argument("--depth_subdir", type=str, default="da3_depth",
                   help="Sub-dir name for cached depths (e.g. da3_depth, da3_depth_large)")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading DA3 from {args.weights_path}...")
    model = DepthAnything3.from_pretrained(args.weights_path).to(device).eval()

    # Collect (img_path, out_path) pairs
    todo = []
    if args.dataset == "smith300":
        root = Path(args.smith300_root)
        sessions = sorted([d for d in root.iterdir() if d.is_dir()])
        for sess in sessions:
            imgs = sorted(sess.glob("rgb_*.jpg"))
            if not imgs:
                continue
            out_dir = sess / args.depth_subdir
            out_dir.mkdir(exist_ok=True)
            for ip in imgs:
                op = out_dir / f"{ip.stem}.npy"
                if args.overwrite or not op.exists():
                    todo.append((ip, op))
        print(f"smith300: {len(sessions)} sessions, {len(todo)} frames to process")
    elif args.dataset == "umi":
        # UMI is a single session — the root IS the session dir
        sess = Path(args.umi_root)
        imgs = sorted(sess.glob("rgb_*.jpg"))
        out_dir = sess / args.depth_subdir
        out_dir.mkdir(exist_ok=True)
        for ip in imgs:
            op = out_dir / f"{ip.stem}.npy"
            if args.overwrite or not op.exists():
                todo.append((ip, op))
        print(f"umi: 1 session, {len(todo)} frames to process")
    else:  # libero
        root = Path(args.libero_root)
        demos = sorted(root.glob("demo_*"))[: args.max_libero_demos]
        for demo in demos:
            imgs = sorted((demo / "frames").glob("*.png"))
            if not imgs:
                continue
            out_dir = demo / args.depth_subdir
            out_dir.mkdir(exist_ok=True)
            for ip in imgs:
                op = out_dir / f"{ip.stem}.npy"
                if args.overwrite or not op.exists():
                    todo.append((ip, op))
        print(f"libero: {len(demos)} demos, {len(todo)} frames to process")

    if not todo:
        print("Nothing to do.")
        return

    t_start = time.time()
    # Process in batches for throughput
    BS = args.batch_size
    for i in tqdm(range(0, len(todo), BS), desc="batches"):
        chunk = todo[i:i + BS]
        # Load + preprocess
        imgs_list = []
        for ip, _ in chunk:
            bgr = cv2.imread(str(ip))
            if bgr is None:
                print(f"  could not read {ip}")
                imgs_list.append(None)
                continue
            imgs_list.append(preprocess(bgr))
        # Skip Nones
        valid = [(i_, x) for i_, x in enumerate(imgs_list) if x is not None]
        if not valid:
            continue
        batch_t = torch.cat([x for _, x in valid], dim=0).to(device)  # (k, 1, 3, H, W)
        with torch.no_grad():
            out = model(batch_t, export_feat_layers=[])
        depths = out['depth'].cpu().float().numpy()                    # (k, 1, H, W)
        for (idx_in_chunk, _), depth in zip(valid, depths):
            _, op = chunk[idx_in_chunk]
            np.save(op, depth[0].astype(np.float16))

    elapsed = time.time() - t_start
    print(f"Done. {len(todo)} frames in {elapsed:.1f}s ({len(todo)/elapsed:.1f} fps)")


if __name__ == "__main__":
    main()
