"""
Stitch ACT vs PARA comparison clips into a single video.

Sequence:
  1. ACT succeeds at in-dist position (establishes baseline works)
  2. ACT fails at OOD position (object shifted right)
  3. ACT fails at OOD viewpoint (camera shifted)
  4. PARA succeeds at same OOD position
  5. PARA succeeds at same OOD viewpoint

Each clip is trimmed, labeled, and concatenated with brief transition frames.
"""

import cv2
import numpy as np
import os
from pathlib import Path


def load_video_frames(path, max_frames=None):
    """Load all frames from an MP4 file."""
    cap = cv2.VideoCapture(str(path))
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
        if max_frames and len(frames) >= max_frames:
            break
    cap.release()
    return frames


def add_label(frame, text, position="top", font_scale=0.7, thickness=2, color=(255, 255, 255)):
    """Add text label to frame."""
    frame = frame.copy()
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    (tw, th), baseline = cv2.getTextSize(text, font, font_scale, thickness)

    if position == "top":
        x = (w - tw) // 2
        y = th + 15
    elif position == "bottom":
        x = (w - tw) // 2
        y = h - 15
    else:
        x, y = position

    # Background
    cv2.rectangle(frame, (x - 6, y - th - 6), (x + tw + 6, y + baseline + 6),
                  (0, 0, 0), -1)
    cv2.putText(frame, text, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
    return frame


def add_status_badge(frame, text, success=True):
    """Add a success/failure badge in the top-right corner."""
    frame = frame.copy()
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.6
    thickness = 2
    (tw, th), baseline = cv2.getTextSize(text, font, font_scale, thickness)

    color = (0, 180, 0) if success else (0, 0, 200)  # green or red (BGR)
    x = w - tw - 20
    y = th + 15

    cv2.rectangle(frame, (x - 8, y - th - 8), (x + tw + 8, y + baseline + 8), color, -1)
    cv2.putText(frame, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
    return frame


def make_transition_frames(w, h, text, n_frames=20, bg_color=(20, 20, 20)):
    """Create transition/title card frames."""
    frames = []
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.9
    thickness = 2

    # Handle multi-line text
    lines = text.split('\n')
    line_heights = []
    line_widths = []
    for line in lines:
        (tw, th), baseline = cv2.getTextSize(line, font, font_scale, thickness)
        line_heights.append(th + baseline)
        line_widths.append(tw)

    total_h = sum(line_heights) + 15 * (len(lines) - 1)

    for _ in range(n_frames):
        frame = np.full((h, w, 3), bg_color, dtype=np.uint8)
        y_start = (h - total_h) // 2
        for i, line in enumerate(lines):
            (tw, th), baseline = cv2.getTextSize(line, font, font_scale, thickness)
            x = (w - tw) // 2
            y = y_start + sum(line_heights[:i]) + 15 * i + th
            cv2.putText(frame, line, (x, y), font, font_scale, (255, 255, 255),
                        thickness, cv2.LINE_AA)
        frames.append(frame)
    return frames


def trim_clip(frames, max_seconds, fps):
    """Trim clip to max_seconds, speeding up if needed."""
    max_frames = int(max_seconds * fps)
    if len(frames) <= max_frames:
        return frames

    # For failure clips (600 steps), subsample to show key moments
    # Take first 20%, middle 20%, last 20%
    n = len(frames)
    chunk = max_frames // 3
    indices = (
        list(range(0, chunk)) +
        list(range(n // 2 - chunk // 2, n // 2 + chunk // 2)) +
        list(range(n - chunk, n))
    )
    return [frames[i] for i in indices if i < n]


def main():
    base = Path("/data/cameron/para/ood_libero/comparison_video_clips_v2")
    output_path = "/data/cameron/para/.agents/reports/project_site/media/act_vs_para_comparison.mp4"

    fps = 15
    max_clip_seconds = 5  # trim each clip to ~5 seconds

    # Define clips
    clips = [
        {
            "path": base / "act_indist/videos/task_0/ep000_success.mp4",
            "label": "ACT - Training Position (In-Distribution)",
            "status": "SUCCESS",
            "success": True,
            "transition": "Standard policy (ACT) works\nin the training setup...",
        },
        {
            "path": base / "act_ood_pos_v2/videos/task_0/ep000_fail.mp4",
            "label": "ACT - Object Shifted Right (OOD)",
            "status": "FAILURE",
            "success": False,
            "transition": "...but fails when the\nobject moves to a new position",
        },
        {
            "path": base / "act_ood_view/videos/task_0/ep000_fail.mp4",
            "label": "ACT - Camera Shifted 18deg (OOD)",
            "status": "FAILURE",
            "success": False,
            "transition": "...and fails when the\ncamera viewpoint changes",
        },
        {
            "path": base / "para_ood_pos_v2/videos/task_0/ep001_success.mp4",
            "label": "PARA - Object Shifted Right (OOD)",
            "status": "SUCCESS",
            "success": True,
            "transition": "PARA succeeds at the\nsame OOD object position",
        },
        {
            "path": base / "para_ood_view/videos/task_0/ep001_success.mp4",
            "label": "PARA - Camera Shifted 18deg (OOD)",
            "status": "SUCCESS",
            "success": True,
            "transition": "...and at the same\nOOD camera viewpoint",
        },
    ]

    # Check all files exist
    for clip in clips:
        if not clip["path"].exists():
            # Try to find an alternative
            parent = clip["path"].parent
            alternatives = sorted(parent.glob("*.mp4"))
            print(f"Missing: {clip['path']}")
            print(f"  Available: {[a.name for a in alternatives]}")
            if clip["success"]:
                success_alts = [a for a in alternatives if "success" in a.name]
                if success_alts:
                    clip["path"] = success_alts[0]
                    print(f"  Using: {clip['path']}")
            else:
                fail_alts = [a for a in alternatives if "fail" in a.name]
                if fail_alts:
                    clip["path"] = fail_alts[0]
                    print(f"  Using: {clip['path']}")

    # Load and process all clips
    all_frames = []

    # Detect frame size from first clip
    test_frames = load_video_frames(clips[0]["path"], max_frames=1)
    h, w = test_frames[0].shape[:2]
    print(f"Frame size: {w}x{h}")

    for i, clip in enumerate(clips):
        print(f"Processing clip {i + 1}/{len(clips)}: {clip['path'].name}")

        # Add transition card
        transition_frames = make_transition_frames(w, h, clip["transition"], n_frames=int(fps * 1.5))
        all_frames.extend(transition_frames)

        # Load clip frames
        frames = load_video_frames(clip["path"])
        print(f"  Loaded {len(frames)} frames")

        # Trim
        frames = trim_clip(frames, max_clip_seconds, fps)
        print(f"  Trimmed to {len(frames)} frames")

        # Add labels
        for j in range(len(frames)):
            frames[j] = add_label(frames[j], clip["label"], position="bottom")
            frames[j] = add_status_badge(frames[j], clip["status"], clip["success"])

        all_frames.extend(frames)

    # Final card
    final_frames = make_transition_frames(w, h,
        "Same backbone. Same data.\nDifferent action head.\n\nPARA: Pixel-Aligned Robot Actions",
        n_frames=int(fps * 3))
    all_frames.extend(final_frames)

    # Write video
    print(f"\nWriting {len(all_frames)} total frames to {output_path}")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
    for frame in all_frames:
        # Ensure frame matches expected size
        if frame.shape[:2] != (h, w):
            frame = cv2.resize(frame, (w, h))
        writer.write(frame)
    writer.release()

    # Re-encode to H.264
    h264_path = output_path.replace(".mp4", "_h264.mp4")
    ret = os.system(
        f'ffmpeg -y -i "{output_path}" -c:v libx264 -preset ultrafast -crf 23 '
        f'-movflags +faststart "{h264_path}" 2>/dev/null'
    )
    if ret == 0:
        os.replace(h264_path, output_path)
        print(f"Re-encoded to H.264: {output_path}")

    print(f"Done! Total duration: {len(all_frames) / fps:.1f}s")


if __name__ == "__main__":
    main()
