import gradio as gr
from PIL import Image, ImageDraw
from collections import defaultdict
import numpy as np
import torch

import data,geometry
import pybullet_data
import pybullet as p
import fpsample

import argparse
parser = argparse.ArgumentParser(description="A simple example")
# pyb params
parser.add_argument("--gui",'-g', action="store_true") # no gui, just for debugging/printing
parser.add_argument("--test_views",action="store_true") # no gui, just for debugging/printing
parser.add_argument("--cube",action="store_true") # use cube instead of sphere as follower shape
parser.add_argument("--no_follower",'-nf', action="store_true") # no follower arm
# model params
parser.add_argument('--cam_pointmap_inp', default=False, action='store_true',help="use camera space pointmap")
parser.add_argument('--rob_pointmap_inp', default=False, action='store_true',help="use robot space pointmap")
parser.add_argument('--cam_pointcloud_inp', default=False, action='store_true',help="use camera space pointcloud")
parser.add_argument('--rob_pointcloud_inp', default=False, action='store_true',help="use robot space pointcloud")
parser.add_argument('--rob_rerender_inp', default=False, action='store_true',help="use robot space re-render image")
parser.add_argument('--endeffector_action', default=False, action='store_true',help="use endeffector prediction instead of joint angle prediction")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
args = parser.parse_args()

#args.cam_pointcloud_inp="cam" in args.init_ckpt
#args.rob_pointcloud_inp="rob_pc" in args.init_ckpt
#args.rob_rerender_inp="rerender" in args.init_ckpt
#args.endeffector_action="effect" in args.init_ckpt

import sys
sys.path.append("/home/cameronsmith/misc/")
physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setAdditionalSearchPath("/home/cameronsmith/misc")
from tmpfns import *

import vis
import models
torch.set_grad_enabled(False)
model = (models.PolicyModel)(args).cuda() 
model.eval()
ckpt_file = args.init_ckpt#"/tmp/redo_policy_rob_rerender_joints/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)

# Step 1: PyBullet scene setup
physicsClient = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath());p.setGravity(0, 0, -9.8)

physicsClient = p.connect(p.DIRECT)
p.resetDebugVisualizerCamera( cameraDistance=1.00, cameraYaw=242.74, cameraPitch= -12.06, cameraTargetPosition=[0.47, -0.31, 0.02])
p.setAdditionalSearchPath(pybullet_data.getDataPath())
robot_id=robot = p.loadURDF("/home/cameronsmith/misc/so_100_arm/urdf/so_100_arm.urdf", useFixedBase=True,flags=p.URDF_USE_SELF_COLLISION,basePosition=[0,0,0])
# red ball to follow
vis = p.createVisualShape(p.GEOM_SPHERE,   radius=0.03, rgbaColor=[1,0,0,1])
ball_id = p.createMultiBody( baseMass=0, baseCollisionShapeIndex=-1, baseVisualShapeIndex=vis, basePosition=[.3, -.3, 0.1])
plane_id = p.loadURDF("plane.urdf")

# Make cameras
cameras=[]
yaw_set = [-30,30] if not args.test_views else [-15,15]
for yaw in yaw_set:
    view_matrix_ = pb.computeViewMatrixFromYawPitchRoll([0.0,0,0], .8,yaw,-37, 0, 2)
    cam2world_cv = np.linalg.inv(Tc) @ np.array(view_matrix_).reshape(4, 4, order='F')
    cameras.append((view_matrix_,cam2world_cv))
view_matrix_,cam2world_cv=cameras[1] # todo rand choice

org_ballpos=np.array(p.getBasePositionAndOrientation(ball_id)[0])

# Reset ball and joint states
def reset_state(params):
    for j  in range(6):p.resetJointState(robot, j+1, 0)  # Reset state
    target_pos = [np.random.uniform(-.2,.2),np.random.uniform(-.2,-.1),np.random.uniform(.01,.2)]
    p.resetBasePositionAndOrientation(ball_id, target_pos, (0.0, 0.0, 0.0, 1.0));p.stepSimulation() # Randomize ball position
    org_ballpos=np.array(p.getBasePositionAndOrientation(ball_id)[0])
    return generate_image(params,just_img=True)
