import torch
import h5py
import numpy as np
import os
from pathlib import Path

from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedEnv
from isaaclab.assets.articulation import Articulation
from isaaclab.assets.rigid_object import RigidObject
from isaaclab.managers import ManagerTermBase, EventTermCfg
from isaaclab.managers import SceneEntityCfg
from isaaclab.utils.math import sample_uniform
from isaaclab.envs.mdp.actions.actions_cfg import DifferentialInverseKinematicsActionCfg, DifferentialIKControllerCfg
from isaaclab.utils import math as math_utils



class reset_from_state_dataset(ManagerTermBase):
    # def __init__(self, cfg: EventTermCfg, env: ManagerBasedRLEnv):
    def __init__(self, dataset_path: str):
        # super().__init__(cfg, env)

        # self.dataset_path = cfg.params["dataset"]
        self.dataset_path = dataset_path
        dataset = h5py.File(self.dataset_path, "r")
        self.all_states = {}
        for demo in dataset["data"].keys():  # type: ignore
            demo_states_path = f"data/{demo}/states"
            if demo_states_path not in dataset:
                continue
            states_group = dataset[demo_states_path]  # type: ignore
            for entry in states_group.keys():  # type: ignore
                if entry not in self.all_states:
                    self.all_states[entry] = []
                self.all_states[entry].append(states_group[entry][()])  # type: ignore
        dataset.close()
        self.all_states = {k: torch.from_numpy(np.concatenate(v, axis=0)) for k, v in self.all_states.items()}
        self.num_states = len(self.all_states[next(iter(self.all_states.keys()))])

        print(f"Loaded {self.num_states} states from dataset {self.dataset_path}")

    def __call__(
        self,
        env: ManagerBasedRLEnv,
        env_ids: torch.Tensor | None,
        # dataset: str,
    ):
        # print("Resetting states from dataset...")
        # resolve environment ids
        if env_ids is None:
            env_ids = torch.arange(env.scene.num_envs, device=env.device)
        else:
            env_ids = env_ids
        state_idx = torch.randint(0, self.num_states, (len(env_ids),))
        for key, value in self.all_states.items():
            if isinstance(env.scene[key], Articulation):
                env.scene[key].write_joint_position_to_sim(value[state_idx].to(env.device), env_ids=env_ids)
            elif isinstance(env.scene[key], RigidObject):
                pose = value[state_idx].to(env.device)[..., :7]
                pose[..., :3] = pose[..., :3] + env.scene.env_origins[env_ids]
                env.scene[key].write_root_pose_to_sim(pose, env_ids=env_ids) # for now ignore velocities

