from __future__ import annotations

import numpy as np
from isaaclab.envs.mdp.actions.task_space_actions import DifferentialInverseKinematicsAction
import isaaclab.sim as sim_utils
from isaaclab.assets import AssetBaseCfg, RigidObjectCfg, ArticulationCfg
from isaaclab.envs import ManagerBasedRLEnvCfg, ViewerCfg, ManagerBasedEnv, ManagerBasedRLEnv
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import RewardTermCfg as RewTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import TiledCameraCfg
from isaaclab.utils import configclass, noise
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR

# from isaaclab.envs.mdp.events import reset_scene_to_default
import sim_improvement.environments.mdp as mdp

from sim_improvement.environments.lbm.robot import  (
    IMPLICIT_PANDA,
    EXPLICIT_PANDA,
    LEFT_PANDA_DEFAULT_JOINT_POS,
    RIGHT_PANDA_DEFAULT_JOINT_POS,
    ARM_IK_ACTION,
    LBM_GRAVITY_COMP_DIFF_IK,
    GRIPPER_ACTION,
    LBM_JOINT_EFFORT,
    )

@configclass
class SceneCfg(InteractiveSceneCfg):
    """Scene configuration for RL state environment."""

    left_panda = EXPLICIT_PANDA.replace(
        prim_path="{ENV_REGEX_NS}/left_panda",
        init_state=ArticulationCfg.InitialStateCfg(
            pos=(-0.5937, -0.34362, -0.08484),
            rot=(0.36811, 0.01027, 0.00078, 0.92973),
            joint_pos=LEFT_PANDA_DEFAULT_JOINT_POS
        )
    )

    right_panda = EXPLICIT_PANDA.replace(
        prim_path="{ENV_REGEX_NS}/right_panda",
        init_state=ArticulationCfg.InitialStateCfg(
            pos=(-0.5937, 0.32962, -0.08062),
            rot=(0.91675, 0.01312, 0.01089, 0.39909),
            joint_pos=RIGHT_PANDA_DEFAULT_JOINT_POS,
        )
    )

    background = AssetBaseCfg(
        prim_path="{ENV_REGEX_NS}/background",
        spawn=sim_utils.DomeLightCfg(
            intensity=500.0,
            texture_file = "envs/2_riverway_shelf/poly_haven_studio_4k.hdr"
            # texture_file = "/home/arhanjain/Downloads/99589f71-228b-4ee0-9b73-25f43965e9b5.exr",
        ),
        init_state=AssetBaseCfg.InitialStateCfg(
            pos=(0.0, 0.0, 0.0),
            rot=(0.0, 0.0, 0.0, 1.0),
        )
    )

    scenario = AssetBaseCfg(
        prim_path="{ENV_REGEX_NS}/scenario",
        spawn=sim_utils.UsdFileCfg(
            usd_path="./envs/2_riverway_shelf/scene.usda",
        ),
    )

    # External scene cameras from Riverway station calibration (2024-11-11).
    # Poses are relative to manipuland_table::table_top_center (= world origin).
    scene_camera_right = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/scene_camera_right",
        offset=TiledCameraCfg.OffsetCfg(
            pos=(0.54755, 0.18739, 0.83597),
            rot=(0.63083, 0.15475, 0.27309, 0.70960),
            convention="opengl",
        ),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=0.5953,
            horizontal_aperture=1.0,
            vertical_aperture=0.75,
        ),
        height=480,
        width=640,
    )

    scene_camera_left = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/scene_camera_left",
        offset=TiledCameraCfg.OffsetCfg(
            pos=(0.49714, -0.36177, 0.77877),
            rot=(0.72133, 0.33411, 0.07049, 0.60257),
            convention="opengl",
        ),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=0.5953,
            horizontal_aperture=1.0,
            vertical_aperture=0.75,
        ),
        height=480,
        width=640,
    )

    # Wrist cameras from Riverway station calibration.
    # Poses are relative to panda_link8 on each arm.
    # Derived from anzu scenario: flange_rotated + [-0.06577, 0, 0.0048] Rpy(0,0,-90)
    # converted to panda_link8 frame via Rz(-45) (flange_rotated definition).
    wrist_camera_right = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/right_panda/panda_link8/wrist_camera_right",
        offset=TiledCameraCfg.OffsetCfg(
            # pos=(-0.05, 0.05, 0.01),
            # rot=(0.37687, -0.06645, 0.16043, -0.90984),
            pos=(-0.04651, 0.04651, 0.0048),
            rot=(0.38268, 0.0, 0.0, -0.92388),
            convention="ros",
        ),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=0.4188,
            horizontal_aperture=1.0,
            vertical_aperture=0.75,
        ),
        height=480,
        width=640,
    )

    # Derived from anzu scenario: flange_rotated + [0.06577, 0, 0.0048] Rpy(0,0,90)
    wrist_camera_left = TiledCameraCfg(
        prim_path="{ENV_REGEX_NS}/left_panda/panda_link8/wrist_camera_left",
        offset=TiledCameraCfg.OffsetCfg(
            # pos=(0.05, -0.05, 0.01),
            # rot=(0.90984, -0.16043, -0.06645, 0.37687),
            pos=(0.04651, -0.04651, 0.0048),
            rot=(0.92388, 0.0, 0.0, 0.38268),
            convention="ros",
        ),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=0.4188,
            horizontal_aperture=1.0,
            vertical_aperture=0.75,
        ),
        height=480,
        width=640,
    )

