from pathlib import Path
import numpy as np
import json
from tqdm import tqdm
import mediapy
import cv2
from scipy.spatial.transform import Rotation as R
import os
from copy import deepcopy

from custom.replay_utils import create_replay_context, replay_sequence_multiple_cams_v2, close_context, is_valid_extrinsic, resize_intrinsics, load_episode_data, load_episode_images, sort_episode_length
from custom.config import path_to_droid_repo, local_dataset_path, replay_img_save_path, IMAGE2SCENE_CAM_MAP, DEFAULT_CAMERA_INTRINSICS, DEFAULT_CAMERA_INTRINSICS_WRIST, CAM2IMAGE_MAP


IoU_threshold = 0.75
reprojection_error_threshold = 5
skip_real_obs = True
# how many envs to run in parallel
N = 1 # TODO: change to 24

# failed_at_first_try = json.load(open(os.path.join(path_to_droid_repo, "need_further_check.json")))

# load the extrinsics
cam2base_extrinsics_path = f"{path_to_droid_repo}/cam2base_extrinsic_superset.json"
with open(cam2base_extrinsics_path, "r") as f:
    cam2base_extrinsics = json.load(f)

episodes_with_good_extrinsics = {}
for episode_id, episode_metadata in cam2base_extrinsics.items():
    if "failure" in episode_metadata['relative_path']: # remove failure trajectories
        continue
    episode_IoUs = []
    reprojection_errors = []
    for key, value in episode_metadata.items():
        if "metric_type" in key and episode_metadata[key] == "IoU":
            episode_IoUs.append(episode_metadata[key.replace("metric_type", "quality_metric")])
        if "metric_type" in key and episode_metadata[key] == "Reprojection_error":
            reprojection_errors.append(episode_metadata[key.replace("metric_type", "quality_metric")])
    if (len(episode_IoUs)+len(reprojection_errors)) >= 2 and all(x > IoU_threshold for x in episode_IoUs) and all(x < reprojection_error_threshold for x in reprojection_errors):
        episodes_with_good_extrinsics[episode_id] = episode_metadata
print("Number of episodes with good extrinsics: ", len(episodes_with_good_extrinsics))

episodes_with_good_extrinsics = sort_episode_length(f"{path_to_droid_repo}/metadata_all_new.jsonl", episodes_with_good_extrinsics)

# load the intrinsics
intrinsics_path = f"{path_to_droid_repo}/intrinsics_updated.json"
with open(intrinsics_path, "r") as f:
    intrinsics = json.load(f)

# load mapping from episode ID to path, then invert
episode_id_to_path_path = f"{path_to_droid_repo}/episode_id_to_path.json"
with open(episode_id_to_path_path, "r") as f:
    episode_id_to_path = json.load(f)
episode_path_to_id = {v: k for k, v in episode_id_to_path.items()}

# load camera serials
camera_serials_path = f"{path_to_droid_repo}/camera_serials.json"
with open(camera_serials_path, "r") as f:
    camera_serials = json.load(f)


