
from __future__ import annotations

from dataclasses import MISSING

import numpy as np
import isaaclab.sim as sim_utils
# from uwlab_tasks.manager_based.manipulation.reset_states.mdp import utils
from isaaclab.assets import AssetBaseCfg, RigidObjectCfg
from isaaclab.envs import ManagerBasedRLEnvCfg, ViewerCfg, ManagerBasedEnv, ManagerBasedRLEnv
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import RewardTermCfg as RewTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import TiledCameraCfg
from isaaclab.utils import configclass, noise
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR

# from uwlab_assets import UWLAB_CLOUD_ASSETS_DIR, ARHANJAIN_CLOUD_ASSETS_DIR
# from uwlab_assets.robots.ur5e_robotiq_gripper import (
#     EXPLICIT_UR5E_ROBOTIQ_2F85,
#     IMPLICIT_UR5E_ROBOTIQ_2F85,
#     Ur5eRobotiq2f85RelativeJointPositionAction,
# )

# from uwlab_tasks.manager_based.manipulation.reset_states.config.droid.actions import (
#     DROIDRelativeOSCAction,
# )
# from uwlab_assets.robots.DROID import DROIDJointPositionAction, IMPLICIT_DROID, EXPLICIT_DROID, DROIDRelativeJointPositionAction, DROIDIkDeltaAction

# from ... import mdp as task_mdp
import sim_improvement.environments.mdp as mdp
from sim_improvement.environments.droid.robot import IMPLICIT_DROID, DROIDJointPositionAction, DROIDIkRelativeAction

UWLAB_CLOUD_ASSETS_DIR = "https://uwlab-assets.s3.us-west-004.backblazeb2.com"

@configclass
class SceneCfg(InteractiveSceneCfg):
    """Scene configuration for RL state environment."""

    robot = IMPLICIT_DROID.replace(prim_path="{ENV_REGEX_NS}/Robot")

    splat = AssetBaseCfg(
        prim_path="{ENV_REGEX_NS}/splat",
        spawn=sim_utils.UsdFileCfg(
            usd_path="./envs/assets/tri_droid_scene/combined.usd",
        ),
        init_state=AssetBaseCfg.InitialStateCfg(pos=(0.0, -0.05, 0.0), rot=(0.0, 0.0, 0.0, 1.0)),
    )

    cube = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/cube",
        spawn=sim_utils.UsdFileCfg(
            usd_path=f"{UWLAB_CLOUD_ASSETS_DIR}/Props/Custom/InsertiveCube/insertive_cube.usd",
            scale=(1, 1, 1),
            rigid_props=sim_utils.RigidBodyPropertiesCfg(
                solver_position_iteration_count=4,
                solver_velocity_iteration_count=0,
                disable_gravity=False,
                kinematic_enabled=False,
            ),
            mass_props=sim_utils.MassPropertiesCfg(mass=0.001),
        ),
        init_state=RigidObjectCfg.InitialStateCfg(pos=(0.3, -0.15, 0.1), rot=(1.0, 0.0, 0.0, 0.0)),
    )

    plate = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/plate",
        spawn=sim_utils.UsdFileCfg(
            usd_path="./envs/assets/plate/plate.usd",
            scale=(1, 1, 1),
            rigid_props=sim_utils.RigidBodyPropertiesCfg(
                solver_position_iteration_count=4,
                solver_velocity_iteration_count=0,
                disable_gravity=False,
                kinematic_enabled=True,
            ),
            # since kinematic_enabled=True, mass does not matter
            mass_props=sim_utils.MassPropertiesCfg(mass=1.0),
        ),
        init_state=RigidObjectCfg.InitialStateCfg(pos=(0.3, 0.15, 0.03), rot=(1.0, 0.0, 0.0, 0.0)),
    )

    external_camera = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/external_camera",
        offset=TiledCameraCfg.OffsetCfg(
            pos = (-0.05277, -0.55604, 0.47632),
            rot = (0.80826, 0.45584, -0.18309, -0.32465),
            convention="opengl"
        ),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length = 0.8,
            horizontal_aperture = 1.0,
            vertical_aperture=1.4721,
        ),
        height=720,
        width=1280,
    )

    wrist_camera = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/Robot/robotiq_2f85_gripper/robotiq_base_link/wrist_camera",
        height=720,
        width=1280,
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=2.8,
            focus_distance=28.0,
            horizontal_aperture=5.376,
            vertical_aperture=3.024,
            clipping_range=(0.001, 20.0),
        ),
        offset=TiledCameraCfg.OffsetCfg(
            pos = (0.011, 0.031, 0.074),
            rot = (0.57291, 0.41446, -0.41446, -0.57291),
            convention="opengl",
        ),
    )

    # Environment
    table = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/Table",
        init_state=RigidObjectCfg.InitialStateCfg(pos=(0.4, 0.0, -0.881), rot=(0.707, 0.0, 0.0, -0.707)),
        spawn=sim_utils.UsdFileCfg(
            usd_path=f"{UWLAB_CLOUD_ASSETS_DIR}/Props/Mounts/UWPatVention/pat_vention.usd",
            rigid_props=sim_utils.RigidBodyPropertiesCfg(kinematic_enabled=True),
            visible=False,
        ),
    )

    sky_light = AssetBaseCfg(
        prim_path="/World/skyLight",
        spawn=sim_utils.DomeLightCfg(
            intensity=10000.0,
            texture_file=f"{ISAAC_NUCLEUS_DIR}/Materials/Textures/Skies/PolyHaven/kloofendal_43d_clear_puresky_4k.hdr",
            # texture_file="./envs/assets/monochrome.hdr",
        ),
    )


