
            # either extract median extrinsic or just most confident one (using weighted mean here) 
            # just trying max idx
            #est_extrinsics = est_links[ torch.tensor([x.sum() for x in segs]).max(dim=0)[1].item() ]
            #weights = np.array([x.sum() for x in segs])
            #weights = weights/weights.max()
            #try: est_extrinsics=geometry.average_se3_split(est_links,weights)
            #except: print("bad se3 est");  est_extrinsics=np.eye(4)
            #print((np.linalg.inv(model_input["cam2world_cv"].cpu().numpy())-est_extrinsics).round(2)) # should be 0 for GT
            # try procrustes on all points and then with ransac procrustes
            pos_emb=torch.stack(torch.meshgrid(torch.arange(model_input["img"].size(-2)),
                    torch.arange(model_input["img"].size(-1))))[None].expand(len(model_input["img"]),-1,-1,-1).to(model_input["img"])*2/512-1
            hires_inp = torch.cat((model_input["img"].squeeze(1),pos_emb),1)
            "joint_sdfs": self.joint_sdf_pred(full_res_feats).unflatten(1,(12,3))
