from pathlib import Path
import numpy as np
import json
from tqdm import tqdm
import mediapy
import cv2
from scipy.spatial.transform import Rotation as R
import os
from copy import deepcopy

from replay_utils import create_replay_context, replay_sequence_multiple_cams_v2, close_context, is_valid_extrinsic, resize_intrinsics, load_episode_data, load_episode_images, sort_episode_length
from config import path_to_droid_repo, path_to_droid_calib_v2, local_dataset_path, replay_img_save_path, IMAGE2SCENE_CAM_MAP, DEFAULT_CAMERA_INTRINSICS, DEFAULT_CAMERA_INTRINSICS_WRIST, CAM2IMAGE_MAP, format_intrinsics


IoU_threshold = 0.75
reprojection_error_threshold = 5
default_keypoint_extrinsics_px_error = 10
predicted_keypoint_extrinsics_px_error = 8

skip_real_obs = True
# how many envs to run in parallel
N = 12 # TODO: change to 24

with open('failed_episodes.txt', 'r') as f:
    failed_at_first_try = [line.strip() for line in f.readlines()]

# load the extrinsics
cam2base_extrinsics_path = f"{path_to_droid_calib_v2}/pnp_cam2base_multiview.json"
with open(cam2base_extrinsics_path, "r") as f:
    cam2base_extrinsics = json.load(f)

episodes_with_good_extrinsics = {}
for episode_id, episode_metadata in cam2base_extrinsics.items():
    if "failure" in episode_metadata['relative_path']: # remove failure trajectories
        continue
    if len(episode_metadata) < 3:
        continue
    episodes_with_good_extrinsics[episode_id] = episode_metadata
print("Number of episodes with good extrinsics: ", len(episodes_with_good_extrinsics))

episodes_with_good_extrinsics = sort_episode_length(f"{path_to_droid_repo}/metadata_all_new.jsonl", episodes_with_good_extrinsics)

# load mapping from episode ID to path, then invert
episode_id_to_path_path = f"{path_to_droid_calib_v2}/full_episode_id_to_path.json"
with open(episode_id_to_path_path, "r") as f:
    episode_id_to_path = json.load(f)
episode_path_to_id = {v["r2d2"]: k for k, v in episode_id_to_path.items()}

# load the intrinsics
intrinsics_path = f"{path_to_droid_calib_v2}/intrinsics.json"
with open(intrinsics_path, "r") as f:
    intrinsics_raw = json.load(f) # 'gs://xembodiment_data/r2d2/r2d2-data-full/TRI/success/2023-11-07/Tue_Nov__7_16:18:39_2023/recordings/MP4--gs://xembodiment_data/r2d2/r2d2-data-full/TRI/success/2023-11-07/Tue_Nov__7_16:18:39_2023/trajectory.h5': {'exterior_image_1_left': [], 'exterior_image_2_left': [], 'wrist_image_left': []}
intrinsics = {}
for episode_path, intrinsics_dict in intrinsics_raw.items():
    try:
        rel_path = episode_path.split("gs://xembodiment_data/r2d2/r2d2-data-full/")[1].split("/recordings/MP4")[0]
        episode_id = episode_path_to_id[rel_path]
    except Exception as e:
        rel_path = episode_path.split("gs://xembodiment_data/r2d2/r2d2-data-full/")[1].split("/recordings/MP4")[0].replace(":", "_")
        episode_id = episode_path_to_id[rel_path]
    reformatted_intrinsics_dict = {}
    for cam_name, cam_intrinsics in intrinsics_dict.items(): # reformat the intrinsics, set default intrinsics if not available
        reformatted_intrinsics_dict[cam_name] = format_intrinsics(cam_intrinsics, "wrist" in cam_name)
    intrinsics[episode_id] = reformatted_intrinsics_dict

# load camera serials
camera_serials_path = f"{path_to_droid_repo}/camera_serials.json"
with open(camera_serials_path, "r") as f:
    camera_serials = json.load(f)
camera_name_to_serial = {}
for episode_id, episode_metadata in camera_serials.items():
    camera_name_to_serial[episode_id] = {serial: name for name, serial in episode_metadata.items()}

# update episodes_with_good_extrinsics to use camera name instead of serial
reformatted_episodes_with_good_extrinsics = {}
for episode_id, episode_metadata in episodes_with_good_extrinsics.items():
    reformatted_episodes_with_good_extrinsics[episode_id] = {
        'relative_path': episode_metadata['relative_path'],
        **{CAM2IMAGE_MAP[camera_name_to_serial[episode_id][camera_serial]]: episode_metadata[camera_serial] for camera_serial in camera_name_to_serial[episode_id].keys() if 'wrist' not in camera_name_to_serial[episode_id][camera_serial]},
        **{f"{CAM2IMAGE_MAP[camera_name_to_serial[episode_id][camera_serial]]}_to_serial": camera_serial for camera_serial in camera_name_to_serial[episode_id].keys()}
    }

