"""Filesystem scan/parse for the mock-wandb viewer (omidlab.net/runs/).

Pure functions, no Flask. serve.py imports list_runs / run_detail / find_run and
wraps them in thin JSON routes (mirrors the data_viewer pattern).

Design notes live in vault/fleet/agents/our_wandb/{overview,memory}.md. Highlights:
- Run dirs: ``run-<YYYYMMDD_HHMMSS>-<id>/`` under the configured roots.
- Grouping key = ``--name`` (newer train.py) or ``--run_name`` (older), read from
  ``files/wandb-metadata.json`` args (always present). Fallback: program basename.
- Loss step-series come from ``files/output.log`` (newer train.py prints
  ``ckpt @ <step>  total=.. vol=.. grip=.. ..``). Older runs have no inline losses;
  we surface the final value from wandb-summary.json instead. No .wandb parsing (v2).
- config.yaml / wandb-summary.json only exist on COMPLETED runs — never required.

Fail-loud per fleet GUIDELINES, EXCEPT the list path tolerates a single unreadable
run dir (stale FUSE) by skipping it — one wedged dir must not blank the dashboard.
"""

import json
import os
import re
from pathlib import Path

DEFAULT_ROOTS = [
    "/home/cameronsmith/mnt/yukon/cameron/puget/wandb",
    "/home/cameronsmith/mnt/yukon/cameron/puget/code/wandb",
    "/home/cameronsmith/mnt/yukon/cameron/puget/data/wandb",
]

_RUN_RE = re.compile(r"^run-(\d{8}_\d{6})-([0-9a-zA-Z]+)$")
# newer train.py log line: "  ckpt @ 200  total=14.1160 vol=10.3021 grip=0.99 rot=2.82  2.13 it/s"
_CKPT_RE = re.compile(r"ckpt @ (\d+)\s+(.*?)(?:\s+[\d.]+\s*it/s)?\s*$")
_KV_RE = re.compile(r"([A-Za-z_/]+)=([-+]?[\d.]+(?:[eE][-+]?\d+)?)")


def get_roots():
    env = os.environ.get("WANDB_ROOTS")
    if env:
        return [r.strip() for r in env.split(",") if r.strip()]
    return DEFAULT_ROOTS


def _arg_value(args, *flags):
    """Return the value following the first matching flag in a wandb args list."""
    for flag in flags:
        if flag in args:
            i = args.index(flag)
            if i + 1 < len(args):
                return args[i + 1]
    return None


def _read_metadata(run_dir):
    """Parse wandb-metadata.json → dict (always present). Returns {} if missing."""
    p = run_dir / "files" / "wandb-metadata.json"
    if not p.exists():
        return {}
    return json.loads(p.read_text())


def _run_name(meta, run_id):
    args = meta.get("args", []) or []
    name = _arg_value(args, "--name", "--run_name")
    if name:
        return name
    program = meta.get("program")
    if program:
        return Path(program).stem
    return run_id


def _run_summary_brief(run_dir):
    """Final scalar metrics from wandb-summary.json (completed runs only)."""
    p = run_dir / "files" / "wandb-summary.json"
    if not p.exists():
        return {}
    data = json.loads(p.read_text())
    # keep only scalar entries (drop image-file refs and _wandb internals)
    return {
        k: v
        for k, v in data.items()
        if isinstance(v, (int, float)) and not k.startswith("_")
    }