@configclass
class BaseEventCfg:
    """Configuration for events."""
    # mode: reset
    reset_everything = EventTerm(func=mdp.reset_scene_to_default, mode="reset", params={})

    reset_cube = EventTerm(
        func=mdp.reset_root_state_uniform, 
        mode="reset", 
        params={
            "asset_cfg": SceneEntityCfg("cube"),
            "pose_range": {
                "x": (-0.1, 0.1),
                "y": (-0.1, 0.1),
                "z": (0.0, 0.0),
            },
            "velocity_range": {}
        },
    )

    reset_plate = EventTerm(
        func=mdp.reset_root_state_uniform, 
        mode="reset", 
        params={
            "asset_cfg": SceneEntityCfg("plate"),
            "pose_range": {
                "x": (-0.1, 0.1),
                "y": (-0.1, 0.1),
                "z": (0.0, 0.0),
            },  
            "velocity_range": {}
        },
    )

def gripper_pos(
    env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
):
    robot = env.scene[asset_cfg.name]
    joint_names = ["finger_joint"]
    joint_indices = [
        i for i, name in enumerate(robot.data.joint_names) if name in joint_names
    ]
    joint_pos = robot.data.joint_pos[:, joint_indices]
    # rescale
    joint_pos = joint_pos / (np.pi / 4)
    return joint_pos


@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""
    @configclass
    class VisionCfg(ObsGroup):
        """Observations for policy group."""
        wrist_camera = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("wrist_camera"), 
            "data_type": "rgb", "normalize": False}
            )
        external_camera = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("external_camera"), 
            "data_type": "rgb", "normalize": False}
            )

        arm_joint_pos = ObsTerm(
            func=mdp.selected_joint_pos, 
            params={
                "asset_cfg": SceneEntityCfg("robot"), 
                "joint_names": [
                    "panda_joint1",
                    "panda_joint2",
                    "panda_joint3",
                    "panda_joint4",
                    "panda_joint5",
                    "panda_joint6",
                    "panda_joint7",
                ]
            }
        )

        gripper_pos = ObsTerm(
            func=gripper_pos,
            noise=noise.GaussianNoiseCfg(std=0.05), clip=(0, 1)
        )

        def __post_init__(self):
            self.concatenate_terms = False
            self.enable_corruption = False

    # observation groups
    vision: VisionCfg = VisionCfg()


@configclass
class RewardsCfg:

    progress_context = RewTerm(
        func = mdp.ProgressContext,
        params = {
            "insertive_asset_cfg": SceneEntityCfg("cube"),
            "receptive_asset_cfg": SceneEntityCfg("plate"),
            "obb_check_axes": (0, 1),
            "open_finger_threshold": 0.1,
            "robot_cfg": SceneEntityCfg("robot"),
        },
        weight = 0.1
    )