@configclass
class ActionsCfg:
    left_panda_arm = LBM_GRAVITY_COMP_DIFF_IK.replace(asset_name="left_panda")
    left_panda_gripper = GRIPPER_ACTION.replace(asset_name="left_panda")
    right_panda_arm = LBM_GRAVITY_COMP_DIFF_IK.replace(asset_name="right_panda")
    right_panda_gripper = GRIPPER_ACTION.replace(asset_name="right_panda")


@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""

    @configclass
    class VisionCfg(ObsGroup):
        """Observations for policy group."""
        wrist_camera_right = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("wrist_camera_right"), 
            "data_type": "rgb", "normalize": False}
            )
        wrist_camera_left = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("wrist_camera_left"), 
            "data_type": "rgb", "normalize": False}
            )

        external_camera_right = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("scene_camera_right"), 
            "data_type": "rgb", "normalize": False}
            )
        external_camera_left = ObsTerm(
            func=mdp.image, 
            params={"sensor_cfg": SceneEntityCfg("scene_camera_left"), 
            "data_type": "rgb", "normalize": False}
            )

        left_ee_pos = ObsTerm(
            func=mdp.link_pos,
            params={"root_asset_cfg": SceneEntityCfg("left_panda"), "link_name": "panda_link8"},
        )
        left_ee_quat = ObsTerm(
            func=mdp.link_quat,
            params={"root_asset_cfg": SceneEntityCfg("left_panda"), "link_name": "panda_link8"},
        )

        right_ee_pos = ObsTerm(
            func=mdp.link_pos,
            params={"root_asset_cfg": SceneEntityCfg("right_panda"), "link_name": "panda_link8"},
        )
        right_ee_quat = ObsTerm(
            func=mdp.link_quat,
            params={"root_asset_cfg": SceneEntityCfg("right_panda"), "link_name": "panda_link8"},
        )

        left_gripper_pos = ObsTerm(
            func=lambda env, asset_cfg: mdp.joint_pos(env, asset_cfg) * -2,
            params={
                "asset_cfg": SceneEntityCfg(name="left_panda", joint_names=["panda_finger_joint1"]), 
            },
        )

        right_gripper_pos = ObsTerm(
            func=lambda env, asset_cfg: mdp.joint_pos(env, asset_cfg) * -2,
            params={
                "asset_cfg": SceneEntityCfg(name="right_panda", joint_names=["panda_finger_joint1"]), 
            },
        )
        def __post_init__(self):
            self.concatenate_terms = False
            self.enable_corruption = False

    # observation groups
    vision: VisionCfg = VisionCfg()

@configclass
class RewardsCfg:
    """Reward specifications for the MDP."""

@configclass
class TerminationsCfg:
    """Termination specifications for the MDP."""
    timeout = DoneTerm(func=mdp.time_out, time_out=True)

