# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

import math

import isaaclab.sim as sim_utils
from isaaclab.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg
from isaaclab.devices import DevicesCfg
from isaaclab.devices.gamepad import Se3GamepadCfg
from isaaclab.devices.keyboard import Se3KeyboardCfg
from isaaclab.devices.spacemouse import Se3SpaceMouseCfg
from isaaclab.envs import ManagerBasedRLEnvCfg
from isaaclab.managers import ActionTermCfg as ActionTerm
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.managers import RewardTermCfg as RewTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.managers import TerminationTermCfg as DoneTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.utils import configclass
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR
from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise

# import isaaclab_tasks.manager_based.manipulation.reach.mdp as mdp

from isaaclab_assets import FRANKA_PANDA_HIGH_PD_CFG  # isort: skip
# from polaris.environments.robot_cfg import NVIDIA_DROID
from isaaclab.markers.config import FRAME_MARKER_CFG
from isaaclab.sensors.frame_transformer.frame_transformer_cfg import (
    FrameTransformerCfg,
    OffsetCfg,
)
import torch

import sim_improvement.environments.mdp as mdp

##
# Scene definition
##


@configclass
class ReachSceneCfg(InteractiveSceneCfg):
    """Configuration for the scene with a robotic arm."""

    # world
    ground = AssetBaseCfg(
        prim_path="/World/ground",
        spawn=sim_utils.GroundPlaneCfg(),
        init_state=AssetBaseCfg.InitialStateCfg(pos=(0.0, 0.0, -1.05)),
    )

    table = AssetBaseCfg(
        prim_path="{ENV_REGEX_NS}/Table",
        spawn=sim_utils.UsdFileCfg(
            usd_path=f"{ISAAC_NUCLEUS_DIR}/Props/Mounts/SeattleLabTable/table_instanceable.usd",
        ),
        init_state=AssetBaseCfg.InitialStateCfg(
            pos=(0.55, 0.0, 0.0), rot=(0.70711, 0.0, 0.0, 0.70711)
        ),
    )

    icecream = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/Icecream",
        spawn=sim_utils.UsdFileCfg(
            usd_path="/home/arhanjain/projects/sim-improvement/PolaRiS-Hub/lightwheel_g60/RA_LW_Assets_20251203/Kitchen_Scene/Assets/Toy005/Toy005.usd",
        ),
        init_state=RigidObjectCfg.InitialStateCfg(
            pos=(0.35, 0.0, 0.15), rot=(0.0, 0.0, 0.0, 1.0)
        ),
    )

    # robots
    # robot: ArticulationCfg = NVIDIA_DROID.replace(prim_path="{ENV_REGEX_NS}/Robot")
    robot: ArticulationCfg = FRANKA_PANDA_HIGH_PD_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")

    # lights
    light = AssetBaseCfg(
        prim_path="/World/light",
        spawn=sim_utils.DomeLightCfg(color=(0.75, 0.75, 0.75), intensity=2500.0),
    )

    def __post_init__(
        self,
    ):
        marker_cfg = FRAME_MARKER_CFG.copy()
        marker_cfg.markers["frame"].scale = (0.1, 0.1, 0.1)
        marker_cfg.prim_path = "/Visuals/FrameTransformer"
        self.ee_frame = FrameTransformerCfg(
            prim_path="{ENV_REGEX_NS}/Robot/panda_link0",
            debug_vis=True,
            visualizer_cfg=marker_cfg,
            target_frames=[
                FrameTransformerCfg.FrameCfg(
                    # prim_path="{ENV_REGEX_NS}/Robot/Gripper/Robotiq_2F_85/base_link",
                    prim_path="{ENV_REGEX_NS}/Robot/panda_hand",
                    name="end_effector",
                    offset=OffsetCfg(
                        pos=[0.0, 0.0, 0.0],
                    ),
                ),
            ],
        )


##
# MDP settings
##


