"""
Load a DROID episode from HuggingFace (cadene/droid_1.0.1 LeRobot format).
Downloads parquet data + video for a given episode, extracts frames,
and returns joint positions, camera extrinsics, and images.
"""

import os
import json
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download

REPO_ID = "cadene/droid_1.0.1"
CACHE_DIR = Path(__file__).parent / "cache"


def download_episode(episode_index: int, cache_dir: Path = CACHE_DIR) -> dict:
    """Download parquet + videos for one episode. Returns local paths."""
    cache_dir.mkdir(parents=True, exist_ok=True)

    chunk = episode_index // 1000
    ep_str = f"episode_{episode_index:06d}"
    chunk_str = f"chunk-{chunk:03d}"

    # Download parquet
    parquet_path = hf_hub_download(
        REPO_ID,
        f"data/{chunk_str}/{ep_str}.parquet",
        repo_type="dataset",
        cache_dir=str(cache_dir),
    )

    # Download exterior_1_left video
    ext1_path = hf_hub_download(
        REPO_ID,
        f"videos/{chunk_str}/observation.images.exterior_1_left/{ep_str}.mp4",
        repo_type="dataset",
        cache_dir=str(cache_dir),
    )

    # Download exterior_2_left video
    ext2_path = hf_hub_download(
        REPO_ID,
        f"videos/{chunk_str}/observation.images.exterior_2_left/{ep_str}.mp4",
        repo_type="dataset",
        cache_dir=str(cache_dir),
    )

    return {
        "parquet": parquet_path,
        "ext1_video": ext1_path,
        "ext2_video": ext2_path,
    }


def decode_video_frames(video_path: str) -> np.ndarray:
    """Decode all frames from an mp4 video. Returns (T, H, W, 3) uint8 array."""
    import av

    container = av.open(video_path)
    frames = []
    for frame in container.decode(video=0):
        img = frame.to_ndarray(format="rgb24")
        frames.append(img)
    container.close()
    return np.stack(frames)


def load_episode(episode_index: int, cache_dir: Path = CACHE_DIR) -> dict:
    """
    Load a full DROID episode: images, joint positions, camera extrinsics.

    Returns dict with:
        - ext1_images: (T, H, W, 3) uint8
        - ext2_images: (T, H, W, 3) uint8
        - joint_positions: (T, 7) float
        - gripper_positions: (T,) float
        - cartesian_positions: (T, 6) float
        - ext1_extrinsics: (6,) float [x,y,z,roll,pitch,yaw] camera-to-robot-base
        - ext2_extrinsics: (6,) float
        - wrist_extrinsics: (6,) float
    """
    import pandas as pd

    paths = download_episode(episode_index, cache_dir)

    # Load parquet
    df = pd.read_parquet(paths["parquet"])

    # Extract joint positions (7 DOF)
    joint_positions = np.stack(df["observation.state.joint_position"].values)
    gripper_positions = df["observation.state.gripper_position"].values.astype(np.float32)
    cartesian_positions = np.stack(df["observation.state.cartesian_position"].values)

    # Camera extrinsics are constant per episode - take first row
    ext1_extrinsics = np.array(df["camera_extrinsics.exterior_1_left"].iloc[0], dtype=np.float32)
    ext2_extrinsics = np.array(df["camera_extrinsics.exterior_2_left"].iloc[0], dtype=np.float32)
    wrist_extrinsics = np.array(df["camera_extrinsics.wrist_left"].iloc[0], dtype=np.float32)

    # Decode video frames
    ext1_images = decode_video_frames(paths["ext1_video"])
    ext2_images = decode_video_frames(paths["ext2_video"])

    # Verify frame counts match
    T = len(df)
    assert ext1_images.shape[0] == T, f"ext1 frames {ext1_images.shape[0]} != parquet rows {T}"
    assert ext2_images.shape[0] == T, f"ext2 frames {ext2_images.shape[0]} != parquet rows {T}"

    return {
        "ext1_images": ext1_images,
        "ext2_images": ext2_images,
        "joint_positions": joint_positions,
        "gripper_positions": gripper_positions,
        "cartesian_positions": cartesian_positions,
        "ext1_extrinsics": ext1_extrinsics,
        "ext2_extrinsics": ext2_extrinsics,
        "wrist_extrinsics": wrist_extrinsics,
        "language_instruction": df["language_instruction"].iloc[0] if "language_instruction" in df.columns else "",
        "num_frames": T,
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--episode", type=int, default=0)
    args = parser.parse_args()

    print(f"Loading episode {args.episode}...")
    data = load_episode(args.episode)

    print(f"  Instruction: {data['language_instruction']}")
    print(f"  Frames: {data['num_frames']}")
    print(f"  ext1 images: {data['ext1_images'].shape}")
    print(f"  ext2 images: {data['ext2_images'].shape}")
    print(f"  joint_positions: {data['joint_positions'].shape}")
    print(f"  ext1_extrinsics: {data['ext1_extrinsics']}")
    print(f"  ext2_extrinsics: {data['ext2_extrinsics']}")
    print(f"  Joint pos range: [{data['joint_positions'].min():.3f}, {data['joint_positions'].max():.3f}]")
    print(f"  First joint pos: {data['joint_positions'][0]}")
