"""Render the middle-frame YAM bimanual overlay for every episode of every task,
then build a contact-sheet per task for fast visual QA of camera calibration.

Output layout:
  /home/robot-lab/cameron/yam_overlay/calibration_qa/
    <task>/
      ep_<eid>.png           # side-by-side raw|overlay per episode (full res)
      _contact_sheet.png     # 4-col grid of all episode overlays (thumbnailed)
    _manifest.json           # which episodes succeeded / failed

Usage:
  MUJOCO_GL=egl python yam_batch_qa.py [--tasks task1,task2] [--limit 5]
"""
import os
os.environ.setdefault("MUJOCO_GL", "egl")

import argparse
import json
import time
import traceback
from pathlib import Path

import cv2
import numpy as np

import yam_overlay_render as R  # local module


DATA_ROOT = Path("/home/robot-lab/data/processed")
OUT_ROOT = Path("/home/robot-lab/cameron/yam_overlay/calibration_qa")


def list_tasks():
    """Tasks are top-level dirs under DATA_ROOT that contain at least one episode."""
    out = []
    for d in sorted(DATA_ROOT.iterdir()):
        if not d.is_dir():
            continue
        episodes = [e for e in d.iterdir() if e.is_dir() and (e / "lowdim").is_dir()]
        if episodes:
            out.append(d.name)
    return out


def list_episodes_for(task):
    return sorted(
        [e.name for e in (DATA_ROOT / task).iterdir() if e.is_dir() and (e / "lowdim").is_dir()]
    )


def thumbnail_panel(panel, max_w=640):
    """Scale a side-by-side panel (raw|overlay) to a max width while preserving aspect."""
    h, w = panel.shape[:2]
    if w <= max_w:
        return panel
    new_w = max_w
    new_h = int(round(h * new_w / w))
    return cv2.resize(panel, (new_w, new_h), interpolation=cv2.INTER_AREA)


def build_contact_sheet(panels_with_labels, ncols=3, thumb_max_w=640, pad=8, bg=(20, 20, 20)):
    """Tile (panel, label) pairs into a single PNG. Panels are already side-by-side raw|overlay."""
    if not panels_with_labels:
        return None
    thumbs = []
    for panel, label in panels_with_labels:
        t = thumbnail_panel(panel, max_w=thumb_max_w)
        # Slight bottom strip for the per-tile label
        H, W = t.shape[:2]
        labeled = np.full((H + 24, W, 3), bg, dtype=np.uint8)
        labeled[:H, :W] = t
        cv2.putText(labeled, label, (6, H + 18), cv2.FONT_HERSHEY_SIMPLEX,
                    0.5, (220, 220, 220), 1, cv2.LINE_AA)
        thumbs.append(labeled)
    cell_h = max(t.shape[0] for t in thumbs)
    cell_w = max(t.shape[1] for t in thumbs)
    nrows = (len(thumbs) + ncols - 1) // ncols
    sheet_h = nrows * cell_h + (nrows + 1) * pad
    sheet_w = ncols * cell_w + (ncols + 1) * pad
    sheet = np.full((sheet_h, sheet_w, 3), bg, dtype=np.uint8)
    for i, t in enumerate(thumbs):
        r, c = divmod(i, ncols)
        y = pad + r * (cell_h + pad)
        x = pad + c * (cell_w + pad)
        sheet[y:y + t.shape[0], x:x + t.shape[1]] = t
    return sheet


def process_task(task, ep_limit=None):
    task_out = OUT_ROOT / task
    task_out.mkdir(parents=True, exist_ok=True)
    episodes = list_episodes_for(task)
    if ep_limit:
        episodes = episodes[:ep_limit]
    print(f"\n=== {task} — {len(episodes)} episodes ===", flush=True)
    results = []
    panels = []
    t0 = time.time()
    for i, ep in enumerate(episodes):
        ep_out = task_out / f"ep_{ep}.png"
        try:
            panel, label = R.render_one(task, ep, out_path=str(ep_out))
            if panel is None:
                results.append({"ep": ep, "ok": False, "error": label})
                print(f"  [{i+1}/{len(episodes)}] {ep}: FAIL — {label}", flush=True)
                continue
            panels.append((panel, ep))
            results.append({"ep": ep, "ok": True, "label": label, "path": str(ep_out)})
            if (i + 1) % 5 == 0 or i + 1 == len(episodes):
                elapsed = time.time() - t0
                rate = (i + 1) / elapsed if elapsed else 0
                eta = (len(episodes) - (i + 1)) / rate if rate else 0
                print(f"  [{i+1}/{len(episodes)}] {ep}: ok ({elapsed:.0f}s, "
                      f"{rate:.1f}/s, eta {eta:.0f}s)", flush=True)
        except Exception as e:
            results.append({"ep": ep, "ok": False, "error": str(e)})
            print(f"  [{i+1}/{len(episodes)}] {ep}: EXC — {e}", flush=True)
            traceback.print_exc()
    # Contact sheet
    sheet = build_contact_sheet(panels, ncols=3)
    if sheet is not None:
        sheet_path = task_out / "_contact_sheet.png"
        cv2.imwrite(str(sheet_path), sheet)
        print(f"  contact sheet: {sheet_path} ({sheet.shape[1]}x{sheet.shape[0]})", flush=True)
    return results


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--tasks", default=None, help="Comma-separated subset of task names.")
    p.add_argument("--limit", type=int, default=None,
                   help="Max episodes per task (for quick iteration).")
    args = p.parse_args()

    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    tasks = list_tasks() if not args.tasks else args.tasks.split(",")
    print(f"Tasks: {tasks}", flush=True)

    manifest = {}
    t_start = time.time()
    for task in tasks:
        manifest[task] = process_task(task, ep_limit=args.limit)
    with open(OUT_ROOT / "_manifest.json", "w") as f:
        json.dump(manifest, f, indent=2)
    print(f"\nDONE. Manifest: {OUT_ROOT / '_manifest.json'}", flush=True)
    print(f"Total wall time: {time.time() - t_start:.0f}s", flush=True)


if __name__ == "__main__":
    main()