@configclass
class CommandsCfg:
    """Command terms for the MDP."""

    # ee_pose = mdp.UniformPoseCommandCfg(
    #     asset_name="robot",
    #     body_name="panda_hand",
    #     resampling_time_range=(4.0, 4.0),
    #     debug_vis=True,
    #     ranges=mdp.UniformPoseCommandCfg.Ranges(
    #         pos_x=(0.35, 0.65),
    #         pos_y=(-0.2, 0.2),
    #         pos_z=(0.15, 0.5),
    #         roll=(0.0, 0.0),
    #         pitch=(math.pi / 2, math.pi),
    #         yaw=(-3.14, 3.14),
    #     ),
    # )


class BinaryJointPositionZeroToOneAction(mdp.BinaryJointPositionAction):
    # override
    def process_actions(self, actions: torch.Tensor):
        # store the raw actions
        self._raw_actions[:] = actions
        # compute the binary mask
        if actions.dtype == torch.bool:
            # true: close, false: open
            binary_mask = actions == 0
        else:
            # true: close, false: open
            binary_mask = actions > 0.5
        # compute the command
        self._processed_actions = torch.where(
            binary_mask, self._close_command, self._open_command
        )
        if self.cfg.clip is not None:
            self._processed_actions = torch.clamp(
                self._processed_actions,
                min=self._clip[:, :, 0],
                max=self._clip[:, :, 1],
            )
@configclass
class BinaryJointPositionZeroToOneActionCfg(mdp.BinaryJointPositionActionCfg):
    """Configuration for the binary joint position action term.

    See :class:`BinaryJointPositionAction` for more details.
    """

    class_type = BinaryJointPositionZeroToOneAction
@configclass
class ActionsCfg:
    """Action specifications for the MDP."""
    # Set actions for the specific robot type (franka)
    arm_action = mdp.DifferentialInverseKinematicsActionCfg(
        asset_name="robot",
        joint_names=["panda_joint.*"],
        # body_name="base_link",
        body_name="panda_hand",
        # controller=mdp.DifferentialIKControllerCfg(command_type="pose", use_relative_mode=False, ik_method="dls"),
        controller=mdp.DifferentialIKControllerCfg(command_type="pose", use_relative_mode=True, ik_method="dls"),
        body_offset=mdp.DifferentialInverseKinematicsActionCfg.OffsetCfg(pos=(0.0, 0.0, 0.107)),
    )

    gripper_action = BinaryJointPositionZeroToOneActionCfg(
        asset_name="robot",
        # joint_names=["finger_joint"],
        # open_command_expr={"finger_joint": 0.0},
        # close_command_expr={"finger_joint": math.pi / 4},
        joint_names=["panda_finger.*"],
        open_command_expr={"panda_finger_.*": 0.04},
        close_command_expr={"panda_finger_.*": 0.0},
    )

def ee_pose_obs(env) -> torch.Tensor:
    ee_frame = env.scene["ee_frame"]
    ee_pos = ee_frame.data.target_pos_w[
        :, 0
    ]  # get the position of the end effector (from world frame)
    ee_quat = ee_frame.data.target_quat_w[
        :, 0
    ]  # get the orientation of the end effector (from world frame)
    return torch.cat([ee_pos, ee_quat], dim=-1)

@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""

    @configclass
    class PolicyCfg(ObsGroup):
        """Observations for policy group."""

        # observation terms (order preserved)
        joint_pos = ObsTerm(
            func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)
        )
        joint_vel = ObsTerm(
            func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.01, n_max=0.01)
        )
        # pose_command = ObsTerm(
        #     func=mdp.generated_commands, params={"command_name": "ee_pose"}
        # )
        actions = ObsTerm(func=mdp.last_action)

        ee_pose = ObsTerm(
            func=ee_pose_obs,
        )

        icecream_pose = ObsTerm(
            func = mdp.object_pose_in_robot_root_frame,
            params = {
                "robot_cfg": SceneEntityCfg("robot"),
                "object_cfg": SceneEntityCfg("icecream"),
            },
        )

        def __post_init__(self):
            self.enable_corruption = True
            self.concatenate_terms = True

    # observation groups
    policy: PolicyCfg = PolicyCfg()


