from pathlib import Path
import numpy as np
import json
from tqdm import tqdm
import mediapy
import cv2
from scipy.spatial.transform import Rotation as R
import os
from copy import deepcopy

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
import tensorflow_datasets as tfds

from custom.replay_utils import create_replay_context, replay_sequence_multiple_cams, close_context, is_valid_extrinsic

CAM2IMAGE_MAP = {
    "ext1_cam_serial": "exterior_image_1_left",
    "ext2_cam_serial": "exterior_image_2_left",
    "wrist_cam_serial": "wrist_image_left",
}
IMAGE2SCENE_CAM_MAP = {
    "exterior_image_1_left": "external_cam1",
    "exterior_image_2_left": "external_cam2",
    "wrist_image_left": "wrist_cam",
}
DEFAULT_CAMERA_INTRINSICS = np.array([[525.31878662,   0.,         648.12060547],
                                    [  0.,         525.31878662, 374.60479736],
                                    [  0.,           0.,           1.        ]])
DEFAULT_CAMERA_INTRINSICS_WRIST = np.array([[732., 0., 640],
                                            [0., 732., 360],
                                            [0., 0., 1.]])

path_to_droid_repo = "./droid"

ds = tfds.load("droid_100", data_dir="/home/junjieye/workspace/sim-evals/dataset", split="train")

# load the extrinsics
cam2base_extrinsics_path = f"{path_to_droid_repo}/cam2base_extrinsic_superset.json"
with open(cam2base_extrinsics_path, "r") as f:
    cam2base_extrinsics = json.load(f)

# load the intrinsics
intrinsics_path = f"{path_to_droid_repo}/intrinsics.json"
with open(intrinsics_path, "r") as f:
    intrinsics = json.load(f)

# load mapping from episode ID to path, then invert
episode_id_to_path_path = f"{path_to_droid_repo}/episode_id_to_path.json"
with open(episode_id_to_path_path, "r") as f:
    episode_id_to_path = json.load(f)
episode_path_to_id = {v: k for k, v in episode_id_to_path.items()}

# load camera serials
camera_serials_path = f"{path_to_droid_repo}/camera_serials.json"
with open(camera_serials_path, "r") as f:
    camera_serials = json.load(f)


# create the replay context
ctx = create_replay_context(headless=True, scene=1)
# default_cfg = deepcopy(ctx.env_cfg)
try:
    for ep in tqdm(ds):
        file_path = ep["episode_metadata"]["file_path"].numpy().decode("utf-8")
        recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode("utf-8")

        episode_path = file_path.split("r2d2-data-full/")[1].split("/trajectory")[0]
        if episode_path not in episode_path_to_id:
            continue
        episode_id = episode_path_to_id[episode_path]

        if episode_id not in cam2base_extrinsics:
            continue

        camera_params = {}
        for k, v in intrinsics[episode_id].items():
            camera_serial = k
            extracted_intrinsics = v

            extracted_extrinsics = cam2base_extrinsics[episode_id].get(camera_serial, None)
            extrinsic_quality = {
                "metric": cam2base_extrinsics[episode_id].get(f"{camera_serial}_metric_type", None),
                "quality": cam2base_extrinsics[episode_id].get(f"{camera_serial}_quality_metric", None),
                "source": cam2base_extrinsics[episode_id].get(f"{camera_serial}_source", None),
            }

            camera_serials_to_name = {v: k for k, v in camera_serials[episode_id].items()}
            calib_camera_name = camera_serials_to_name[camera_serial]

            calib_image_name = CAM2IMAGE_MAP[calib_camera_name]

            # get camera pose
            if extracted_extrinsics is None:
                camera_pose_np = None
            else:
                pos = extracted_extrinsics[0:3]
                rot = R.from_euler("xyz", extracted_extrinsics[3:6])
                rot_mat = rot.as_matrix()
                camera_pose_np = np.eye(4)
                camera_pose_np[:3, :3] = rot_mat
                camera_pose_np[:3, 3] = pos

            # convert the intrinsics to a matrix
            fx, cx, fy, cy = extracted_intrinsics["cameraMatrix"]
            if fx == 0 or fy == 0 or cx == 0 or cy == 0:
                if "wrist" in calib_image_name:
                    intrinsics_np = DEFAULT_CAMERA_INTRINSICS_WRIST
                else:
                    intrinsics_np = DEFAULT_CAMERA_INTRINSICS
            else:
                intrinsics_np = np.array([[fx, 0, cx],
                                            [0, fy, cy],
                                            [0, 0, 1]])

            camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]] = {}
            camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["camera_serial"] = camera_serial
            camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["camera_pose"] = camera_pose_np
            camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["intrinsics"] = intrinsics_np
            camera_params[IMAGE2SCENE_CAM_MAP[calib_image_name]]["extrinsic_quality"] = extrinsic_quality

        if not is_valid_extrinsic(camera_params, IoU_threshold=0.8, reprojection_error_threshold=5):
            continue

        joint_positions = []
        images = {}
        for cam_scene in IMAGE2SCENE_CAM_MAP.values():
            images[cam_scene] = []
        gripper_actions = []
        cartesian_poses = []
        for step in ep["steps"]:
            joint_positions.append(np.concatenate([
                step["observation"]["joint_position"].numpy(),
                step["observation"]["gripper_position"].numpy(),
            ]))
            gripper_actions.append(step["action_dict"]["gripper_position"].numpy())
            for cam_real, cam_scene in IMAGE2SCENE_CAM_MAP.items():
                images[cam_scene].append(step["observation"][cam_real].numpy())
            cartesian_poses.append(step["observation"]["cartesian_position"].numpy())

        joint_positions = np.array(joint_positions)
        images = {k: np.array(v) for k, v in images.items()}
        cartesian_poses = np.array(cartesian_poses)
        gripper_actions = np.array(gripper_actions)
        joint_positions = np.concatenate([joint_positions, gripper_actions], axis=-1)

        video_path = os.path.join("runs/debug_filtered", f"replay_{episode_id}.mp4")
        # if not os.path.exists(video_path):
        # ctx.env_cfg = default_cfg
        video_path = replay_sequence_multiple_cams(
            ctx=ctx,
            camera_params=camera_params,
            joint_positions=joint_positions,
            output_path=video_path,
            real_obs=images,
        )

        # cam_replay_images = mediapy.read_video(video_path)
        # cam_images = images["exterior_image_1_left"]
        # h, w = cam_images.shape[1:3]
        # cam_replay_images = np.stack([
        #     cv2.resize(cam_replay_images[i], (w, h), interpolation=cv2.INTER_NEAREST)
        #     for i in range(len(cam_replay_images))
        # ], axis=0)
        # masked_cam_images = cam_images * (cam_replay_images.sum(axis=3, keepdims=True) > 0)
        # concat_images = np.concatenate([cam_images, cam_replay_images, masked_cam_images], axis=1)
        # mediapy.write_video(os.path.join("runs/debug", f"replay_{episode_id}_concat.mp4"), concat_images)

finally:
    close_context(ctx)