import gradio as gr
from PIL import Image, ImageDraw
from collections import defaultdict
import numpy as np
import torch

import data,geometry
import pybullet_data
import pybullet as p
physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
rob_id = p.loadURDF("franka_panda/panda.urdf")

#table_id = p.loadURDF("table/table.urdf", basePosition=[.0, -0.0, -.6], baseOrientation=[0, 0, 0.7071, 0.7071])
#plane_id = p.loadURDF("plane.urdf", basePosition=[.0, -0.0, -.6])
width=height = 256#1200, 1200
def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

import vis
import our_models as models
torch.set_grad_enabled(False)
model = models.CanonRobotTrajPredTransfEuclidean().cuda() 
model.eval()
ckpt_file = "/tmp/newlinksetup_sanity/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)

def generate_image(params):
    joint_states=params[6:]
    view_matrix_ = np.array(p.computeViewMatrix(params[:3], params[3:6], [0,0,1])).reshape(4, 4, order='F')  
    #cam2world_cv = np.linalg.inv(data.Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
    render=data.format_pybullet_render(joint_states,view_matrix_,rob_id=rob_id)
    return Image.fromarray((render["img"][0].permute(1,2,0).numpy()*255).astype(np.uint8))

def predict_image_grid(params):
    joint_states=params[6:]
    #view_matrix_ = p.computeViewMatrix(params[:3], params[3:6], [0,0,1]) 
    #cam2world_cv = np.linalg.inv(data.Tc) @ np.array(view_matrix_).reshape(4, 4, order='F') 
    view_matrix_ = np.array(p.computeViewMatrix(params[:3], params[3:6], [0,0,1])).reshape(4, 4, order='F')  
    #joint_states=cam2world_cv=None
    render=data.format_pybullet_render(joint_states,view_matrix_,rob_id=rob_id)
    data_sample={k:v[None] for k,v in to_gpu(data.format_render_sample(render,use_rand_bg=False,use_color_aug=False)).items()}
    with torch.no_grad():out = model(data_sample)

    # todo replace this with our model prediction to not depend on GT seg at all 
    for k in ["points_cam","points_link"][1:]:
        for d in [data_sample,out]: 
            d[k] = d[k]*data_sample["fg"]

    wandb_out=vis.wandb_summary(0,out,data_sample,data_sample,0,dont_log=True)
    return list(zip([v.image for v in wandb_out.values()], list(wandb_out.keys())))

# CSS to remove fixed-height constraint on the gallery
custom_css = """
#model-output-gallery .grid-wrap.fixed-height {
    max-height: none !important;
    height: auto !important;
}
"""
with gr.Blocks(css=custom_css, fill_height=True) as demo:
    with gr.Row(equal_height=True):
        # ─── Left panel: image, buttons, then grouped sliders ───
        with gr.Column(scale=1, min_width=300):
            rendered_img = gr.Image(label="Rendered Image")

            # Buttons above the sliders
            with gr.Row():
                render_btn  = gr.Button("Render")
                predict_btn = gr.Button("Predict")

            # ── Group 1 sliders (1–3) ──
            #cam_pos,look_at = [-0.6980221076275764, 1.7691152265426342, 0.9748194607625217],[-0.18228678601176462, -0.031187220028748453, 0.3185598402219197]

            cam_pos,look_at,roll=geometry.sample_camera_positions(1.5)
            gr.Markdown("**— Camera Center —**")
            sliders_grp1 = [ gr.Slider(-3,3, cam_pos[i], label=f"Camera Center %s"%["X","Y","Z"][i]) for i in range(3) ]

            # ── Group 2 sliders (4–6) ──
            gr.Markdown("**— Camera LookAt —**")
            sliders_grp2 = [ gr.Slider(-3,3, look_at[i], label=f"Camera Lookat %s"%["X","Y","Z"][i]) for i in range(3) ]

            # ── Group 3 sliders (7–12) ──
            gr.Markdown("**— Joint States —**")
            joint_states = [x[0] for x in p.getJointStates(rob_id, list(range(p.getNumJoints(rob_id))))]
            joint_sliders=[]
            for i in range(p.getNumJoints(rob_id)):
                info = p.getJointInfo(rob_id, i)
                lower, upper = info[8], info[9]
                if upper<lower:upper=lower+.01
                joint_sliders.append( gr.Slider(lower, upper, 0.01, value=joint_states[i], label=f"Joint {i+1}") )

            # flatten all sliders for callback wiring
            all_sliders = sliders_grp1 + sliders_grp2 + joint_sliders

        # ─── Right panel: expanded gallery ───
        with gr.Column(scale=2):
            gallery = gr.Gallery( label="Model Outputs", columns=3, rows=None, elem_id="model-output-gallery")

    # Wire up the two buttons
    render_btn.click(  fn=lambda *vals: generate_image(list(vals)), inputs=all_sliders, outputs=rendered_img)
    predict_btn.click( fn=lambda *vals: predict_image_grid(list(vals)), inputs=all_sliders, outputs=gallery)

# Launch on port 8080
demo.launch(server_port=8080)