def list_runs(roots=None):
    """Scan roots → list of run summaries grouped by name.

    Returns {"groups": [{"name", "count", "runs": [...]}], "n_skipped": int,
             "roots": [...]}, newest run first within each group, groups sorted
             by their newest run's mtime.
    """
    roots = roots or get_roots()
    runs = []
    skipped = 0
    seen_ids = set()
    for root in roots:
        root = Path(root)
        if not root.is_dir():
            continue
        try:
            entries = sorted(root.iterdir())
        except OSError:  # stale FUSE / unreadable root
            skipped += 1
            continue
        for d in entries:
            m = _RUN_RE.match(d.name)
            if not m:
                continue
            run_id = m.group(2)
            if run_id in seen_ids:  # same run can appear under multiple roots
                continue
            try:
                meta = _read_metadata(d)
                mtime = d.stat().st_mtime
            except OSError:  # one wedged dir must not blank the list
                skipped += 1
                continue
            seen_ids.add(run_id)
            runs.append({
                "id": run_id,
                "name": _run_name(meta, run_id),
                "started_at": meta.get("startedAt"),
                "host": meta.get("host"),
                "gpu": meta.get("gpu"),
                "project": _arg_value(meta.get("args", []) or [], "--wandb_project"),
                "mtime": mtime,
                "root": str(root),
            })

    groups = {}
    for r in runs:
        groups.setdefault(r["name"], []).append(r)
    out = []
    for name, group_runs in groups.items():
        group_runs.sort(key=lambda r: r["mtime"], reverse=True)
        out.append({
            "name": name,
            "count": len(group_runs),
            "newest_mtime": group_runs[0]["mtime"],
            "runs": group_runs,
        })
    out.sort(key=lambda g: g["newest_mtime"], reverse=True)
    return {"groups": out, "n_skipped": skipped, "roots": [str(r) for r in roots]}


def find_run(run_id, roots=None):
    """Locate a run dir by its id across roots. Returns Path or None."""
    roots = roots or get_roots()
    for root in roots:
        root = Path(root)
        if not root.is_dir():
            continue
        try:
            for d in root.iterdir():
                m = _RUN_RE.match(d.name)
                if m and m.group(2) == run_id:
                    return d
        except OSError:
            continue
    return None


def parse_loss_series(run_dir):
    """Parse newer train.py's output.log into {metric: [[step, val], ...]}.

    Returns {} when no inline loss lines exist (older runs / pre-training)."""
    p = run_dir / "files" / "output.log"
    if not p.exists():
        return {}
    series = {}
    for line in p.read_text(errors="replace").splitlines():
        cm = _CKPT_RE.search(line)
        if not cm:
            continue
        step = int(cm.group(1))
        for key, val in _KV_RE.findall(cm.group(2)):
            series.setdefault(key, []).append([step, float(val)])
    return series


def list_media(run_dir):
    """Group logged PNGs by panel key. Returns [{"panel", "images": [...]}].

    Each image: {"rel": <path under files/>, "step": int|None}. Newest step last.
    """
    media_root = run_dir / "files" / "media" / "images"
    if not media_root.is_dir():
        return []
    panels = {}
    for png in media_root.rglob("*.png"):
        rel = png.relative_to(run_dir / "files")
        stem = png.stem
        # strip trailing _<hash>; step is the trailing-but-one integer token
        parts = stem.rsplit("_", 2)  # [key_idx, step, hash]
        step = None
        panel = png.parent.name
        if len(parts) == 3 and parts[1].isdigit():
            step = int(parts[1])
            key = parts[0].rsplit("_", 1)[0] if "_" in parts[0] else parts[0]
            panel = key or panel
        panels.setdefault(panel, []).append({"rel": str(rel), "step": step})
    for imgs in panels.values():
        imgs.sort(key=lambda i: (i["step"] is None, i["step"] or 0))
    return [{"panel": k, "images": v} for k, v in sorted(panels.items())]


def _config_args(run_dir, meta):
    """Best-effort config: prefer metadata args (always present, flat & readable)."""
    args = meta.get("args", []) or []
    cfg = {}
    i = 0
    while i < len(args):
        tok = args[i]
        if tok.startswith("--"):
            if i + 1 < len(args) and not args[i + 1].startswith("--"):
                cfg[tok.lstrip("-")] = args[i + 1]
                i += 2
            else:
                cfg[tok.lstrip("-")] = True
                i += 1
        else:
            i += 1
    return cfg


def run_detail(run_id, roots=None):
    """Full per-run payload for the detail view. None if run not found."""
    run_dir = find_run(run_id, roots)
    if run_dir is None:
        return None
    meta = _read_metadata(run_dir)
    series = parse_loss_series(run_dir)
    return {
        "id": run_id,
        "name": _run_name(meta, run_id),
        "dir": str(run_dir),
        "started_at": meta.get("startedAt"),
        "host": meta.get("host"),
        "gpu": meta.get("gpu"),
        "python": meta.get("python"),
        "program": meta.get("program"),
        "config": _config_args(run_dir, meta),
        "summary": _run_summary_brief(run_dir),
        "loss_series": series,
        "has_step_history": bool(series),
        "media": list_media(run_dir),
    }
