import json
import os
from scipy.spatial.transform import Rotation as R
import numpy as np
from copy import deepcopy

from droid_meta.config import format_intrinsics, CAM2IMAGE_MAP, IMAGE2SCENE_CAM_MAP, SCENE_CAM2IMAGE_MAP
from replay.utils import resize_intrinsics, load_episode_data, load_episode_images
from replay.replay_functions import create_replay_context, replay_sequence_multiple_cams_v2, close_context
from replay.replay import batched_replay

local_dataset_path = "/home/junjieye/datasets/droid_raw/1.0.1"
skip_real_obs = False
control_mode = "end_effector" # "joint_position" or "end_effector"
N_Parallel = 2
replay_img_save_path = "runs/droid_replay"

cam2base_extrinsics_path = "droid_meta/pnp_cam2base_multiview.json"
with open(cam2base_extrinsics_path, "r") as f:
    cam2base_extrinsics = json.load(f)

successful_episodes = {}
for episode_id, episode_metadata in cam2base_extrinsics.items():
    if "failure" in episode_metadata['relative_path']:
        continue
    if len(episode_metadata) == 1: # camera param is missing
        continue
    successful_episodes[episode_id] = episode_metadata
print("Number of successful episodes:", len(successful_episodes))

episode_id_to_path_path = "droid_meta/droid/episode_id_to_path.json"
with open(episode_id_to_path_path, "r") as f:
    episode_id_to_path = json.load(f)
episode_path_to_id = {v: k for k, v in episode_id_to_path.items()}

intrinsics_path = "droid_meta/droid/intrinsics.json"
with open(intrinsics_path, "r") as f:
    intrinsics_raw = json.load(f)

intrinsics = {}
for episode_id, intrinsics_dict in intrinsics_raw.items():
    rel_path = episode_id_to_path[episode_id]
    reformatted_intrinsics_dict = {}
    for cam_name, cam_intrinsics in intrinsics_dict.items(): # reformat the intrinsics, set default intrinsics if not available
        reformatted_intrinsics_dict[cam_name] = format_intrinsics(cam_intrinsics['cameraMatrix'], "wrist" in cam_name)
    intrinsics[episode_id] = reformatted_intrinsics_dict

# load camera serials
camera_serials_path = "droid_meta/droid/camera_serials.json"
with open(camera_serials_path, "r") as f:
    camera_serials = json.load(f)
camera_serial_to_name = {}
for episode_id, episode_metadata in camera_serials.items():
    camera_serial_to_name[episode_id] = {serial: name for name, serial in episode_metadata.items()}

# update episodes_with_good_extrinsics to use camera name instead of serial
reformatted_episodes_with_good_extrinsics = {}
for episode_id, episode_metadata in successful_episodes.items():
    try:
        reformatted_episodes_with_good_extrinsics[episode_id] = {
            'relative_path': episode_metadata['relative_path'],
            **{CAM2IMAGE_MAP[camera_serial_to_name[episode_id][camera_serial]]: episode_metadata[camera_serial] for camera_serial in camera_serial_to_name[episode_id].keys() if 'wrist' not in camera_serial_to_name[episode_id][camera_serial]},
            **{f"{CAM2IMAGE_MAP[camera_serial_to_name[episode_id][camera_serial]]}_intrinsics": intrinsics[episode_id][camera_serial] for camera_serial in camera_serial_to_name[episode_id].keys()},
            **{f"{CAM2IMAGE_MAP[camera_serial_to_name[episode_id][camera_serial]]}_to_serial": camera_serial for camera_serial in camera_serial_to_name[episode_id].keys()}
        }
    except Exception as e:
        print(f"Error processing episode {episode_id}: {e}")