# breakpoint()
# with open("reformatted_episodes_with_good_extrinsics.json", "w") as f:
#     json.dump(reformatted_episodes_with_good_extrinsics, f, indent=4)

# create the replay context
ctx = create_replay_context(headless=True, scene=1, num_envs=N)
try:
    failed_episodes = []
    
    # Simple batch manager for grouping inputs for parallel replay
    def _flush_batch(batch_store):
        if len(batch_store["episode_ids"]) == 0:
            return
        # Pad joint positions to the maximum length in the batch
        seq_lengths = [arr.shape[0] for arr in batch_store["joint_positions"]]
        max_len = max(seq_lengths)
        feat_dim = batch_store["joint_positions"][0].shape[1]
        B = len(batch_store["joint_positions"]) 
        padded = np.zeros((B, max_len, feat_dim), dtype=batch_store["joint_positions"][0].dtype)
        mask = np.zeros((B, max_len), dtype=np.bool_)  # True for valid timesteps
        for b_idx, seq in enumerate(batch_store["joint_positions"]):
            L = seq.shape[0]
            padded[b_idx, :L] = seq
            mask[b_idx, :L] = True
        # batchify the camera params
        batched_camera_params = {cam: {
            # "camera_serial": [],
            "camera_pose": [],
            "intrinsics": [],
            "extrinsic_quality": [],
        } for cam in IMAGE2SCENE_CAM_MAP.values()}
        for b_idx, cam_params in enumerate(batch_store["camera_params"]):
            for cam, v in cam_params.items():
                # batched_camera_params[cam]["camera_serial"].append(v["camera_serial"])
                batched_camera_params[cam]["camera_pose"].append(v["camera_pose"])
                batched_camera_params[cam]["intrinsics"].append(v["intrinsics"])
                batched_camera_params[cam]["extrinsic_quality"].append(v["extrinsic_quality"])
        for k, v in batched_camera_params.items():
            # the camera_pose for each view should either be all None or all not None
            if v["camera_pose"][0] is not None:
                v["camera_pose"] = np.stack(v["camera_pose"])
            else:
                v["camera_pose"] = None
            v["intrinsics"] = np.stack(v["intrinsics"])
        # # batchify the real_obs NOTE: seems more convenient to leave as a list of dicts
        # batched_real_obs = {cam: [] for cam in IMAGE2SCENE_CAM_MAP.values()}
        # for b_idx, real_obs in enumerate(batch_store["real_obs"]):
        #     for cam in IMAGE2SCENE_CAM_MAP.values():
        #         batched_real_obs[cam].append(real_obs[cam]) # each with shape (T, H, W, 3)

        # Call replay with batched inputs
        success_list = replay_sequence_multiple_cams_v2(
            ctx=ctx,
            camera_params=batched_camera_params,
            joint_positions=padded,
            padding_mask=mask,
            output_path=batch_store["video_paths"],
            img_paths=batch_store["img_paths"],
            real_obs=batch_store["real_obs"] if not skip_real_obs else None,
        )
        # Handle success/failure per episode if available
        for ep_id, ok in zip(batch_store["episode_ids"], success_list):
            if not bool(ok):
                failed_episodes.append(ep_id)
        # Reset batch
        batch_store["camera_params"].clear()
        batch_store["joint_positions"].clear()
        batch_store["real_obs"].clear()
        batch_store["video_paths"].clear()
        batch_store["img_paths"].clear()
        batch_store["episode_ids"].clear()

    batch_store = {
        "camera_params": [],
        "joint_positions": [],
        "real_obs": [],
        "video_paths": [],
        "img_paths": [],
        "episode_ids": [],
    }
    for episode_id, episode_metadata in tqdm(reformatted_episodes_with_good_extrinsics.items(), position=0, leave=True):
        try:
            if episode_id not in failed_at_first_try:
                continue
            if not os.path.exists(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5")):
                print(f"Skipping episode {episode_id} because it does not exist")
                failed_episodes.append(episode_id)
                continue
            # tmp_img_path = os.path.join(replay_img_save_path, episode_metadata["relative_path"], "rgb_sim_rollout_masked")
            # if os.path.exists(tmp_img_path) and len(os.listdir(tmp_img_path)) > 0:
            #     print(f"Skipping episode {episode_id} because it already exists")
            #     continue
            camera_params = {}
            for k, v in intrinsics[episode_id].items():
                extracted_intrinsics = v

                extracted_extrinsics = episode_metadata.get(k, None)
                extrinsic_quality = {
                    "metric": None,
                    "quality": None,
                    "source": None,
                }

                # get camera pose
                if extracted_extrinsics is None:
                    camera_pose_np = None
                else:
                    pos = extracted_extrinsics[0:3]
                    rot = R.from_euler("xyz", extracted_extrinsics[3:6])
                    rot_mat = rot.as_matrix()
                    camera_pose_np = np.eye(4)
                    camera_pose_np[:3, :3] = rot_mat
                    camera_pose_np[:3, 3] = pos

                # convert the intrinsics to a matrix
                fx, cx, fy, cy = extracted_intrinsics["cameraMatrix"]
                intrinsics_np = np.array([[fx, 0, cx],
                                            [0, fy, cy],
                                            [0, 0, 1]])

                if extracted_intrinsics["width"] == 672 and extracted_intrinsics["height"] == 376:
                    intrinsics_np = resize_intrinsics(intrinsics_np, src_size=(672, 376), dst_size=(1280, 720))

                camera_params[IMAGE2SCENE_CAM_MAP[k]] = {}
                camera_params[IMAGE2SCENE_CAM_MAP[k]]["camera_pose"] = camera_pose_np
                camera_params[IMAGE2SCENE_CAM_MAP[k]]["intrinsics"] = intrinsics_np
                camera_params[IMAGE2SCENE_CAM_MAP[k]]["extrinsic_quality"] = extrinsic_quality
                camera_params[IMAGE2SCENE_CAM_MAP[k]]["camera_serial"] = episode_metadata[f"{k}_to_serial"]

            # if not is_valid_extrinsic(camera_params, IoU_threshold=IoU_threshold, reprojection_error_threshold=reprojection_error_threshold):
            #     continue

            joint_positions, gripper_actions, cartesian_poses = load_episode_data(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5"))
            if not skip_real_obs:
                images = load_episode_images(os.path.join(local_dataset_path, episode_metadata["relative_path"], "recordings", "MP4"), camera_params)
            else:
                images = None

            # episode length
            episode_length = min(len(joint_positions), len(gripper_actions), len(cartesian_poses))
            if episode_length > 512 or episode_length < 45:
                print(f"Skipping episode {episode_id} because it is too long or too short ({episode_length})")
                continue
            if not skip_real_obs:
                episode_length = min(episode_length, len(images[list(images.keys())[0]]))

            tmp_img_path = os.path.join(replay_img_save_path, episode_metadata["relative_path"], "rgb_sim_rollout_masked", IMAGE2SCENE_CAM_MAP[k])
            if os.path.exists(tmp_img_path) and len(os.listdir(tmp_img_path)) >= (episode_length-1):
                print(f"Skipping episode {episode_id} because it already exists")
                continue
    
            joint_positions = joint_positions[:episode_length]
            gripper_actions = gripper_actions[:episode_length]
            cartesian_poses = cartesian_poses[:episode_length]
            if not skip_real_obs:
                for cam in images.keys():
                    images[cam] = images[cam][:episode_length]

            joint_positions = np.concatenate([joint_positions, gripper_actions], axis=-1)

            video_path = os.path.join("runs/debug_filtered_1028", f"replay_{episode_id}.mp4")

            # Accumulate into batch
            batch_store["camera_params"].append(deepcopy(camera_params))
            batch_store["joint_positions"].append(deepcopy(joint_positions))
            batch_store["real_obs"].append(deepcopy(images))
            batch_store["video_paths"].append(video_path)
            batch_store["img_paths"].append(os.path.join(replay_img_save_path, episode_metadata["relative_path"]))
            batch_store["episode_ids"].append(episode_id)

            print(f"saving to {batch_store['img_paths'][-1]}")

            # If batch full, flush
            if len(batch_store["episode_ids"]) >= N:
                _flush_batch(batch_store)
                continue
        except Exception as e:
            print(f"Skipping episode {episode_id} because of error: {e}")
            failed_episodes.append(episode_id)
            continue
    # Flush any remaining episodes in the batch
    _flush_batch(batch_store)
    with open("failed_episodes.txt", "w") as f:
        for episode_id in failed_episodes:
            f.write(f"{episode_id}\n")
except Exception as e:
    print(e)
finally:
    close_context(ctx)