"""
Build a GIF from anim_pics: 3 panes (left=RGB, middle=heatmap, right=volume).
First timestep: animate sliding through height bins (~2 sec). Rest: one frame per timestep.
Run make_heatmap_vis.py first to populate volume_dino_tracks/scratch/anim_pics/.
"""
import argparse
from pathlib import Path

import cv2
import numpy as np
from PIL import Image

ANIM_PICS_DIR = Path("volume_dino_tracks/scratch/anim_pics")
OUT_GIF = Path("volume_dino_tracks/scratch/volume_anim.gif")


LABEL_H = 96   # 3x original 32
FONT = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE = 1.8   # 3x original 0.6
FONT_THICKNESS = 6   # 3x original 2
TEXT_COLOR = (255, 255, 255)
BAR_COLOR = (40, 40, 40)


def load_bgr(path: Path):
    if not path.exists():
        return None
    im = cv2.imread(str(path))
    return im  # BGR, HWC


def put_text_centered(img_bgr, text, pane_x_start, pane_w):
    (tw, th), _ = cv2.getTextSize(text, FONT, FONT_SCALE, FONT_THICKNESS)
    x = pane_x_start + (pane_w - tw) // 2
    y = LABEL_H // 2 + th // 2
    cv2.putText(img_bgr, text, (x, y), FONT, FONT_SCALE, TEXT_COLOR, FONT_THICKNESS, cv2.LINE_AA)


def main():
    parser = argparse.ArgumentParser(description="Build 3-pane volume animation GIF from anim_pics")
    parser.add_argument("--pics_dir", type=Path, default=ANIM_PICS_DIR, help="Directory with saved anim pics")
    parser.add_argument("--out", type=Path, default=OUT_GIF, help="Output GIF path")
    parser.add_argument("--t0_duration_sec", type=float, default=2.0, help="Duration of first-timestep height animation in seconds")
    parser.add_argument("--t_other_duration_ms", type=int, default=500, help="Duration per frame for other timesteps (ms)")
    args = parser.parse_args()

    pics_dir = args.pics_dir
    if not pics_dir.exists():
        raise FileNotFoundError(f"Run make_heatmap_vis.py first to create {pics_dir}")

    rgb_path = pics_dir / "rgb.png"
    rgb_bgr = load_bgr(rgb_path)
    if rgb_bgr is None:
        raise FileNotFoundError(f"Missing {rgb_path}")

    # Discover first-timestep height frames
    heatmap_h_files = sorted(pics_dir.glob("t0_heatmap_h*.png"))
    volume_h_files = sorted(pics_dir.glob("t0_volume_h*.png"))
    n_height = min(len(heatmap_h_files), len(volume_h_files))
    if n_height == 0:
        raise FileNotFoundError(f"No t0_heatmap_h*.png or t0_volume_h*.png in {pics_dir}")

    # Discover other timesteps (t1_volume_kp.png, t2_volume_kp.png, ...)
    volume_kp_files = sorted(pics_dir.glob("t*_volume_kp.png"))
    max_heatmap_files = sorted(pics_dir.glob("t*_max_heatmap.png"))
    timestep_indices = []
    for p in volume_kp_files:
        try:
            # t0_volume_kp.png -> 0, t1_volume_kp.png -> 1
            stem = p.stem
            t_str = stem.replace("t", "").split("_")[0]
            timestep_indices.append(int(t_str))
        except (ValueError, IndexError):
            continue
    timestep_indices = sorted(set(timestep_indices))

    H, W = rgb_bgr.shape[:2]
    frame_w = W * 3
    frame_h = H + LABEL_H

    height_values_mm = None
    hv_path = pics_dir / "height_values.npy"
    if hv_path.exists():
        try:
            height_values_mm = np.load(str(hv_path)) * 1000  # m -> mm
        except Exception:
            pass

    duration_t0_ms = int(args.t0_duration_sec * 1000 / n_height) if n_height else 100
    duration_t0_ms = max(20, duration_t0_ms)

    frames_pil = []

    def make_frame_with_labels(left_bgr, mid_bgr, right_bgr, label_left, label_mid, label_right):
        if left_bgr.shape[:2] != (H, W):
            left_bgr = cv2.resize(left_bgr, (W, H))
        if mid_bgr.shape[:2] != (H, W):
            mid_bgr = cv2.resize(mid_bgr, (W, H))
        if right_bgr.shape[:2] != (H, W):
            right_bgr = cv2.resize(right_bgr, (W, H))
        frame_bgr = np.zeros((frame_h, frame_w, 3), dtype=np.uint8)
        frame_bgr[:] = BAR_COLOR
        frame_bgr[LABEL_H:, 0:W] = left_bgr
        frame_bgr[LABEL_H:, W:2*W] = mid_bgr
        frame_bgr[LABEL_H:, 2*W:3*W] = right_bgr
        put_text_centered(frame_bgr, label_left, 0, W)
        put_text_centered(frame_bgr, label_mid, W, W)
        put_text_centered(frame_bgr, label_right, 2*W, W)
        return Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))

    # First timestep: animate height from min to max (~2 sec)
    for h in range(n_height):
        heat_path = pics_dir / f"t0_heatmap_h{h:02d}.png"
        vol_path = pics_dir / f"t0_volume_h{h:02d}.png"
        heat_bgr = load_bgr(heat_path)
        vol_bgr = load_bgr(vol_path)
        if heat_bgr is None or vol_bgr is None:
            continue
        if height_values_mm is not None and h < len(height_values_mm):
            mid_label = f"timestep 0, heatmap for height={height_values_mm[h]:.0f}mm"
        else:
            mid_label = f"timestep 0, heatmap for height layer {h}/{n_height}"
        frame = make_frame_with_labels(rgb_bgr, heat_bgr, vol_bgr, "RGB", mid_label, "volume visualization")
        frames_pil.append(frame)

    # Append one frame per other timestep (t > 0): left=rgb, middle=max_heatmap, right=volume_kp
    other_t = [t for t in timestep_indices if t > 0]
    for t in other_t:
        max_path = pics_dir / f"t{t}_max_heatmap.png"
        vol_path = pics_dir / f"t{t}_volume_kp.png"
        max_bgr = load_bgr(max_path)
        vol_bgr = load_bgr(vol_path)
        if max_bgr is None or vol_bgr is None:
            continue
        mid_label = f"timestep {t} height-maxed heatmap"
        frame = make_frame_with_labels(rgb_bgr, max_bgr, vol_bgr, "RGB", mid_label, "volume visualization")
        frames_pil.append(frame)

    if not frames_pil:
        raise RuntimeError("No frames to write")

    # Durations: first n_height frames at duration_t0_ms (~2 sec total), rest at args.t_other_duration_ms
    n_other = len(frames_pil) - n_height
    durations_ms = [duration_t0_ms] * n_height + [args.t_other_duration_ms] * n_other

    out_path = args.out
    out_path.parent.mkdir(parents=True, exist_ok=True)

    frames_pil[0].save(
        str(out_path),
        save_all=True,
        append_images=frames_pil[1:],
        duration=durations_ms,
        loop=0,
    )
    print(f"Saved {len(frames_pil)} frames to {out_path}")


if __name__ == "__main__":
    main()
