import numpy as np

################# Franka Panda Forward Kinematics ##############################



def get_fk_solution(joint_angles):
    def get_tf_mat(i, dh):
        a = dh[i][0]
        d = dh[i][1]
        alpha = dh[i][2]
        theta = dh[i][3]
        q = theta

        return np.array([[np.cos(q), -np.sin(q), 0, a],
                        [np.sin(q) * np.cos(alpha), np.cos(q) * np.cos(alpha), -np.sin(alpha), -np.sin(alpha) * d],
                        [np.sin(q) * np.sin(alpha), np.cos(q) * np.sin(alpha), np.cos(alpha), np.cos(alpha) * d],
                        [0, 0, 0, 1]])
    dh_params = [[0, 0.333, 0, joint_angles[0]],
                 [0, 0, -np.pi/2, joint_angles[1]],
                 [0, 0.316, np.pi/2, joint_angles[2]],
                 [0.0825, 0, np.pi/2, joint_angles[3]],
                 [-0.0825, 0.384, -np.pi/2, joint_angles[4]],
                 [0, 0, np.pi/2, joint_angles[5]],
                 [0.088, 0, np.pi/2, joint_angles[6]],
                 [0, 0.107, 0, 0],
                 [0, 0, 0, -np.pi/4],
                 [0.0, 0.1034, 0, 0]]

    T = np.eye(4)
    for i in range(7 + 1):
        T = T @ get_tf_mat(i, dh_params)
    return T
  


####################### For key board control task ##############################

def key_board_control(pose1, action, task_id='1799',distance=0.08):
    # input pose1 (1,7)
    # action: left right forward backward up down open close

    action_chunk = []
    delata = np.zeros((1,7))
    if action == 'o': # open gripper
        delata[0,6] =  -pose1[0,6] + 0.0
    elif action == 'c': # close gripper
        delata[0,6] =  -pose1[0,6] + 0.70
    elif action == 'l': # left
        delata[0,1] = -distance
    elif action == 'r': # right
        delata[0,1] = +distance
    elif action == 'f': # forward
        delata[0,0] = +distance
    elif action == 'b': # backward
        delata[0,0] = -distance
    elif action == 'u': # up
        delata[0,2] = +distance
    elif action == 'd': # down
        delata[0,2] = -distance
    else:
        print("wrong action key, please use l,r,f,b,u,d,o,c")
        delata[0,:] = 0.0

    action_chunk.append(pose1)
    for i in range(4):
        action_chunk.append(pose1 + delata*(i+1)/4.0)
    action_chunk = np.concatenate(action_chunk, axis=0) # (5, 7)

    
    if task_id == '1799':
        # special process when moving down, other wise too hard to use key board to grasp the marker.
        if np.abs(pose1[0,0]- 0.657)<0.01 and np.abs(pose1[0,1]- -0.233)<0.01 and np.abs(pose1[0,2]- 0.148)<0.01 and action == 'd':
            action_chunk = np.array([
                [0.6571552157402039, -0.23330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
                [0.6629241108894348, -0.23121032118797302, 0.1249917671084404, 3.033341884613037, -0.07552512735128403, -0.5513433814048767, 0.0],
                [0.6683378219604492, -0.23391945660114288, 0.10137390345335007, 3.064735174179077, -0.05264944210648537, -0.5547294616699219, 0.0],
                [0.6717594265937805, -0.24730165302753448, 0.08032714575529099, 3.097148895263672, -0.04734605923295021, -0.550068199634552, 0.0],
                [0.6743124723434448, -0.264811635017395, 0.06130369007587433, 3.118835210800171, -0.05084453523159027, -0.5484786629676819, 0.0],])
            action_chunk[:,-1] = pose1[0,-1]
        elif np.abs(pose1[0,0]- 0.674)<0.01 and np.abs(pose1[0,1]- -0.264)<0.01 and np.abs(pose1[0,2]- 0.061)<0.01 and action == 'u':
            action_chunk = np.array([
                [0.6743124723434448, -0.264811635017395, 0.06130369007587433, 3.118835210800171, -0.05084453523159027, -0.5484786629676819, 0.0],
                [0.6717594265937805, -0.24730165302753448, 0.08032714575529099, 3.097148895263672, -0.04734605923295021, -0.550068199634552, 0.0],
                [0.6683378219604492, -0.23391945660114288, 0.10137390345335007, 3.064735174179077, -0.05264944210648537, -0.5547294616699219, 0.0],
                [0.6629241108894348, -0.23121032118797302, 0.1249917671084404, 3.033341884613037, -0.07552512735128403, -0.5513433814048767, 0.0],
                [0.6571552157402039, -0.23330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],])
            action_chunk[:,-1] = pose1[0,-1]
        elif np.abs(pose1[0,0]- 0.674)<0.01 and np.abs(pose1[0,1]- -0.264)<0.01 and np.abs(pose1[0,2]- 0.061)<0.01 and action == 'd':
            action_chunk = np.array([[0.6743124723434448, -0.264811635017395, 0.06130369007587433, 3.118835210800171, -0.05084453523159027, -0.5484786629676819, 0.0],
                [0.6779722571372986, -0.27722054719924927, 0.04469338804483414, 3.1374242305755615, -0.05760820955038071, -0.5545875430107117, 0.0],
                [0.6827353239059448, -0.27674779295921326, 0.030309824272990227, -3.1302576065063477, -0.058622077107429504, -0.5516106486320496, 0.0],
                [0.6880980134010315, -0.2685723602771759, 0.019759373739361763, -3.1161246299743652, -0.052770618349313736, -0.5417410731315613, 0.0],
                [0.691747784614563, -0.262683629989624, 0.012840778566896915, -3.112227201461792, -0.04243526607751846, -0.5334827303886414, 0.0],])
            action_chunk[:,-1] = pose1[0,-1]
        elif np.abs(pose1[0,0]- 0.691)<0.01 and np.abs(pose1[0,1]- -0.262)<0.01 and np.abs(pose1[0,2]- 0.0128)<0.01 and action == 'u':
            action_chunk = np.array([[0.691747784614563, -0.262683629989624, 0.012840778566896915, -3.112227201461792, -0.04243526607751846, -0.5334827303886414, 0.0],
                [0.6880980134010315, -0.2685723602771759, 0.019759373739361763, -3.1161246299743652, -0.052770618349313736, -0.5417410731315613, 0.0],
                [0.6827353239059448, -0.27674779295921326, 0.030309824272990227, -3.1302576065063477, -0.058622077107429504, -0.5516106486320496, 0.0],
                [0.6779722571372986, -0.27722054719924927, 0.04469338804483414, 3.1374242305755615, -0.05760820955038071, -0.5545875430107117, 0.0],
                [0.6743124723434448, -0.264811635017395, 0.06130369007587433, 3.118835210800171, -0.05084453523159027, -0.5484786629676819, 0.0],])
            action_chunk[:,-1] = pose1[0,-1]
    
    # action limit
    # 0.6571552157402039, -0.23330219089984894, 0.14893116056919098
    # print('before clip:', action_chunk.shape, action_chunk)
    action_chunk[:,0] = np.clip(action_chunk[:,0], 0.3, 0.8)
    action_chunk[:,1] = np.clip(action_chunk[:,1], -0.5, 0.5)
    action_chunk[:,2] = np.clip(action_chunk[:,2], 0.01, 0.5)

    return action_chunk