@configclass
class EventCfg:
    """Configuration for events."""

    reset_all = EventTerm(
        func=mdp.reset_scene_to_default,
        mode="reset",
    )

    reset_from_state_dataset = EventTerm(
        func=mdp.reset_from_state_dataset("./demonstrations/ice_cream_reset_states/trajectory.hdf5"),
        mode="reset",
        params={
            # "dataset": "./demonstrations/ice_cream_reset_states/trajectory.hdf5",
        },
    )






@configclass
class RewardsCfg:
    """Reward terms for the MDP."""

    reaching_object = RewTerm(
        func=mdp.object_ee_distance, 
        params={
            "std": 0.1,
            "object_cfg": SceneEntityCfg("icecream"),
            "ee_frame_cfg": SceneEntityCfg("ee_frame"),
            }, 
        weight=1.0
        )

    lift = RewTerm(
        func=mdp.lift_reward,
        params={
            "object_cfg": SceneEntityCfg("icecream"),
        },
        weight=20.0,
    )
    # action penalty
    action_rate = RewTerm(func=mdp.action_rate_l2, weight=-0.0001)
    joint_vel = RewTerm(
        func=mdp.joint_vel_l2,
        weight=-0.0001,
        params={"asset_cfg": SceneEntityCfg("robot")},
    )


@configclass
class TerminationsCfg:
    """Termination terms for the MDP."""

    time_out = DoneTerm(func=mdp.time_out, time_out=True)

    fall_below = DoneTerm(
        func=mdp.root_height_below_minimum,
        params={
            "asset_cfg": SceneEntityCfg("icecream"),
            "minimum_height": 0.0,
        },
    )


@configclass
class CurriculumCfg:
    """Curriculum terms for the MDP."""

    # action_rate = CurrTerm(
    #     func=mdp.modify_reward_weight, params={"term_name": "action_rate", "weight": -0.005, "num_steps": 4500}
    # )

    # joint_vel = CurrTerm(
    #     func=mdp.modify_reward_weight, params={"term_name": "joint_vel", "weight": -0.001, "num_steps": 4500}
    # )


##
# Environment configuration
##


@configclass
class ReachEnvCfg(ManagerBasedRLEnvCfg):
    """Configuration for the reach end-effector pose tracking environment."""

    # Scene settings
    scene = ReachSceneCfg(num_envs=4096, env_spacing=2.5)
    # Basic settings
    observations = ObservationsCfg()
    actions = ActionsCfg()
    commands = CommandsCfg()
    # MDP settings
    rewards = RewardsCfg()
    terminations = TerminationsCfg()
    events = EventCfg()
    curriculum = CurriculumCfg()

    def __post_init__(self):
        """Post initialization."""
        # general settings
        self.decimation = 4
        self.sim.render_interval = self.decimation
        self.episode_length_s = 5.0
        self.viewer.eye = (10, 1, 3)
        self.viewer.lookat = (0, 0, 0)
        # simulation settings
        self.sim.dt = 1.0 / (15.0 * self.decimation)
        self.sim.physx.gpu_max_rigid_patch_count = 10 * 2**15

        self.teleop_devices = DevicesCfg(
            devices={
                "keyboard": Se3KeyboardCfg(
                    gripper_term=False,
                    sim_device=self.sim.device,
                ),
                "gamepad": Se3GamepadCfg(
                    gripper_term=False,
                    sim_device=self.sim.device,
                ),
                "spacemouse": Se3SpaceMouseCfg(
                    gripper_term=False,
                    sim_device=self.sim.device,
                ),
            },
        )

@configclass
class PnpSceneCfg(ReachSceneCfg):
    icecream = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/icecream",
        spawn=sim_utils.UsdFileCfg(
            usd_path="/home/arhanjain/projects/sim-improvement/PolaRiS-Hub/lightwheel_g60/RA_LW_Assets_20251203/Kitchen_Scene/Assets/Toy005/Toy005.usd",
        ),
        init_state=RigidObjectCfg.InitialStateCfg(
            pos=(0.35, 0.0, 0.15), rot=(0.0, 0.0, 0.0, 1.0)
        ),
    )

    bowl = RigidObjectCfg(
        prim_path="{ENV_REGEX_NS}/bowl",
        spawn=sim_utils.UsdFileCfg(
            usd_path="/home/arhanjain/projects/sim-improvement/PolaRiS-Hub/lightwheel_g60/RA_LW_Assets_20251203/Kitchen_Scene/Assets/Bowl056/Bowl056.usd",
        ),
        init_state=RigidObjectCfg.InitialStateCfg(
            pos=(0.45, 0.0, 0.15), rot=(1.0, 0.0, 0.0, 0.0)
        ),
    )