class reset_end_effector_round_fixed_asset(ManagerTermBase):
    def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
        fixed_asset_cfg: SceneEntityCfg = cfg.params.get("fixed_asset_cfg")  # type: ignore
        fixed_asset_offset: Offset = cfg.params.get("fixed_asset_offset")  # type: ignore
        pose_range_b: dict[str, tuple[float, float]] = cfg.params.get("pose_range_b")  # type: ignore
        robot_ik_cfg: SceneEntityCfg = cfg.params.get("robot_ik_cfg", SceneEntityCfg("robot"))

        range_list = [pose_range_b.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
        self.ranges = torch.tensor(range_list, device=env.device)
        self.fixed_asset: Articulation | RigidObject = env.scene[fixed_asset_cfg.name]
        self.fixed_asset_offset = fixed_asset_offset
        self.robot: Articulation = env.scene[robot_ik_cfg.name]
        self.joint_ids: list[int] | slice = robot_ik_cfg.joint_ids
        self.n_joints: int = self.robot.num_joints if isinstance(self.joint_ids, slice) else len(self.joint_ids)
        robot_ik_solver_cfg = DifferentialInverseKinematicsActionCfg(
            asset_name=robot_ik_cfg.name,
            joint_names=robot_ik_cfg.joint_names,  # type: ignore
            body_name=robot_ik_cfg.body_names,  # type: ignore
            controller=DifferentialIKControllerCfg(command_type="pose", use_relative_mode=False, ik_method="dls"),
            scale=1.0,
        )
        self.solver: DifferentialInverseKinematicsAction = robot_ik_solver_cfg.class_type(robot_ik_solver_cfg, env)  # type: ignore
        self.reset_velocity = torch.zeros((env.num_envs, self.robot.data.joint_vel.shape[1]), device=env.device)
        self.reset_position = torch.zeros((env.num_envs, self.robot.data.joint_pos.shape[1]), device=env.device)

    def __call__(
        self,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        fixed_asset_cfg: SceneEntityCfg,
        fixed_asset_offset: None,
        pose_range_b: dict[str, tuple[float, float]],
        robot_ik_cfg: SceneEntityCfg,
    ) -> None:
        if fixed_asset_offset is None:
            fixed_tip_pos_w, fixed_tip_quat_w = (
                env.scene[fixed_asset_cfg.name].data.root_pos_w,
                env.scene[fixed_asset_cfg.name].data.root_quat_w,
            )
        else:
            fixed_tip_pos_w, fixed_tip_quat_w = self.fixed_asset_offset.apply(self.fixed_asset)

        samples = math_utils.sample_uniform(self.ranges[:, 0], self.ranges[:, 1], (env.num_envs, 6), device=env.device)
        pos_b, quat_b = self.solver._compute_frame_pose()
        # for those non_reset_id, we will let ik solve for its current position
        pos_w = fixed_tip_pos_w + samples[:, 0:3]
        quat_w = math_utils.quat_from_euler_xyz(samples[:, 3], samples[:, 4], samples[:, 5])
        pos_b, quat_b = math_utils.subtract_frame_transforms(
            self.robot.data.root_link_pos_w, self.robot.data.root_link_quat_w, pos_w, quat_w
        )
        self.solver.process_actions(torch.cat([pos_b, quat_b], dim=1))

        # Error Rate 75% ^ 10 = 0.05 (final error)
        for i in range(10):
            self.solver.apply_actions()
            delta_joint_pos = 0.25 * (self.robot.data.joint_pos_target[env_ids] - self.robot.data.joint_pos[env_ids])
            self.robot.write_joint_state_to_sim(
                position=(delta_joint_pos + self.robot.data.joint_pos[env_ids])[:, self.joint_ids],
                velocity=torch.zeros((len(env_ids), self.n_joints), device=env.device),
                joint_ids=self.joint_ids,
                env_ids=env_ids,  # type: ignore
            )

    
class MultiResetManager(ManagerTermBase):
    def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
        super().__init__(cfg, env)

        # Get hdf5 file paths - can be direct paths or base paths
        hdf5_paths: list[str] = cfg.params.get("hdf5_paths", [])  # type: ignore
        base_paths: list[str] = cfg.params.get("base_paths", [])  # type: ignore
        probabilities: list[float] = cfg.params.get("probs", [])  # type: ignore

        # Support both hdf5_paths (direct) and base_paths (for backward compatibility)
        if hdf5_paths:
            dataset_files = hdf5_paths
        elif base_paths:
            # If base_paths provided, assume they point to hdf5 files
            # For backward compatibility, you can still use base_paths but they should point to .hdf5 files
            dataset_files = base_paths
        else:
            raise ValueError("Either 'hdf5_paths' or 'base_paths' must be provided")

        if len(dataset_files) != len(probabilities):
            raise ValueError("Number of dataset files must match number of probabilities")

        # Load all datasets from hdf5 files
        self.all_states_list = []
        num_states = []
        
        for dataset_file in dataset_files:
            dataset_path = Path(dataset_file)
            if not dataset_path.exists():
                raise FileNotFoundError(f"Dataset file {dataset_file} not found.")

            # Load states from hdf5 file (similar to reset_from_state_dataset)
            dataset = h5py.File(dataset_path, "r")
            all_states = {}
            
            for demo in dataset["data"].keys():  # type: ignore
                demo_path = f"data/{demo}/states"
                if demo_path not in dataset:
                    continue
                    
                states_group = dataset[demo_path]  # type: ignore
                for entry in states_group.keys():  # type: ignore
                    if entry not in all_states:
                        all_states[entry] = []
                    all_states[entry].append(states_group[entry][()])  # type: ignore
            
            dataset.close()
            
            # Convert to torch tensors
            all_states_torch = {
                k: torch.from_numpy(np.concatenate(v, axis=0)) 
                for k, v in all_states.items()
            }
            
            if not all_states_torch:
                raise ValueError(f"No states found in dataset {dataset_file}")
            
            # Get number of states (all should have same length)
            num_states_for_dataset = len(all_states_torch[next(iter(all_states_torch.keys()))])
            num_states.append(num_states_for_dataset)
            self.all_states_list.append(all_states_torch)
            
            print(f"Loaded {num_states_for_dataset} states from dataset {dataset_file}")

        # Normalize probabilities and store dataset lengths
        # Store as CPU tensors initially, move to device when needed
        self.probs = torch.tensor(probabilities, device="cpu") / sum(probabilities)
        self.num_states = torch.tensor(num_states, device="cpu")
        self.num_tasks = len(self.all_states_list)

        # Initialize success monitor if provided
        self.success_monitor = None
        if cfg.params.get("success") is not None:
            # Try to import SuccessMonitorCfg if available
            try:
                from isaaclab.managers import SuccessMonitorCfg  # type: ignore
                success_monitor_cfg = SuccessMonitorCfg(
                    monitored_history_len=100, num_monitored_data=self.num_tasks, device=env.device
                )
                self.success_monitor = success_monitor_cfg.class_type(success_monitor_cfg)
            except (ImportError, AttributeError):
                print("Warning: SuccessMonitorCfg not available, success monitoring disabled")

        # Store task_id on CPU initially, move to device when needed
        self.task_id = torch.randint(0, self.num_tasks, (env.num_envs,), device="cpu")

    def __call__(
        self,
        env: ManagerBasedEnv,
        env_ids: torch.Tensor,
        hdf5_paths: list[str] | None = None,
        base_paths: list[str] | None = None,
        probs: list[float] | None = None,
        success: str | None = None,
    ) -> None:
        if env_ids is None:
            env_ids = torch.arange(env.num_envs, device=env.device)

        # Log current data if success monitor is available
        if success is not None and self.success_monitor is not None:
            success_mask = torch.where(eval(success)[env_ids], 1.0, 0.0)
            self.success_monitor.success_update(self.task_id[env_ids], success_mask)

            # Log metrics for each task
            success_rates = self.success_monitor.get_success_rate()
            if "log" not in env.extras:
                env.extras["log"] = {}
            for task_idx in range(self.num_tasks):
                env.extras["log"].update({
                    f"Metrics/task_{task_idx}_success_rate": success_rates[task_idx].item(),
                    f"Metrics/task_{task_idx}_prob": self.probs[task_idx].item(),
                    f"Metrics/task_{task_idx}_normalized_prob": self.probs[task_idx].item(),
                })

        # Sample which dataset to use for each environment
        # Ensure probs are on the correct device
        probs_on_device = self.probs.to(env.device)
        dataset_indices = torch.multinomial(probs_on_device, len(env_ids), replacement=True)
        # Ensure task_id is on the correct device before updating
        self.task_id = self.task_id.to(env.device)
        self.task_id[env_ids] = dataset_indices

        # Process each dataset's environments
        for dataset_idx in range(self.num_tasks):
            mask = dataset_indices == dataset_idx
            if not mask.any():
                continue

            current_env_ids = env_ids[mask]
            all_states = self.all_states_list[dataset_idx]
            
            # Sample random state indices for each environment
            num_states_int = int(self.num_states[dataset_idx].item())
            state_indices = torch.randint(
                0, num_states_int, (len(current_env_ids),), device="cpu"
            )
            
            # Apply states to each environment
            for key, value in all_states.items():
                # Skip if entity doesn't exist in scene
                try:
                    scene_obj = env.scene[key]
                except KeyError:
                    continue
                
                # Get states for the sampled indices
                # Convert indices to CPU for indexing (value is on CPU), then move result to device
                selected_states = value[state_indices.cpu()].to(env.device)
                
                if isinstance(scene_obj, Articulation):
                    # Joint positions for robot
                    scene_obj.write_joint_position_to_sim(selected_states, env_ids=current_env_ids)  # type: ignore
                elif isinstance(scene_obj, RigidObject):
                    # Pose (position + quaternion) for rigid objects
                    pose = selected_states[..., :7]
                    # Adjust position for environment origins
                    pose[..., :3] = pose[..., :3] + env.scene.env_origins[current_env_ids]
                    scene_obj.write_root_pose_to_sim(pose, env_ids=current_env_ids)  # type: ignore

        # Reset velocities
        robot: Articulation = env.scene["robot"]
        robot.set_joint_velocity_target(torch.zeros_like(robot.data.joint_vel[env_ids]), env_ids=env_ids)  # type: ignore

def launch_view_port(
    env: ManagerBasedEnv,
    env_ids: torch.Tensor,
    view_portname: str,
    camera_path: str,
    viewport_size: tuple[int, int] = (640, 360),
    position: tuple[int, int] = (0, 0),
):
    if env.sim.has_gui():
        from isaacsim.core.utils.viewports import create_viewport_for_camera, get_viewport_names

        if view_portname not in get_viewport_names():
            create_viewport_for_camera(
                viewport_name=view_portname,
                camera_prim_path=camera_path,
                width=viewport_size[0],
                height=viewport_size[1],
                position_x=position[0],
                position_y=position[1],
            )