# Move wrist to ball position
#joint_poses = p.calculateInverseKinematics( robot_id, 6, p.getBasePositionAndOrientation(ball_id)[0])
#for joint_i,joint_state in enumerate(joint_poses): p.setJointMotorControl2(robot_id, joint_i+1, p.POSITION_CONTROL, joint_state, 0)
# Simulate time forward and save intermediate states
#for t in range(8):
#    rgb = np.reshape(pb.getCameraImage(width=width, height=height, viewMatrix=view_matrix, projectionMatrix=proj_matrix)[2], (height, width, 4))[:, :, :3].astype(np.float32) / 255.0
#    for _ in range(3): p.stepSimulation()

def generate_image(params,just_img=False):
    #ball_pos = org_ballpos+np.array(params)
    #p.resetBasePositionAndOrientation(ball_id, ball_pos, (0.0, 0.0, 0.0, 1.0));p.stepSimulation() # Randomize ball position

    _, _, rgb, depth_ndc, seg = p.getCameraImage(width, height, viewMatrix=view_matrix_, projectionMatrix=proj_matrix_)
    if just_img:return Image.fromarray(rgb.astype(np.uint8))

    rgb = rgb.reshape(height,width, 4)[..., :3].astype(np.float32)/255
    fg = seg>-1
    points_cam = depth_to_point_cloud((far * near) / (far - (far - near) * depth_ndc), fx, fy, cx, cy)
    points_rob_base = np.einsum('ij,nj->ni', np.linalg.inv(cam2world_cv), hom(points_cam))[:, :3]
    curr_joint_state= [x[0] for x in p.getJointStates(robot_id, list(range(p.getNumJoints(robot_id))))][1:]

    # save last state and curr joints as target
    n_points = 10000#128**2

    lims_rob_base=[[-.4,.4],[-.4,.1],[.03,.3]]
    mask_rob_base = np.arange(len(points_rob_base))[ ((np.linalg.norm(points_rob_base,axis=-1)<1.2)&(points_rob_base[:,0]>lims_rob_base[0][0])&
                            (points_rob_base[:,0]<lims_rob_base[0][1])&(points_rob_base[:,1]>lims_rob_base[1][0])&
                            (points_rob_base[:,1]<lims_rob_base[1][1])&(points_rob_base[:,2]>lims_rob_base[2][0])&(points_rob_base[:,2]<lims_rob_base[2][1])) ]
    #from pdb import set_trace as pdb_;pdb_() 
    if (n_points-mask_rob_base.shape[0])>0:
        mask_rob_base=np.concatenate([mask_rob_base,np.zeros(10+n_points-mask_rob_base.shape[0])]).astype(int)
    mask_rob_base = mask_rob_base[ fpsample.bucket_fps_kdtree_sampling(points_rob_base[mask_rob_base],n_points) ] 

    mask_points_cam = np.arange(len(points_cam))
    mask_points_cam = mask_points_cam[(np.linalg.norm(points_rob_base,axis=-1)<1.2)]
    print(mask_points_cam.shape)
    mask_points_cam = mask_points_cam[ fpsample.bucket_fps_kdtree_sampling(points_cam[mask_points_cam],n_points) ]

    # add random noise
    if args.depth_noise:
        points_cam += np.random.normal(scale=(0.02 * np.linalg.norm(points_cam, axis=1))[:, None], size=points_cam.shape)
        points_rob_base += np.random.normal(scale=(0.02 * np.linalg.norm(points_rob_base, axis=1))[:, None], size=points_rob_base.shape)
    if args.cam_noise:
        # randomly shift/rotate points (camera noise)
        from scipy.spatial.transform import Rotation as R
        trans_std=0.04
        rot_deg_std=40.0
        axis = np.random.randn(3)
        axis /= np.linalg.norm(axis) + 1e-8
        angle = np.deg2rad(np.random.normal(scale=rot_deg_std))
        R_noise = R.from_rotvec(axis * angle).as_matrix()
        t_noise = np.random.normal(scale=trans_std, size=3)
        points_rob_base = (R_noise @ points_rob_base.T).T + t_noise
        points_cam = (R_noise @ points_cam.T).T + t_noise


    for color,imgname in [ (rgb,"rgb"),(points_rob_base*2*.5+.5,"pointmap") ][:1]:
        print("doing render")
        fig = plt.figure();ax = fig.add_subplot(111, projection='3d');ax.view_init(elev=15, azim=-45);s=1
        ax.set_xlim(*lims_rob_base[0]); ax.set_ylim(*lims_rob_base[1]); ax.set_zlim(*lims_rob_base[2])
        ax.scatter(*points_rob_base[mask_rob_base][::s].T, c=np.clip(color.reshape(-1,3)[mask_rob_base][::s],-1,1), 
                            marker='o', alpha=.99, label='Point Cloud',s=3.5)
        plt.tight_layout();
        for axis in [ax.xaxis, ax.yaxis, ax.zaxis]: axis.set_ticklabels([]);axis.set_pane_color((1, 1, 1))
        ax.set_facecolor('none')
        plt.savefig("pc_%s.png"%imgname, dpi=100)
        plt.close()
        print("done render")
        #zz


    data_sample={        "points_cam":np.concatenate((points_cam,rgb.reshape(-1,3)),axis=-1)[mask_points_cam], 
                         "points_rob_base":np.concatenate((points_rob_base,rgb.reshape(-1,3)),axis=-1)[mask_rob_base], 
                         "pointmap_cam":points_cam.reshape(height,width,3),
                         "pointmap_rob":points_rob_base.reshape(height,width,3),
                         "rgb":rgb,
                         "rerender_rgb":plt.imread("pc_rgb.png"), #"rerender_pointmap":plt.imread("pc_pointmap.png"), 
                    }
    if 0: # IK GT control 
        joint_poses = p.calculateInverseKinematics( robot_id, 6, p.getBasePositionAndOrientation(ball_id)[0])
        for joint_i,joint_state in enumerate(joint_poses): p.setJointMotorControl2(robot_id, joint_i+1, p.POSITION_CONTROL, joint_state, 0)
    else: 
        data_sample = data.MultiModalFolder.format_sample(None,data_sample)
        
        with torch.no_grad():model_out=model({k:v[None].cuda() for k,v in data_sample.items()})
        action_pred=model_out["action_state"][0].cpu().numpy()
        print(action_pred)
        if args.endeffector_action:
            target_pos,targetOrientation,pred_grip_angle=action_pred[:3].tolist(),action_pred[3:-1].tolist(),action_pred[-1]
            targ_joint_poses = p.calculateInverseKinematics( robot_id, 5, target_pos, targetOrientation)
            for j, q in enumerate(targ_joint_poses[:-1]): p.setJointMotorControl2( bodyIndex= robot_id, jointIndex= j+1, controlMode= p.POSITION_CONTROL, targetPosition = q,)
            p.resetJointState(robot, 6, pred_grip_angle)
        else:
            for joint_i,joint_state in enumerate(action_pred): p.setJointMotorControl2(robot_id, joint_i+1, p.POSITION_CONTROL, joint_state, 0)

    for _ in range(4):p.stepSimulation()

    return Image.fromarray((rgb*255).astype(np.uint8))

# CSS to remove fixed-height constraint on the gallery
custom_css = """
#model-output-gallery .grid-wrap.fixed-height { max-height: none !important; height: auto !important; }
"""
with gr.Blocks(css=custom_css, fill_height=True) as demo:
    with gr.Row(equal_height=True):
        # ─── Left panel: image, buttons, then grouped sliders ───
        with gr.Column(scale=1, min_width=300):
            rendered_img = gr.Image(label="Rendered Image")

            # Buttons above the sliders
            with gr.Row():
                render_btn  = gr.Button("Render")
                reset_btn  = gr.Button("Reset Joints")

            # ── Group 1 sliders (1–3) ──
            gr.Markdown("**— Ball offset —**")
            all_sliders = sliders_grp1 = [ gr.Slider(-.2,.2, 0, label=f"%s"%["X","Y","Z"][i]) for i in range(3) ]

    # Wire up the two buttons
    render_btn.click(  fn=lambda *vals: generate_image(list(vals)), inputs=all_sliders, outputs=rendered_img)
    reset_btn.click(  fn=lambda *vals:     reset_state(list(vals)), inputs=all_sliders, outputs=rendered_img)

# Launch on port 8080
demo.launch(server_port=8080)