@configclass
class PnpRewardsCfg():
    """Reward terms for the PnP environment."""
    icecream_in_bowl = RewTerm(
        func=mdp.is_within_xy,
        weight=50.0,
        params={
            "object1": "icecream", 
            "object2": "bowl", 
            "percent_threshold": 0.6, 
            "open_finger_threshold": 0.02
            },
    )

    # object_to_object_distance = RewTerm(
    #     func=mdp.object_to_object_distance,
    #     weight=1.0,
    #     params={
    #         "std": 0.2,
    #         "object1_cfg": SceneEntityCfg("icecream"),
    #         "object2_cfg": SceneEntityCfg("bowl"),
    #     },
    # )

    # reaching_object = RewTerm(
    #     func=mdp.object_ee_distance,
    #     weight=1.0,
    #     params={
    #         "std": 0.1,
    #         "object_cfg": SceneEntityCfg("icecream"),
    #         "ee_frame_cfg": SceneEntityCfg("ee_frame"),
    #     },
    # )



@configclass
class PnpEventsCfg():
    """Event terms for the PnP environment."""
    reset_all = EventTerm(
        func=mdp.reset_scene_to_default,
        mode="reset",
    )
    reset_from_state_dataset = EventTerm(
        func=mdp.reset_from_state_dataset("./demonstrations/icecream-to-bowl/trajectory.hdf5"),
        mode="reset",
    )

@configclass
class PnpObservationsCfg:
    """Observation specifications for the MDP."""

    @configclass
    class PolicyCfg(ObsGroup):
        """Observations for policy group."""

        # observation terms (order preserved)
        joint_pos = ObsTerm(
            func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)
        )
        joint_vel = ObsTerm(
            func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.01, n_max=0.01)
        )
        actions = ObsTerm(func=mdp.last_action)

        ee_pose = ObsTerm(
            func=ee_pose_obs,
        )

        icecream_pose = ObsTerm(
            func = mdp.object_pose_in_robot_root_frame,
            params = {
                "robot_cfg": SceneEntityCfg("robot"),
                "object_cfg": SceneEntityCfg("icecream"),
            },
        )

        bowl_pose = ObsTerm(
            func = mdp.object_pose_in_robot_root_frame,
            params = {
                "robot_cfg": SceneEntityCfg("robot"),
                "object_cfg": SceneEntityCfg("bowl"),
            },
        )

        def __post_init__(self):
            self.enable_corruption = True
            self.concatenate_terms = True

    # observation groups
    policy: PolicyCfg = PolicyCfg()

@configclass
class PnpTerminationsCfg:
    """Termination terms for the PnP environment."""
    time_out = DoneTerm(func=mdp.time_out, time_out=True)
    icecream_below = DoneTerm(
        func=mdp.root_height_below_minimum,
        params={
            "asset_cfg": SceneEntityCfg("icecream"),
            "minimum_height": 0.0,
        },
    )
    bowl_below = DoneTerm(
        func=mdp.root_height_below_minimum,
        params={
            "asset_cfg": SceneEntityCfg("bowl"),
            "minimum_height": 0.0,
        },
    )

@configclass
class PnpEnvCfg(ReachEnvCfg):
    """Configuration for the PnP environment."""

    scene = PnpSceneCfg(num_envs=4096, env_spacing=2.5)
    rewards = PnpRewardsCfg()
    events = PnpEventsCfg()
    observations = PnpObservationsCfg()
    terminations = PnpTerminationsCfg()


    def __post_init__(self):
        super().__post_init__()

        self.sim.physx.gpu_max_rigid_patch_count = 2**19