# create the replay context
ctx = create_replay_context(headless=True, scene=1, num_envs=N)
try:
    failed_episodes = []
    
    # Simple batch manager for grouping inputs for parallel replay
    def _flush_batch(batch_store):
        if len(batch_store["episode_ids"]) == 0:
            return
        # Pad joint positions to the maximum length in the batch
        seq_lengths = [arr.shape[0] for arr in batch_store["joint_positions"]]
        max_len = max(seq_lengths)
        feat_dim = batch_store["joint_positions"][0].shape[1]
        B = len(batch_store["joint_positions"]) 
        padded = np.zeros((B, max_len, feat_dim), dtype=batch_store["joint_positions"][0].dtype)
        mask = np.zeros((B, max_len), dtype=np.bool_)  # True for valid timesteps
        for b_idx, seq in enumerate(batch_store["joint_positions"]):
            L = seq.shape[0]
            padded[b_idx, :L] = seq
            mask[b_idx, :L] = True
        # batchify the camera params
        batched_camera_params = {cam: {
            "camera_serial": [],
            "camera_pose": [],
            "intrinsics": [],
            "extrinsic_quality": [],
        } for cam in IMAGE2SCENE_CAM_MAP.values()}
        for b_idx, cam_params in enumerate(batch_store["camera_params"]):
            for cam, v in cam_params.items():
                batched_camera_params[cam]["camera_serial"].append(v["camera_serial"])
                batched_camera_params[cam]["camera_pose"].append(v["camera_pose"])
                batched_camera_params[cam]["intrinsics"].append(v["intrinsics"])
                batched_camera_params[cam]["extrinsic_quality"].append(v["extrinsic_quality"])
        for k, v in batched_camera_params.items():
            # the camera_pose for each view should either be all None or all not None
            if v["camera_pose"][0] is not None:
                v["camera_pose"] = np.stack(v["camera_pose"])
            else:
                v["camera_pose"] = None
            v["intrinsics"] = np.stack(v["intrinsics"])
        # # batchify the real_obs NOTE: seems more convenient to leave as a list of dicts
        # batched_real_obs = {cam: [] for cam in IMAGE2SCENE_CAM_MAP.values()}
        # for b_idx, real_obs in enumerate(batch_store["real_obs"]):
        #     for cam in IMAGE2SCENE_CAM_MAP.values():
        #         batched_real_obs[cam].append(real_obs[cam]) # each with shape (T, H, W, 3)

        # Call replay with batched inputs
        success_list = replay_sequence_multiple_cams_v2(
            ctx=ctx,
            camera_params=batched_camera_params,
            joint_positions=padded,
            padding_mask=mask,
            output_path=batch_store["video_paths"],
            img_paths=batch_store["img_paths"],
            real_obs=batch_store["real_obs"] if not skip_real_obs else None,
        )
        # Handle success/failure per episode if available
        for ep_id, ok in zip(batch_store["episode_ids"], success_list):
            if not bool(ok):
                failed_episodes.append(ep_id)
        # Reset batch
        batch_store["camera_params"].clear()
        batch_store["joint_positions"].clear()
        batch_store["real_obs"].clear()
        batch_store["video_paths"].clear()
        batch_store["img_paths"].clear()
        batch_store["episode_ids"].clear()

    batch_store = {
        "camera_params": [],
        "joint_positions": [],
        "real_obs": [],
        "video_paths": [],
        "img_paths": [],
        "episode_ids": [],
    }
    for episode_id, episode_metadata in tqdm(episodes_with_good_extrinsics.items(), position=0, leave=True):
        try:
            # if episode_id not in failed_at_first_try:
            #     continue
            if not os.path.exists(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5")):
                failed_episodes.append(episode_id)
                continue
            # tmp_img_path = os.path.join(replay_img_save_path, episode_metadata["relative_path"], "rgb_sim_rollout_masked")
            # if os.path.exists(tmp_img_path) and len(os.listdir(tmp_img_path)) > 0:
            #     print(f"Skipping episode {episode_id} because it already exists")
            #     continue
            camera_params = {}
            if intrinsics.get(episode_id, None) is None and camera_serials.get(episode_id, None) is not None:
                # there is a small chance that the intrinsics are not available for some episodes
                # in this case, we use the default intrinsics
                intrinsics[episode_id] = {}
                for cam_name, cam_serial in camera_serials[episode_id].items():
                    intrinsics[episode_id][cam_serial] = {
                        "cameraMatrix": [DEFAULT_CAMERA_INTRINSICS[0, 0], DEFAULT_CAMERA_INTRINSICS[0, 2], DEFAULT_CAMERA_INTRINSICS[1, 1], DEFAULT_CAMERA_INTRINSICS[1, 2]],
                        "distCoeffs": [0.0] * 12,
                        "width": 1280,
                        "height": 720,
                    }
            for k, v in intrinsics[episode_id].items():
                camera_serial = k
                extracted_intrinsics = v

                extracted_extrinsics = episode_metadata.get(camera_serial, None)
                extrinsic_quality = {
                    "metric": episode_metadata.get(f"{camera_serial}_metric_type", None),
                    "quality": episode_metadata.get(f"{camera_serial}_quality_metric", None),
                    "source": episode_metadata.get(f"{camera_serial}_source", None),
                }

                camera_serials_to_name = {v: k for k, v in camera_serials[episode_id].items()}
                calib_camera_name = camera_serials_to_name[camera_serial]

                calib_image_name = CAM2IMAGE_MAP[calib_camera_name]

                # get camera pose
                if extracted_extrinsics is None:
                    camera_pose_np = None
                else:
                    pos = extracted_extrinsics[0:3]
                    rot = R.from_euler("xyz", extracted_extrinsics[3:6])
                    rot_mat = rot.as_matrix()
                    camera_pose_np = np.eye(4)
                    camera_pose_np[:3, :3] = rot_mat
                    camera_pose_np[:3, 3] = pos

                # convert the intrinsics to a matrix
                fx, cx, fy, cy = extracted_intrinsics["cameraMatrix"]
                if fx == 0 or fy == 0 or cx == 0 or cy == 0:
                    if "wrist" in calib_image_name:
                        intrinsics_np = DEFAULT_CAMERA_INTRINSICS_WRIST
                    else:
                        intrinsics_np = DEFAULT_CAMERA_INTRINSICS
                else:
                    intrinsics_np = np.array([[fx, 0, cx],
                                                [0, fy, cy],
                                                [0, 0, 1]])

                if extracted_intrinsics["width"] == 672 and extracted_intrinsics["height"] == 376:
                    intrinsics_np = resize_intrinsics(intrinsics_np, src_size=(672, 376), dst_size=(1280, 720))

                camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]] = {}
                camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["camera_serial"] = camera_serial
                camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["camera_pose"] = camera_pose_np
                camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["intrinsics"] = intrinsics_np
                camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["extrinsic_quality"] = extrinsic_quality

            if not is_valid_extrinsic(camera_params, IoU_threshold=IoU_threshold, reprojection_error_threshold=reprojection_error_threshold):
                continue


            joint_positions, gripper_actions, cartesian_poses = load_episode_data(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5"))
            if not skip_real_obs:
                images = load_episode_images(os.path.join(local_dataset_path, episode_metadata["relative_path"], "recordings", "MP4"), camera_params)
            else:
                images = None

            # episode length
            episode_length = min(len(joint_positions), len(gripper_actions), len(cartesian_poses))
            if not skip_real_obs:
                episode_length = min(episode_length, len(images[list(images.keys())[0]]))

            tmp_img_path = os.path.join(replay_img_save_path, episode_metadata["relative_path"], "rgb_sim_rollout_masked", IMAGE2SCENE_CAM_MAP[calib_image_name])
            if os.path.exists(tmp_img_path) and len(os.listdir(tmp_img_path)) >= (episode_length-1):
                print(f"Skipping episode {episode_id} because it already exists")
                continue
    
            joint_positions = joint_positions[:episode_length]
            gripper_actions = gripper_actions[:episode_length]
            cartesian_poses = cartesian_poses[:episode_length]
            if not skip_real_obs:
                for cam in images.keys():
                    images[cam] = images[cam][:episode_length]

            joint_positions = np.concatenate([joint_positions, gripper_actions], axis=-1)

            video_path = os.path.join("runs/debug_filtered", f"replay_{episode_id}.mp4")

            # Accumulate into batch
            batch_store["camera_params"].append(deepcopy(camera_params))
            batch_store["joint_positions"].append(deepcopy(joint_positions))
            batch_store["real_obs"].append(deepcopy(images))
            batch_store["video_paths"].append(video_path)
            batch_store["img_paths"].append(os.path.join(replay_img_save_path, episode_metadata["relative_path"]))
            batch_store["episode_ids"].append(episode_id)

            # If batch full, flush
            if len(batch_store["episode_ids"]) >= N:
                _flush_batch(batch_store)
                continue
        except Exception as e:
            print(e)
            failed_episodes.append(episode_id)
            continue
    # Flush any remaining episodes in the batch
    _flush_batch(batch_store)
    with open("failed_episodes.txt", "w") as f:
        for episode_id in failed_episodes:
            f.write(f"{episode_id}\n")
except Exception as e:
    print(e)
finally:
    close_context(ctx)