@configclass
class TerminationsCfg:
    """Termination terms for the MDP."""

    time_out = DoneTerm(func=mdp.time_out, time_out=True)
    abnormal_robot = DoneTerm(func=mdp.abnormal_robot_state)
    # success = DoneTerm(func=mdp.success_reward)
    success = DoneTerm(
        func=mdp.success_termination,
        params={
            "context": "progress_context",
        },
    )



@configclass
class EnvCfg(ManagerBasedRLEnvCfg):
    scene = SceneCfg(num_envs=32, env_spacing=1.5)
    observations = ObservationsCfg()
    # actions = DROIDJointPositionAction()
    rewards = RewardsCfg()
    terminations = TerminationsCfg()
    events = BaseEventCfg()
    # commands = CommandsCfg()
    viewer: ViewerCfg = ViewerCfg(eye=(2.0, 0.0, 0.75), origin_type="world", env_index=0, asset_name="robot")

    def __post_init__(self):
        self.decimation = 8
        self.episode_length_s = 15.0
        # simulation settings
        self.sim.dt = 1 / (15.0 * self.decimation)

        # Contact and solver settings
        self.sim.physx.solver_type = 1
        self.sim.physx.max_position_iteration_count = 192
        self.sim.physx.max_velocity_iteration_count = 1
        self.sim.physx.bounce_threshold_velocity = 0.02
        self.sim.physx.friction_offset_threshold = 0.01
        self.sim.physx.friction_correlation_distance = 0.0005

        self.sim.physx.gpu_found_lost_aggregate_pairs_capacity = 1024 * 1024 * 4
        self.sim.physx.gpu_total_aggregate_pairs_capacity = 2**23
        self.sim.physx.gpu_max_rigid_contact_count = 2**23
        self.sim.physx.gpu_max_rigid_patch_count = 2**23
        self.sim.physx.gpu_collision_stack_size = 2**31

        # Render settings
        self.sim.render.enable_dlssg = True
        self.sim.render.enable_ambient_occlusion = True
        self.sim.render.enable_reflections = True
        self.sim.render.enable_dl_denoiser = True


        # SCALING HACK TO GET AROUND CAMERA CLIPPING DISTANCE ISSUE
        # self.scene.splat.spawn.scale = (10, 10, 10)
        # self.scene.robot.spawn.scale = (10.0, 10.0, 10.0)
        # self.scene.external_camera.offset.pos = [entry * 10.0 for entry in self.scene.external_camera.offset.pos]
        # self.scene.table.spawn.scale = (10.0, 10.0, 10.0)
        # self.scene.table.init_state.pos = [entry * 10.0 for entry in self.scene.table.init_state.pos]
        # self.scene.cube.spawn.scale = (10.0, 10.0, 10.0)
        # self.scene.cube.init_state.pos = [entry * 10.0 for entry in self.scene.cube.init_state.pos]
        # self.scene.plate.spawn.scale = (10.0, 10.0, 10.0)
        # self.scene.plate.init_state.pos = [entry * 10.0 for entry in self.scene.plate.init_state.pos]


@configclass
class DROIDJointPosRolloutCfg(EnvCfg):
    """Training configuration for Relative Joint Position action space."""

    # events: TrainEventCfg = TrainEventCfg()
    actions = DROIDJointPositionAction()

    def __post_init__(self):
        super().__post_init__()


@configclass
class DROIDRelativeIKRolloutCfg(EnvCfg):
    """Training configuration for Relative Joint Position action space."""

    # events: TrainEventCfg = TrainEventCfg()
    actions = DROIDIkRelativeAction()

    def __post_init__(self):
        super().__post_init__()
        self.episode_length_s = 1000.0

        # del self.scene.splat
        del self.scene.external_camera
        del self.scene.wrist_camera
        del self.observations.vision



        # self.scene.robot.actuators["arm"].stiffness= 4000.0
        # self.scene.robot.actuators["arm"].damping= 800.0
        # self.scene.robot.actuators["arm"].effort_limit= 870.0


