
# static point track est 
def forward(self, model_input, out={}):
    self.step+=1

    imsize=model_input["rgb"].shape[-2:]
    (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

    # Run optical flow and point track networks
    self.get_flow(model_input)
    #corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))

    # Est depth
    depth_inp = ch_fst(model_input["zoe_depth"],imsize[0])
    rgbd = torch.cat((model_input["rgb"],depth_inp),2)
    model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
    res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
    depth = res_depth + depth_inp

    # Est intrinsics
    if not self.args.use_gt_intrinsics:
        focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
        model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
        model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
        model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
        model_input["intrinsics"][...,1,1]=focal

    # Get depth
    #midas_feats = self.midas((model_input["rgb"]*.5+.5).flatten(0,1))
    #model_input["fmap"]=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))/20
    #depth = 1e3/(self.midas_out(midas_feats).unflatten(0,(b,n_trgt))+1e-1)

    # Get corresp weights per track corresp 
    track_feat   = rearrange(grid_samp(model_input["fmap"], model_input["pred_tracks"].unsqueeze(-2)),"b t c p 1 -> b t p c")
    corr_weights = self.corr_weighter_perpoint(torch.cat((track_feat[:,[0]].expand(-1,n_trgt,-1,-1),track_feat),-1)).sigmoid().clip(min=1e-4)

    # Est pose
    rds = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
    eye_surf = rds * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)
    poses = geometry.procrustes(eye_surf,eye_surf[:,[0]].expand(-1,n_trgt,-1,-1),corr_weights)[1]

    # Reproject point tracks using est poses (move frame 0 to all other frames)
    point_track_surf_reproj = torch.einsum('btij,btpj->btpi',poses.inverse(),hom(eye_surf[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
    point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
    point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * model_input["pred_visibility"].unsqueeze(-1) ).square().mean()

    #pts_canon = torch.einsum("brtij,brtxj->brtxi",poses,hom(eye_surf.unflatten(0,(b,self.n_rig))))[...,:3]

    return out | {
        "res_depth":ch_sec(res_depth),
        "depth":ch_sec(depth),
        "point_track_loss":point_track_loss*1e2,
        "zoe_depth":model_input["zoe_depth"],
        "zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
        "corr_weights": ch_fst(corr_weights,42),
        "poses":poses,
        "flow_inp_": model_input["bwd_flow"],
        "intrinsics": model_input["intrinsics"],
    }
