import os
import glob
import numpy as np
import mediapy as media
import json
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import cv2

from custom.config import path_to_droid_repo, local_dataset_path, replay_img_save_path, \
    CAM2IMAGE_MAP, IMAGE2SCENE_CAM_MAP

def process_episode(episode_data):
    episode_id, episode_metadata = episode_data
    episode_rel_path = episode_metadata['relative_path']
    episode_ori_img_path = os.path.join(local_dataset_path, episode_rel_path, "recordings", "MP4")
    episode_replay_img_path = os.path.join(replay_img_save_path, episode_rel_path, "rgb_sim_rollout_masked")

    camera_name_to_serial = {IMAGE2SCENE_CAM_MAP[CAM2IMAGE_MAP[k]]: v for k, v in camera_serials[episode_id].items()}
    images = {}
    cams = list(IMAGE2SCENE_CAM_MAP.values())
    
    vis_save_path = os.path.join("debug", "vis", f"{episode_id}.mp4")
    if os.path.exists(vis_save_path):
        print(f"Episode {episode_id} already processed")
        return None

    try:
        # First check sequence lengths without loading full data
        first_cam_serial = camera_name_to_serial[cams[0]]
        ori_video = media.read_video(os.path.join(episode_ori_img_path, f"{first_cam_serial}.mp4"))
        h, w = ori_video.shape[1:3]
        ori_length = len(ori_video)
        replay_length = len(glob.glob(os.path.join(episode_replay_img_path, cams[0], "*.jpg")))
        
        if ori_length > replay_length:
            print(f"Episode {episode_id}: replayed frames ({replay_length}) is shorter than original frames ({ori_length})")
            return episode_id

        episode_length = min(ori_length, replay_length)
        sample_indices = np.linspace(0, episode_length-1, 5, dtype=int)

        # Only load the frames we need
        for cam_name, cam_serial in camera_name_to_serial.items():
            images[cam_name] = {}
            ori_video = media.read_video(os.path.join(episode_ori_img_path, f"{cam_serial}.mp4"))
            images[cam_name]["ori_obs"] = np.array([cv2.resize(ori_video[i], (w, h)) for i in sample_indices])
            
            replay_paths = sorted(glob.glob(os.path.join(episode_replay_img_path, cam_name, "*.jpg")))
            images[cam_name]["replay_obs"] = np.array([
                cv2.resize(media.read_image(replay_paths[i]), (w, h)) for i in sample_indices
            ])

        ori_obs_mul = np.concatenate([images[cams[0]]["ori_obs"], images[cams[1]]["ori_obs"], images[cams[2]]["ori_obs"]], axis=2)
        replay_obs_mul = np.concatenate([images[cams[0]]["replay_obs"], images[cams[1]]["replay_obs"], images[cams[2]]["replay_obs"]], axis=2)
        ori_obs_mul_masked = ori_obs_mul * (replay_obs_mul.sum(axis=3, keepdims=True) > 0)
        vis_canvas = np.concatenate([
            ori_obs_mul,
            replay_obs_mul,
            ori_obs_mul_masked,
        ], axis=1)
        
        # Downsample the video by resizing each frame to half size
        h, w = vis_canvas.shape[1:3]
        downsample_ratio = 4
        downsampled_canvas = np.zeros((vis_canvas.shape[0], h//downsample_ratio, w//downsample_ratio, 3), dtype=vis_canvas.dtype)
        for i in range(len(vis_canvas)):
            downsampled_canvas[i] = media.resize_image(vis_canvas[i], (h//downsample_ratio, w//downsample_ratio))

        # save the downsampled video in debug/episode_id
        os.makedirs(os.path.dirname(vis_save_path), exist_ok=True)
        media.write_video(vis_save_path, downsampled_canvas, fps=5)
        
        return None
    except Exception as e:
        print(f"Error processing episode {episode_id}: {e}")
        return episode_id

with open(os.path.join(path_to_droid_repo, "episodes_with_good_extrinsics.json"), "r") as f:
    episodes_with_good_extrinsics = json.load(f)

failed_at_first_try = json.load(open(os.path.join(path_to_droid_repo, "need_further_check.json")))

episodes_with_good_extrinsics = {k: v for k, v in episodes_with_good_extrinsics.items() if k in failed_at_first_try}

with open(os.path.join(path_to_droid_repo, "camera_serials.json"), "r") as f:
    camera_serials = json.load(f)

# Process episodes in parallel with progress bar
with mp.Pool(processes=1) as pool:
    need_further_check = list(filter(None, tqdm(
        pool.imap(process_episode, episodes_with_good_extrinsics.items()),
        total=len(episodes_with_good_extrinsics),
        desc="Processing episodes"
    )))

with open(os.path.join(path_to_droid_repo, "need_further_check_new.json"), "w") as f:
    json.dump(need_further_check, f, indent=4)
print(f"Need further check: {len(need_further_check)}")