"""
Reusable replay utilities for IsaacLab DROID environment.

This module provides a small context that initializes the IsaacLab AppLauncher
and environment configuration once, and a function that replays a sequence of
joint positions for a given camera pose and intrinsics, recording a video that
matches the behavior in `run_random.py`.

Typical usage:

    from pathlib import Path
    import numpy as np
    from src.replay_utils import create_replay_context, replay_sequence, close_context

    ctx = create_replay_context(headless=True, scene=1)
    try:
        replay_sequence(
            ctx=ctx,
            camera_pose=camera_pose_np,
            intrinsics=intrinsics_np,
            joint_positions=joint_positions_np,
            output_path=Path("runs/2025-01-01/12-00-00/replay.mp4"),
        )
        # ... call replay_sequence again with different inputs, reusing ctx ...
    finally:
        close_context(ctx)

Notes:
- IsaacLab-dependent modules are imported only after the app is launched.
- The function constructs a fresh env for each replay to apply updated camera
  parameters, while reusing the heavier app and config objects.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import os
import argparse
import numpy as np
import torch
import gymnasium as gym
import mediapy
import cv2
import h5py
import mediapy
from tqdm import tqdm
import json

@dataclass
class ReplayContext:
    """Holds long-lived objects to avoid repeated initialization.

    Attributes
    - app_launcher: IsaacLab AppLauncher instance
    - simulation_app: The running Omniverse app (from app_launcher.app)
    - env_cfg: IsaacLab environment configuration, to be modified per replay
    - device: Device string used by the environment
    - scene: Scene identifier
    """

    app_launcher: object
    simulation_app: object
    env_cfg: object
    device: str
    scene: int


def create_replay_context(
    *,
    headless: bool = True,
    scene: int = 1,
    device: Optional[str] = None,
    num_envs: int = 1,
) -> ReplayContext:
    """Initialize IsaacLab app and environment configuration once.

    Parameters
    - headless: Whether to run without rendering to a window
    - scene: Scene identifier to pass to the env config
    - device: Optional device override (e.g., "cuda:0"). If None, use default.

    Returns
    - A `ReplayContext` that should be reused across multiple replays.
    """
    # Import inside function to avoid side effects during module import.
    from isaaclab.app import AppLauncher

    parser = argparse.ArgumentParser(description="Create IsaacLab app for DROID replays")
    AppLauncher.add_app_launcher_args(parser)
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = headless

    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app

    # Import IsaacLab-dependent modules only after the app is launched
    import src.environments  # noqa: F401
    from isaaclab_tasks.utils import parse_env_cfg

    resolved_device = device if device is not None else args_cli.device

    env_cfg = parse_env_cfg(
        "DROID",
        device=resolved_device,
        num_envs=num_envs,
        use_fabric=True,
    )

    # Persist the scene choice, to be used by replay_sequence per call if needed
    # (The caller can still mutate env_cfg before making the env.)
    _ = scene  # keep for future extension; env is currently built without explicit set_scene

    return ReplayContext(
        app_launcher=app_launcher,
        simulation_app=simulation_app,
        env_cfg=env_cfg,
        device=str(resolved_device),
        scene=scene,
    )


def replay_sequence(
    *,
    ctx: ReplayContext,
    camera_pose: np.ndarray,
    intrinsics: np.ndarray,
    joint_positions: np.ndarray,
    output_path: Path | str,
    fps: int = 15,
    width: int = 1280,
    height: int = 720,
) -> Path:
    """Replay a joint-position sequence under given camera params and save video.

    This mirrors the behavior in `run_random.py` but reuses `ctx` to avoid
    reinitializing the AppLauncher and base env configuration.

    Parameters
    - ctx: A `ReplayContext` created by `create_replay_context`
    - camera_pose: (4,4) world transform of the camera as a homogeneous matrix
    - intrinsics: (3,3) camera intrinsic matrix
    - joint_positions: (T, N) numpy array of joint configurations
    - output_path: Path to the output mp4 file
    - fps: Video framerate
    - width, height: Camera resolution

    Returns
    - The resolved `Path` to the saved video.
    """
    # Late imports for IsaacLab-dependent modules
    from custom.utils import overwrite_camera_pose, overwrite_joint_positions, get_camera_pose
    import isaaclab.sim as sim_utils

    # 1) Apply camera pose and initial joint positions to the env config
    overwrite_camera_pose(ctx.env_cfg, np.asarray(camera_pose))

    # Ensure we have at least one set of joints for initial state
    if joint_positions.shape[0] < 1:
        raise ValueError("joint_positions must have at least one row")

    overwrite_joint_positions(ctx.env_cfg, np.asarray(joint_positions[0]).tolist())

    # Configure external camera intrinsics; convert to flat list for API
    intrinsic_list: list[float] = np.asarray(intrinsics, dtype=float).flatten().tolist()
    ctx.env_cfg.scene.external_cam1.spawn = sim_utils.PinholeCameraCfg.from_intrinsic_matrix(
        intrinsic_matrix=intrinsic_list,
        width=width,
        height=height,
        focal_length=2.1,
        focus_distance=28.0,
    )

    # 2) Build a fresh env for this replay using the updated config
    env = gym.make("DROID", cfg=ctx.env_cfg)
    try:
        # Reset twice to ensure materials and sensors are ready
        obs, _ = env.reset()
        obs, _ = env.reset()

        # Access camera handle if needed for pose/projection (kept for parity with run_random)
        # cam = env.unwrapped.scene["external_cam"]
        # _ = get_camera_pose(
        #     np.array(cam.cfg.offset.pos).reshape(-1),
        #     np.array(cam.cfg.offset.rot).reshape(-1),
        # )

        # Collect frames while stepping through provided joint positions
        frames: list[np.ndarray] = []

        with torch.no_grad():
            robot = env.unwrapped.scene["robot"]
            for idx in range(joint_positions.shape[0]):
                # Replace the arm joints while keeping the rest from the env
                env_joint_position = robot.data.joint_pos[0]
                desired_joint_positions = torch.concatenate(
                    [
                        torch.from_numpy(joint_positions[idx][:-1])
                        .to(torch.float32)
                        .to(robot.device),
                        env_joint_position[7:],
                    ]
                )
                # desired_joint_positions[7] = joint_positions[idx][-1] / 0.6 * (np.pi / 4)
                robot.write_joint_position_to_sim(desired_joint_positions)

                # Apply without a full step by resetting to current scene state
                obs, _ = env.unwrapped.reset_to(env.unwrapped.scene.get_state(), None)

                # Capture the external camera frame for policy observation
                frame_np = obs["policy"]["external_cam1"].cpu().numpy()[0]
                frames.append(frame_np)

        # 3) Write video
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        mediapy.write_video(output_path, frames, fps=fps)
        print(f"Saved video to {output_path}")
        return output_path
    except Exception as e:
        raise e
    finally:
        env.close()

def replay_sequence_multiple_cams(
    *,
    ctx: ReplayContext,
    camera_params: dict[str, dict[str, np.ndarray | str]],
    joint_positions: np.ndarray,
    output_path: Path | str,
    fps: int = 15,
    width: int = 1280,
    height: int = 720,
    real_obs: Optional[dict[str, np.ndarray]] = None,
    img_path: Optional[str] = None,
) -> bool:
    """Replay a joint-position sequence under given camera params and save video.

    This mirrors the behavior in `run_random.py` but reuses `ctx` to avoid
    reinitializing the AppLauncher and base env configuration.

    Parameters
    - ctx: A `ReplayContext` created by `create_replay_context`
    - camera_params: Dictionary of camera parameters, with keys as camera names and values as dictionaries containing:
      - camera_pose: (4,4) world transform of the camera as a homogeneous matrix
      - intrinsics: (3,3) camera intrinsic matrix
    - joint_positions: (T, N) numpy array of joint configurations
    - output_path: Path to the output mp4 file
    - fps: Video framerate
    - width, height: Camera resolution
    - real_obs: Optional dictionary of real observations to use for the replay.
      if provided, mask the real observations with the replay observations.
    - img_path: Optional path to save the replay observations.
    Returns
    - The resolved `Path` to the saved video.
    """
    # Late imports for IsaacLab-dependent modules
    from custom.utils import overwrite_camera_pose, overwrite_joint_positions, get_camera_pose
    import isaaclab.sim as sim_utils
    import omni.usd, omni.timeline
    import gc

    # 1) Apply camera pose and initial joint positions to the env config
    for cam_name, cam_params in camera_params.items():
        if cam_params["camera_pose"] is not None: # wrist camera pose is not available
            overwrite_camera_pose(ctx.env_cfg, np.asarray(cam_params["camera_pose"]), cam_name)
        # Configure external camera intrinsics; convert to flat list for API
        intrinsic_list: list[float] = np.asarray(cam_params["intrinsics"], dtype=float).flatten().tolist()
        getattr(ctx.env_cfg.scene, cam_name).spawn = sim_utils.PinholeCameraCfg.from_intrinsic_matrix(
            intrinsic_matrix=intrinsic_list,
            width=width,
            height=height,
            focal_length=2.1,
            focus_distance=28.0,
        )

    # Ensure we have at least one set of joints for initial state
    if joint_positions.shape[0] < 1:
        raise ValueError("joint_positions must have at least one row")

    overwrite_joint_positions(ctx.env_cfg, np.asarray(joint_positions[0][:-1]).tolist())

    if img_path is not None:
        sim_masked_save_path = os.path.join(img_path, "rgb_sim_rollout_masked")
        for cam in camera_params.keys():
            os.makedirs(os.path.join(sim_masked_save_path, cam), exist_ok=True)
        if real_obs is not None:
            real_masked_save_path = os.path.join(img_path, "rgb_real_rollout_masked")
            for cam in camera_params.keys():
                os.makedirs(os.path.join(real_masked_save_path, cam), exist_ok=True)

    # 2) Build a fresh env for this replay using the updated config
    env = gym.make("DROID", cfg=ctx.env_cfg)
    try:
        # Reset twice to ensure materials and sensors are ready
        obs, _ = env.reset()
        obs, _ = env.reset()

        # Collect frames while stepping through provided joint positions
        frames: list[np.ndarray] = []

        with torch.no_grad():
            robot = env.unwrapped.scene["robot"]
            for idx in range(joint_positions.shape[0]):
                # Replace the arm joints while keeping the rest from the env
                env_joint_position = robot.data.joint_pos[0]
                desired_joint_positions = torch.concatenate(
                    [
                        torch.from_numpy(joint_positions[idx][:-2])
                        .to(torch.float32)
                        .to(robot.device),
                        env_joint_position[7:],
                    ]
                )
                robot.write_joint_position_to_sim(desired_joint_positions)

                # Apply without a full step by resetting to current scene state
                obs, _ = env.unwrapped.reset_to(env.unwrapped.scene.get_state(), None)

                if joint_positions[idx][-1] != 0:
                    action = torch.concatenate([
                        torch.from_numpy(joint_positions[idx][:-2]).to(robot.device),
                        torch.tensor([joint_positions[idx][-1]]).to(robot.device),
                    ], dim=-1).unsqueeze(0)
                    env.step(action)

                # Capture the external camera frame for policy observation
                frame_np_multi_cam = []
                for cam in camera_params.keys():
                    img = obs["policy"][cam].cpu().numpy()[0]
                    if img_path is not None:
                        cv2.imwrite(os.path.join(sim_masked_save_path, cam, f"{idx:010d}.jpg"), img)
                        if real_obs is not None:
                            real_img = cv2.resize(real_obs[cam][idx], (width, height), interpolation=cv2.INTER_LINEAR)
                            real_img_masked = real_img * (img.sum(axis=2, keepdims=True) > 0)
                            cv2.imwrite(os.path.join(real_masked_save_path, cam, f"{idx:010d}.jpg"), real_img_masked)
                    if os.environ.get("DEBUG", "0") == "1":
                        cam_quality = camera_params[cam].get("extrinsic_quality", None) 
                        if cam_quality is not None and cam_quality["metric"] is not None:
                            img = cv2.putText(img, f"{cam_quality['metric']}: {cam_quality['quality']} ({cam_quality['source']})", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2)
                    frame_np_multi_cam.append(img)
                frame_np_multi_cam = np.concatenate(frame_np_multi_cam, axis=1)

                if real_obs is not None:
                    frame_np_multi_cam_real = []
                    for cam in camera_params.keys():
                        frame_np_multi_cam_real.append(cv2.resize(real_obs[cam][idx], (width, height), interpolation=cv2.INTER_LINEAR))
                    frame_np_multi_cam_real = np.concatenate(frame_np_multi_cam_real, axis=1)
                    frame_np_multi_cam_real_masked = frame_np_multi_cam_real * (frame_np_multi_cam.sum(axis=2, keepdims=True) > 0)
                    frame_np_multi_cam_real_masked[np.all(frame_np_multi_cam_real_masked == 0, axis=2)] = [255, 255, 255] # make the background white just for visualization
                    frame_np_multi_cam[np.all(frame_np_multi_cam == 0, axis=2)] = [255, 255, 255] # make the background white just for visualization
                    frame_np_multi_cam = np.concatenate([frame_np_multi_cam_real, frame_np_multi_cam, frame_np_multi_cam_real_masked], axis=0)

                frames.append(cv2.resize(frame_np_multi_cam, (frame_np_multi_cam.shape[1]//4, frame_np_multi_cam.shape[0]//4), interpolation=cv2.INTER_LINEAR))

        # 3) Write video
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        mediapy.write_video(output_path, frames, fps=fps)
        return True
    except Exception as e:
        print(e)
        return False
    finally:
        # try:
        #     omni.timeline.get_timeline().stop()
        # except Exception as e:
        #     print(e)
        #     pass
        env.close()
        del env
        gc.collect()
        omni.usd.get_context().new_stage()

def replay_sequence_multiple_cams_v2(
    *,
    ctx: ReplayContext,
    camera_params,
    joint_positions: np.ndarray,
    output_path,
    padding_mask=None,
    fps = 15,
    width = 1280,
    height = 720,
    real_obs = None,
    img_paths = None,
) -> list[bool]:
    """Replay a joint-position sequence under given camera params and save video.

    This mirrors the behavior in `run_random.py` but reuses `ctx` to avoid
    reinitializing the AppLauncher and base env configuration.

    Parameters
    - ctx: A `ReplayContext` created by `create_replay_context`
    - camera_params: Dictionary of camera parameters, with keys as camera names and values as dictionaries containing:
      - camera_pose: (4,4) world transform of the camera as a homogeneous matrix
      - intrinsics: (3,3) camera intrinsic matrix
    - joint_positions: (T, N) numpy array of joint configurations
    - output_path: Path to the output mp4 file
    - fps: Video framerate
    - width, height: Camera resolution
    - real_obs: Optional dictionary of real observations to use for the replay.
      if provided, mask the real observations with the replay observations.
    - img_path: Optional path to save the replay observations.
    Returns
    - The resolved `Path` to the saved video.
    """
    # Late imports for IsaacLab-dependent modules
    from custom.utils import world_pose_to_opengl_pose
    # import isaaclab.sim as sim_utils
    import omni.usd # , omni.timeline
    import gc

    # Normalize inputs to batched form
    jp = torch.from_numpy(joint_positions).to(torch.float32).to(ctx.device)
    if jp.ndim == 2:
        jp = jp[None, ...]  # (1, T, N)
    if jp.shape[1] < 1:
        raise ValueError("joint_positions must have at least one timestep")
    B, T, D = jp.shape

    # padding mask: True where valid; if None, all valid
    if padding_mask is None:
        padding_mask = torch.ones((B, T), dtype=bool).to(ctx.device)
    else:
        padding_mask = torch.from_numpy(padding_mask).to(torch.bool).to(ctx.device)
        assert padding_mask.shape == (B, T), "padding_mask must have shape (batch, time)"


    output_paths = [Path(p) for p in output_path]
    assert len(output_paths) == B, "output_path list must match batch size"

    if img_paths is None:
        img_paths = [None for _ in range(B)]

    # real_obs can be dict (broadcast) or list per env
    if real_obs is None:
        real_obs_list = [None for _ in range(B)]
    else:
        real_obs_list = real_obs
        assert len(real_obs_list) == B, "real_obs list must match batch size"

    # Prepare directories if saving images
    sim_masked_save_paths = [None for _ in range(B)]
    real_masked_save_paths = [None for _ in range(B)]
    for b in range(B):
        if img_paths[b] is not None:
            sim_masked = os.path.join(img_paths[b], "rgb_sim_rollout_masked")
            sim_masked_save_paths[b] = sim_masked
            for cam in camera_params.keys():
                os.makedirs(os.path.join(sim_masked, cam), exist_ok=True)
            if real_obs_list[b] is not None:
                real_masked = os.path.join(img_paths[b], "rgb_real_rollout_masked")
                real_masked_save_paths[b] = real_masked
                for cam in camera_params.keys():
                    os.makedirs(os.path.join(real_masked, cam), exist_ok=True)

    # 2) Build a fresh env for this replay using the updated config
    env = gym.make("DROID", cfg=ctx.env_cfg)
    try:
        # Reset twice to ensure materials and sensors are ready
        obs, _ = env.reset()
        obs, _ = env.reset()

        # Configure cameras per env
        # Assume same set of camera names across envs; take from first
        cam_names = list(camera_params.keys())
        for cam_name in cam_names:
            cam = env.unwrapped.scene[cam_name]
            cam.set_intrinsic_matrices(
                torch.from_numpy(camera_params[cam_name]["intrinsics"]),
                focal_length=2.1,
            )
            if camera_params[cam_name]["camera_pose"] is not None:
                base_positions, body_orientations = env.unwrapped.scene["robot"].data.body_pos_w[:, 0], env.unwrapped.scene["robot"].data.body_quat_w[:, 0]
                # NOTE: the orientations seem to be always in the default openGL convention, so we just do camera_pose + base_positions
                camera_params[cam_name]["camera_pose"][..., :3, 3] = camera_params[cam_name]["camera_pose"][..., :3, 3] + base_positions.cpu().numpy()
                positions, orientations, convention = world_pose_to_opengl_pose(torch.from_numpy(camera_params[cam_name]["camera_pose"]))
                cam.set_world_poses(
                    positions=positions,
                    orientations=orientations,
                    env_ids=None,
                    convention=convention,
                )

        # Collect frames while stepping through provided joint positions
        frames_per_env: list[list[np.ndarray]] = [[] for _ in range(B)]

        with torch.no_grad():
            robot = env.unwrapped.scene["robot"]
            device = robot.data.joint_pos.device
            for t in tqdm(range(T), desc="Replaying sequence", position=1, leave=True):
                # Replace the arm joints while keeping the rest from the env
                env_joint_positions = robot.data.joint_pos.clone()
                desired_joint_positions = torch.concatenate([
                    jp[:, t, :-2],
                    env_joint_positions[:, 7:],
                ], dim=-1)
                desired_joint_positions = torch.where(padding_mask[:, t, None], desired_joint_positions, env_joint_positions)
                robot.write_joint_position_to_sim(desired_joint_positions)

                # Apply without a full step by resetting to current scene state
                obs, _ = env.unwrapped.reset_to(env.unwrapped.scene.get_state(), None)

                # Build actions for gripper actuation where requested
                if any(jp[:, t, -1] != 0):
                    actions = torch.concatenate([
                        jp[:, t, :-2],
                        jp[:, t, -1:],
                    ], dim=-1)
                    env.step(actions)

                # Capture frames per env
                for b in range(B):
                    if not padding_mask[b, t]:
                        continue
                    frame_np_multi_cam = []
                    for cam in cam_names:
                        img = obs["policy"][cam].cpu().numpy()[b]
                        if img_paths[b] is not None:
                            cv2.imwrite(os.path.join(sim_masked_save_paths[b], cam, f"{t:010d}.jpg"), img)
                            if real_obs_list[b] is not None:
                                real_img = cv2.resize(real_obs_list[b][cam][t], (width, height), interpolation=cv2.INTER_LINEAR)
                                real_img_masked = real_img * (img.sum(axis=2, keepdims=True) > 0)
                                cv2.imwrite(os.path.join(real_masked_save_paths[b], cam, f"{t:010d}.jpg"), real_img_masked)
                        if os.environ.get("DEBUG", "0") == "1":
                            cam_quality = camera_params[cam]["extrinsic_quality"][b]
                            if cam_quality is not None and cam_quality["metric"] is not None:
                                img = cv2.putText(img, f"{cam_quality['metric']}: {cam_quality['quality']} ({cam_quality['source']})", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2)
                        frame_np_multi_cam.append(img)
                    frame_np_multi_cam = np.concatenate(frame_np_multi_cam, axis=1)

                    if real_obs_list[b] is not None:
                        frame_np_multi_cam_real = []
                        for cam in cam_names:
                            frame_np_multi_cam_real.append(cv2.resize(real_obs_list[b][cam][t], (width, height), interpolation=cv2.INTER_LINEAR))
                        frame_np_multi_cam_real = np.concatenate(frame_np_multi_cam_real, axis=1)
                        frame_np_multi_cam_real_masked = frame_np_multi_cam_real * (frame_np_multi_cam.sum(axis=2, keepdims=True) > 0)
                        frame_np_multi_cam_real_masked[np.all(frame_np_multi_cam_real_masked == 0, axis=2)] = [255, 255, 255]
                        frame_np_multi_cam[np.all(frame_np_multi_cam == 0, axis=2)] = [255, 255, 255]
                        frame_np_multi_cam = np.concatenate([frame_np_multi_cam_real, frame_np_multi_cam, frame_np_multi_cam_real_masked], axis=0)

                    frames_per_env[b].append(cv2.resize(frame_np_multi_cam, (frame_np_multi_cam.shape[1]//4, frame_np_multi_cam.shape[0]//4), interpolation=cv2.INTER_LINEAR))

        # 3) Write videos per env
        results = [False] * B
        for b in range(B):
            output_paths[b].parent.mkdir(parents=True, exist_ok=True)
            mediapy.write_video(output_paths[b], frames_per_env[b], fps=fps)
            results[b] = True
        return results
    except Exception as e:
        print(e)
        return [False] * B
    finally:
        # try:
        #     omni.timeline.get_timeline().stop()
        # except Exception as e:
        #     print(e)
        #     pass
        env.close()
        del env
        gc.collect()
        omni.usd.get_context().new_stage()

def close_context(ctx: ReplayContext) -> None:
    """Close the Omniverse app held by the context."""
    try:
        ctx.simulation_app.close()
    except Exception:
        # Best-effort shutdown
        pass


def is_valid_extrinsic(camera_params, IoU_threshold=0.8, reprojection_error_threshold=5):
    results = []
    for cam_name, cam_params in camera_params.items():
        if cam_params["extrinsic_quality"] is not None:
            if cam_params["extrinsic_quality"]["metric"] == "IoU":
                results.append(cam_params["extrinsic_quality"]["quality"] > IoU_threshold)
            elif cam_params["extrinsic_quality"]["metric"] == "Reprojection_error":
                results.append(cam_params["extrinsic_quality"]["quality"] < reprojection_error_threshold)
            elif cam_params["extrinsic_quality"]["metric"] == None:
                results.append(True) # wrist camera pose has no extrinsic quality
            else:
                results.append(False)
    return all(results)

def resize_intrinsics(intrinsics, src_size, dst_size):
    fx, fy, cx, cy = intrinsics[0,0], intrinsics[1,1], intrinsics[0,2], intrinsics[1,2]
    new_fx = fx * dst_size[0] / src_size[0]
    new_fy = fy * dst_size[1] / src_size[1]
    new_cx = cx * dst_size[0] / src_size[0]
    new_cy = cy * dst_size[1] / src_size[1]
    return np.array([[new_fx, 0, new_cx],
                    [0, new_fy, new_cy],
                    [0, 0, 1]])

def load_episode_data(trajectory_path):
    with h5py.File(trajectory_path, "r") as f:
        joint_positions = np.concatenate([
            f["observation"]["robot_state"]["joint_positions"][()],
            f["observation"]["robot_state"]["gripper_position"][()][:, None],
        ], axis=-1)
        gripper_actions = np.array(f["action"]["gripper_position"][()][:, None], dtype=np.float32)
        cartesian_poses = np.array(f["observation"]["robot_state"]["cartesian_position"][()], dtype=np.float32)
        return joint_positions, gripper_actions, cartesian_poses
        
def load_episode_images(videos_path, camera_params):
    images = {}
    for cam_name, cam_params in camera_params.items():
        video_path = os.path.join(videos_path, f"{cam_params['camera_serial']}.mp4")
        images[cam_name] = np.array(mediapy.read_video(video_path))
    return images

def sort_episode_length(metadata_path, episode_dict):
    episode_id_to_length = {}
    with open(metadata_path, "r") as f:
        for line in f:
            data = json.loads(line)
            key = list(data.keys())[0]
            if data[key] is not None:
                episode_id_to_length[
                    data[key]['uuid']
                ] = data[key]['trajectory_length']
    sorted_episode_dict = dict(
        sorted(
            episode_dict.items(),
            key=lambda item: episode_id_to_length.get(item[0], float('inf'))
        )
    )
    return sorted_episode_dict