import json
import torch
import numpy as np
from isaaclab.utils.math import quat_from_euler_xyz, quat_mul, quat_inv, matrix_from_quat, quat_from_matrix, euler_xyz_from_quat, convert_camera_frame_orientation_convention

def world_pose_to_opengl_pose(world_pose):
    """
    world_pose: (..., 4, 4)
    Returns:
        positions: (..., 3)
        orientations: (..., 4)
        convention: "opengl"
    """
    if world_pose.ndim == 2:
        world_pose = world_pose[None]
    R_matrix = world_pose[..., :3, :3]
    t = world_pose[..., :3, 3]
    quat = quat_from_matrix(R_matrix)
    roll_correction = quat_from_euler_xyz(torch.tensor([-np.pi], dtype=torch.float32), 
                                          torch.tensor([0.0], dtype=torch.float32), 
                                          torch.tensor([0.0], dtype=torch.float32))
    quat = quat_mul(quat, roll_correction.repeat(quat.shape[0], 1))
    return t, quat, "opengl"

def overwrite_camera_pose(env_cfg, camera_pose, camera_name="external_cam1"):
    R_matrix = camera_pose[:3, :3]
    t = camera_pose[:3, 3]
    quat = quat_from_matrix(torch.from_numpy(R_matrix))
    roll_correction = quat_from_euler_xyz(torch.tensor([-np.pi], dtype=torch.float32), 
                                          torch.tensor([0.0], dtype=torch.float32), 
                                          torch.tensor([0.0], dtype=torch.float32))
    quat = quat_mul(quat[None], roll_correction)

    getattr(env_cfg.scene, camera_name).offset.pos = tuple(t.reshape(-1))
    getattr(env_cfg.scene, camera_name).offset.rot = tuple(quat.numpy().reshape(-1))
    getattr(env_cfg.scene, camera_name).offset.convention = "opengl"

def get_camera_pose(pos, quat, convention="opengl"):
    assert convention == "opengl", "Only opengl convention is supported for now"
    # Undo roll correction
    roll_correction = quat_from_euler_xyz(torch.tensor([-np.pi], dtype=torch.float32),
                                        torch.tensor([0.0], dtype=torch.float32),
                                        torch.tensor([0.0], dtype=torch.float32))
    quat = quat_mul(torch.from_numpy(quat)[None], quat_inv(roll_correction))
    
    # Convert quaternion to rotation matrix
    R_matrix = matrix_from_quat(quat.reshape(-1))
    
    # Construct camera pose matrix
    camera_pose = np.eye(4)
    camera_pose[:3, :3] = R_matrix
    camera_pose[:3, 3] = pos
    
    return camera_pose

def overwrite_joint_positions(env_cfg, joint_positions):
    # joint_positions (8,), indicating 7 joints and 1 finger joint
    for joint_ind, joint_pos in enumerate(joint_positions[:-1]):
        env_cfg.scene.robot.init_state.joint_pos[f"panda_joint{joint_ind+1}"] = joint_pos
    # TODO: overwrite finger joint position