import gradio as gr
from PIL import Image, ImageDraw
from collections import defaultdict
import numpy as np
import torch

import data,geometry
import pybullet_data
import pybullet as p
import fpsample

import argparse
parser = argparse.ArgumentParser(description="A simple example")
# pyb params
parser.add_argument("--gui",'-g', action="store_true") # no gui, just for debugging/printing
parser.add_argument("--test_views",action="store_true") # no gui, just for debugging/printing
parser.add_argument("--cube",action="store_true") # use cube instead of sphere as follower shape
parser.add_argument("--no_follower",'-nf', action="store_true") # no follower arm
# model params
parser.add_argument('--depth_noise', default=False, action='store_true',help="add depth noise")
parser.add_argument('--cam_noise', default=False, action='store_true',help="add cam noise")
parser.add_argument('--cam_pointmap_inp', default=False, action='store_true',help="use camera space pointmap")
parser.add_argument('--rob_pointmap_inp', default=False, action='store_true',help="use robot space pointmap")
parser.add_argument('--cam_pointcloud_inp', default=False, action='store_true',help="use camera space pointcloud")
parser.add_argument('--rob_pointcloud_inp', default=False, action='store_true',help="use robot space pointcloud")
parser.add_argument('--rob_rerender_inp', default=False, action='store_true',help="use robot space re-render image")
parser.add_argument('--endeffector_action', default=False, action='store_true',help="use endeffector prediction instead of joint angle prediction")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
args = parser.parse_args()

#args.cam_pointcloud_inp="cam" in args.init_ckpt
#args.rob_pointcloud_inp="rob_pc" in args.init_ckpt
#args.rob_rerender_inp="rerender" in args.init_ckpt
#args.endeffector_action="effect" in args.init_ckpt

import sys,os
sys.path.append("/home/cameronsmith/misc/")
physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setAdditionalSearchPath("/home/cameronsmith/misc")
from tmpfns import *

import vis
import models
torch.set_grad_enabled(False)
model = (models.PolicyModel)(args).cuda() 
model.eval()
ckpt_file = args.init_ckpt#"/tmp/redo_policy_rob_rerender_joints/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)

# Step 1: PyBullet scene setup
physicsClient = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath());p.setGravity(0, 0, -9.8)

physicsClient = p.connect(p.DIRECT)
p.resetDebugVisualizerCamera( cameraDistance=1.00, cameraYaw=242.74, cameraPitch= -12.06, cameraTargetPosition=[0.47, -0.31, 0.02])
p.setAdditionalSearchPath(pybullet_data.getDataPath())
robot_id=robot = p.loadURDF("/home/cameronsmith/misc/so_100_arm/urdf/so_100_arm.urdf", useFixedBase=True,flags=p.URDF_USE_SELF_COLLISION,basePosition=[0,0,0])
# red ball to follow
vis = p.createVisualShape(p.GEOM_SPHERE,   radius=0.03, rgbaColor=[1,0,0,1]) if not args.cube else p.createVisualShape(p.GEOM_BOX,   halfExtents=[0.03]*3, rgbaColor=[1,0,0,1])
ball_id = p.createMultiBody( baseMass=0, baseCollisionShapeIndex=-1, baseVisualShapeIndex=vis, basePosition=[.3, -.3, 0.1])
plane_id = p.loadURDF("plane.urdf")

# Make cameras
cameras=[]
yaw_set = [-30,30] if not args.test_views else [-15,15]
for yaw in yaw_set:
    view_matrix_ = pb.computeViewMatrixFromYawPitchRoll([0.0,0,0], .8,yaw,-37, 0, 2)
    cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F')
    cameras.append((view_matrix_,cam2world_cv))
view_matrix_,cam2world_cv=cameras[1] # todo rand choice

org_ballpos=np.array(p.getBasePositionAndOrientation(ball_id)[0])