@configclass
class EventCfg:
    reset_scene = EventTerm(
        func=mdp.reset_scene_to_default,
        mode="reset",
        params={"reset_joint_targets": True},
    )

@configclass
class EnvCfg(ManagerBasedRLEnvCfg):
    scene = SceneCfg(num_envs=32, env_spacing=1.5)
    observations = ObservationsCfg()
    rewards = RewardsCfg()
    terminations = TerminationsCfg()
    actions = ActionsCfg()
    events = EventCfg()
    viewer: ViewerCfg = ViewerCfg(eye=(-1.5, 0, 1.45), origin_type="world", env_index=0, asset_name="scenario")

    def __post_init__(self):
        self.decimation = 5
        self.episode_length_s = 45.0
        # simulation settings
        self.sim.dt = 1 / (10.0 * self.decimation)
        
        # Contact and solver settings
        self.sim.physx.solver_type = 1
        self.sim.physx.max_position_iteration_count = 192
        self.sim.physx.max_velocity_iteration_count = 1
        self.sim.physx.bounce_threshold_velocity = 0.02
        self.sim.physx.friction_offset_threshold = 0.01
        self.sim.physx.friction_correlation_distance = 0.0005

        self.sim.physx.gpu_found_lost_aggregate_pairs_capacity = 1024 * 1024 * 4
        self.sim.physx.gpu_total_aggregate_pairs_capacity = 2**23
        self.sim.physx.gpu_max_rigid_contact_count = 2**23
        self.sim.physx.gpu_max_rigid_patch_count = 2**23
        self.sim.physx.gpu_collision_stack_size = 2**31

        # Render settings
        self.sim.render.enable_dlssg = True
        self.sim.render.enable_ambient_occlusion = True
        self.sim.render.enable_reflections = True
        self.sim.render.enable_dl_denoiser = True

        self.wait_for_textures = True
        self.rerender_on_reset = True



@configclass
class LBMIKRolloutCfg(EnvCfg):

    def __post_init__(self):
        super().__post_init__()
        self.actions.left_panda_arm.controller.use_relative_mode = False
        self.actions.right_panda_arm.controller.use_relative_mode = False


@configclass
class LBMImplicityIKRolloutCfg(EnvCfg):
    def __post_init__(self):
        super().__post_init__()

        self.scene.left_panda = IMPLICIT_PANDA.replace(
            prim_path="{ENV_REGEX_NS}/left_panda",
            init_state=ArticulationCfg.InitialStateCfg(
                pos=(-0.5937, -0.34362, -0.08484),
                rot=(0.36811, 0.01027, 0.00078, 0.92973),
                joint_pos=LEFT_PANDA_DEFAULT_JOINT_POS
            )
        )
        self.scene.right_panda = IMPLICIT_PANDA.replace(
            prim_path="{ENV_REGEX_NS}/right_panda",
            init_state=ArticulationCfg.InitialStateCfg(
                pos=(-0.5937, 0.32962, -0.08062),
                rot=(0.91675, 0.01312, 0.01089, 0.39909),
                joint_pos=RIGHT_PANDA_DEFAULT_JOINT_POS,
            )
        )
        self.scene.left_panda.spawn.rigid_props.disable_gravity = True
        self.scene.right_panda.spawn.rigid_props.disable_gravity = True

        self.actions.left_panda_arm = ARM_IK_ACTION.replace(asset_name="left_panda")
        self.actions.right_panda_arm = ARM_IK_ACTION.replace(asset_name="right_panda")
        self.actions.left_panda_gripper = GRIPPER_ACTION.replace(asset_name="left_panda")
        self.actions.right_panda_gripper = GRIPPER_ACTION.replace(asset_name="right_panda")




@configclass
class CustomActionsCfg():
    left_panda_arm = LBM_JOINT_EFFORT.replace(asset_name="left_panda")
    left_panda_gripper = GRIPPER_ACTION.replace(asset_name="left_panda")
    right_panda_arm = LBM_JOINT_EFFORT.replace(asset_name="right_panda")
    right_panda_gripper = GRIPPER_ACTION.replace(asset_name="right_panda")
@configclass
class LBMJointEffortRolloutCfg(EnvCfg):
    def __post_init__(self):
        self.actions = CustomActionsCfg()