ctx = create_replay_context(headless=True, scene=1, num_envs=N_Parallel, env_name="DROID" if control_mode == "joint_position" else "DROID_EE")
try:
    batch_store = {
        "camera_params": [],
        "actions": [],
        "real_obs": [],
        "video_paths": [],
        "img_paths": [],
        "episode_ids": [],
    }
    for episode_id, episode_metadata in reformatted_episodes_with_good_extrinsics.items():
        try:
            if not os.path.exists(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5")):
                print(f"Skipping episode {episode_id} because it does not exist")
                continue

            camera_params = {}
            for cam_name, img_name in SCENE_CAM2IMAGE_MAP.items():
                cam_extrinsics = episode_metadata.get(img_name, None)
                if cam_extrinsics is None:
                    camera_pose_np = None
                else:
                    pos = cam_extrinsics[0:3]
                    rot = R.from_euler("xyz", cam_extrinsics[3:6]).as_matrix()
                    camera_pose_np = np.eye(4)
                    camera_pose_np[:3, :3] = rot
                    camera_pose_np[:3, 3] = pos
                
                cam_intrinsics = episode_metadata[f"{img_name}_intrinsics"]
                fx, cx, fy, cy = cam_intrinsics["cameraMatrix"]
                intrinsics_np = np.array([[fx, 0, cx],
                                        [0, fy, cy],
                                        [0, 0, 1]])
                if cam_intrinsics["width"] == 672 and cam_intrinsics["height"] == 376:
                    intrinsics_np = resize_intrinsics(intrinsics_np, src_size=(672, 376), dst_size=(1280, 720))
                
                camera_params[cam_name] = {
                    "camera_pose": camera_pose_np,
                    "intrinsics": intrinsics_np,
                    "camera_serial": episode_metadata[f"{img_name}_to_serial"],
                }


            # for k, v in intrinsics[episode_id].items():
            #     extracted_intrinsics = v

            #     extracted_extrinsics = episode_metadata.get(k, None)
                
            #     # get camera pose
            #     if extracted_extrinsics is None:
            #         camera_pose_np = None
            #     else:
            #         pos = extracted_extrinsics[0:3]
            #         rot = R.from_euler("xyz", extracted_extrinsics[3:6]).as_matrix()
            #         camera_pose_np = np.eye(4)
            #         camera_pose_np[:3, :3] = rot
            #         camera_pose_np[:3, 3] = pos

            #     # convert intrinsics to matrix
            #     fx, cx, fy, cy = extracted_intrinsics["cameraMatrix"]
            #     intrinsics_np = np.array([[fx, 0, cx], 
            #                             [0, fy, cy],
            #                             [0, 0, 1]])

            #     if extracted_intrinsics["width"] == 672 and extracted_intrinsics["height"] == 376:
            #         intrinsics_np = resize_intrinsics(intrinsics_np, src_size=(672, 376), dst_size=(1280, 720))

            #     camera_params[IMAGE2SCENE_CAM_MAP[k]] = {}
            #     camera_params[IMAGE2SCENE_CAM_MAP[k]]["camera_pose"] = camera_pose_np
            #     camera_params[IMAGE2SCENE_CAM_MAP[k]]["intrinsics"] = intrinsics_np
            #     camera_params[IMAGE2SCENE_CAM_MAP[k]]["camera_serial"] = episode_metadata[f"{k}_to_serial"]

            joint_positions, gripper_actions, cartesian_poses = load_episode_data(os.path.join(local_dataset_path, episode_metadata["relative_path"], "trajectory.h5"))

            episode_length = min(len(joint_positions), len(gripper_actions), len(cartesian_poses))
            if episode_length > 512 or episode_length < 45:
                print(f"Skipping episode {episode_id} because it is too long or too short ({episode_length})")
                continue
            if not skip_real_obs:
                images = load_episode_images(os.path.join(local_dataset_path, episode_metadata["relative_path"], "recordings", "MP4"), camera_params)
                episode_length = min(episode_length, len(images[list(images.keys())[0]]))
                for cam in images.keys():
                    images[cam] = images[cam][:episode_length]
            else:
                images = None

            joint_positions = joint_positions[:episode_length]
            gripper_actions = gripper_actions[:episode_length]
            cartesian_poses = cartesian_poses[:episode_length]

            if control_mode == "joint_position":
                actions = np.concatenate([joint_positions, gripper_actions], axis=-1)
            elif control_mode == "end_effector":
                actions = np.concatenate([cartesian_poses, gripper_actions], axis=-1)
            else:
                raise ValueError(f"Invalid control mode: {control_mode}")

            video_path = os.path.join("runs/debug_filtered_1028", f"replay_{episode_id}.mp4")

            # Accumulate into batch
            batch_store["camera_params"].append(deepcopy(camera_params))
            batch_store["actions"].append(deepcopy(actions))
            batch_store["real_obs"].append(deepcopy(images))
            batch_store["video_paths"].append(video_path)
            batch_store["img_paths"].append(os.path.join(replay_img_save_path, episode_metadata["relative_path"]))
            batch_store["episode_ids"].append(episode_id)

            if len(batch_store["episode_ids"]) >= N_Parallel:
                batched_replay(ctx, batch_store, control_mode, skip_real_obs)
                continue
        except Exception as e:
            print(f"Error processing episode {episode_id}: {e}")
            continue
    batched_replay(ctx, batch_store, control_mode, skip_real_obs)

except Exception as e:
    print(f"Error: {e}")
finally:
    close_context(ctx)