savedir="evals/%s"%args.init_ckpt.split("/")[-2]
if args.test_views: savedir+="_viewpoint_gen"
if args.cube: savedir+="_obj_gen"
if args.depth_noise: savedir+="_cam_noise"
if args.cam_noise: savedir+="_depth_noise"
for scene_i in range(10): # number of evaluations/trials
    savedir_ep=savedir+"/%03d"%scene_i
    os.makedirs(savedir_ep,exist_ok=True)

    # Reset ball and joint states
    for j  in range(6):p.resetJointState(robot, j+1, 0)  # Reset state
    target_pos = [np.random.uniform(-.2,.2),np.random.uniform(-.2,-.1),np.random.uniform(.01,.2)]
    p.resetBasePositionAndOrientation(ball_id, target_pos, (0.0, 0.0, 0.0, 1.0));p.stepSimulation() # Randomize ball position
    org_ballpos=np.array(p.getBasePositionAndOrientation(ball_id)[0])

    for time_i in range(10): # number of timesteps to unroll

        # Render image and other modalities
        _, _, rgb, depth_ndc, seg = p.getCameraImage(width, height, viewMatrix=view_matrix_, projectionMatrix=proj_matrix_)

        rgb = rgb.reshape(height,width, 4)[..., :3].astype(np.float32)/255
        if args.depth_noise:rgb = np.clip(rgb+np.random.normal(0, 10/255, rgb.shape),0,1)
        plt.imsave(savedir_ep+"/%03d.png"%time_i,rgb)
        fg = seg>-1
        points_cam = depth_to_point_cloud((far * near) / (far - (far - near) * depth_ndc), fx, fy, cx, cy)
        points_rob_base = np.einsum('ij,nj->ni', np.linalg.inv(cam2world_cv), hom(points_cam))[:, :3]
        curr_joint_state= [x[0] for x in p.getJointStates(robot_id, list(range(p.getNumJoints(robot_id))))][1:]

        # save last state and curr joints as target
        n_points = 10000#128**2

        lims_rob_base=[[-.4,.4],[-.4,.1],[.03,.3]]
        mask_rob_base = np.arange(len(points_rob_base))[ ((np.linalg.norm(points_rob_base,axis=-1)<1.2)&(points_rob_base[:,0]>lims_rob_base[0][0])&
                                (points_rob_base[:,0]<lims_rob_base[0][1])&(points_rob_base[:,1]>lims_rob_base[1][0])&
                                (points_rob_base[:,1]<lims_rob_base[1][1])&(points_rob_base[:,2]>lims_rob_base[2][0])&(points_rob_base[:,2]<lims_rob_base[2][1])) ]
        #from pdb import set_trace as pdb_;pdb_() 
        if (n_points-mask_rob_base.shape[0])>0:
            mask_rob_base=np.concatenate([mask_rob_base,np.zeros(10+n_points-mask_rob_base.shape[0])]).astype(int)
        mask_rob_base = mask_rob_base[ fpsample.bucket_fps_kdtree_sampling(points_rob_base[mask_rob_base],n_points) ] 

        mask_points_cam = np.arange(len(points_cam))
        mask_points_cam = mask_points_cam[(np.linalg.norm(points_rob_base,axis=-1)<1.2)]
        print(mask_points_cam.shape)
        mask_points_cam = mask_points_cam[ fpsample.bucket_fps_kdtree_sampling(points_cam[mask_points_cam],n_points) ]

        if args.depth_noise:
            points_cam += np.random.normal(scale=(0.02 * np.linalg.norm(points_cam, axis=1))[:, None], size=points_cam.shape)
            points_rob_base += np.random.normal(scale=(0.02 * np.linalg.norm(points_rob_base, axis=1))[:, None], size=points_rob_base.shape)
        if args.cam_noise:
            # randomly shift/rotate points (camera noise)
            from scipy.spatial.transform import Rotation as R
            trans_std=0.04
            rot_deg_std=40.0
            axis = np.random.randn(3)
            axis /= np.linalg.norm(axis) + 1e-8
            angle = np.deg2rad(np.random.normal(scale=rot_deg_std))
            R_noise = R.from_rotvec(axis * angle).as_matrix()
            t_noise = np.random.normal(scale=trans_std, size=3)
            points_rob_base = (R_noise @ points_rob_base.T).T + t_noise
            points_cam = (R_noise @ points_cam.T).T + t_noise



        for color,imgname in [ (rgb,"rgb"),(points_rob_base*2*.5+.5,"pointmap") ][:1]:
            print("doing render")
            fig = plt.figure();ax = fig.add_subplot(111, projection='3d');ax.view_init(elev=15, azim=-45);s=1
            ax.set_xlim(*lims_rob_base[0]); ax.set_ylim(*lims_rob_base[1]); ax.set_zlim(*lims_rob_base[2])
            ax.scatter(*points_rob_base[mask_rob_base][::s].T, c=np.clip(color.reshape(-1,3)[mask_rob_base][::s],-1,1), 
                                marker='o', alpha=.99, label='Point Cloud',s=3.5)
            plt.tight_layout();
            for axis in [ax.xaxis, ax.yaxis, ax.zaxis]: axis.set_ticklabels([]);axis.set_pane_color((1, 1, 1))
            ax.set_facecolor('none')
            plt.savefig("pc_%s.png"%imgname, dpi=100)
            plt.close()
            print("done render")
            #zz

        data_sample={        "points_cam":np.concatenate((points_cam,rgb.reshape(-1,3)),axis=-1)[mask_points_cam], 
                             "points_rob_base":np.concatenate((points_rob_base,rgb.reshape(-1,3)),axis=-1)[mask_rob_base], 
                             "pointmap_cam":points_cam.reshape(height,width,3),
                             "pointmap_rob":points_rob_base.reshape(height,width,3),
                             "rgb":rgb,
                             "rerender_rgb":plt.imread("pc_rgb.png"), #"rerender_pointmap":plt.imread("pc_pointmap.png"), 
                        }
        data_sample = data.MultiModalFolder.format_sample(None,data_sample)
        
        with torch.no_grad():model_out=model({k:v[None].cuda() for k,v in data_sample.items()})
        action_pred=model_out["action_state"][0].cpu().numpy()
        print(action_pred)
        if args.endeffector_action:
            target_pos,targetOrientation,pred_grip_angle=action_pred[:3].tolist(),action_pred[3:-1].tolist(),action_pred[-1]
            targ_joint_poses = p.calculateInverseKinematics( robot_id, 5, target_pos, targetOrientation)
            for j, q in enumerate(targ_joint_poses[:-1]): p.setJointMotorControl2( bodyIndex= robot_id, jointIndex= j+1, controlMode= p.POSITION_CONTROL, targetPosition = q,)
            p.resetJointState(robot, 6, pred_grip_angle)
        else:
            for joint_i,joint_state in enumerate(action_pred): p.setJointMotorControl2(robot_id, joint_i+1, p.POSITION_CONTROL, joint_state, 0)

        for _ in range(6):p.stepSimulation()
