
        # Main loss signal -- reproject all points to all other points
        # Need to change this to be all frames to all other frames -- right now it's just going from frame 0 back to all other frames
        # wait do we? see if it even makes a difference, i think it might be sufficient to do this
        # i think so because e.g. if first frame tracks suck no good transport 

        #out["track_idxs"] = torch.randperm(model_input["pred_tracks"].size(-2))[:int(8000//(n_trgt**2/20**2))] # choose random smaller subset to combat quadratic complexity
        #pose_all_to_all_perpix = repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p t s x y",s=n_trgt)
        #pose_tracks = poses.unsqueeze(1) # b p t x y
        #pose_all_to_all_pertrack = repeat(pose_tracks,"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks,"b p t x y -> b p t s x y",s=n_trgt)

        #track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_pertrack,hom(eye_surf_tracks]))[...,:3]
        #track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks[:,out["track_idxs"]]))[...,:3]

        # first non-vectorized
        #point_track_loss=0
        #reprojs=[]
        #tmps=[],[]
        #for i in range(n_trgt):
        #    tmps[0].append(poses.unsqueeze(2).inverse()@poses[:,[i]].unsqueeze(2))
        #    tmps[1].append(eye_surf_track[:,[i]].expand(-1,n_trgt,-1,-1))
        #    point_track_surf_reproj = torch.einsum('btpij,btpj->btpi',poses.unsqueeze(2).inverse()@poses[:,[i]].unsqueeze(2),hom(eye_surf_track[:,[i]].expand(-1,n_trgt,-1,-1)))[...,:3]
        #    point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        #    reprojs.append(point_track_reproj)
        #    point_track_loss += ( (point_track_reproj - model_input["pred_tracks"]) * model_input["pred_visibility"].unsqueeze(-1) ).square().mean()
        #reprojs=torch.stack(reprojs)
        #point_track_loss/=n_trgt

        #torch.stack([eye_surf_track[:,[i]].expand(-1,n_trgt,-1,-1) for i in range(n_trgt)],1)

        #point_track_surf_reproj = torch.einsum('btpij,btpj->btpi',poses.unsqueeze(2).inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        #point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        #point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * model_input["pred_visibility"].unsqueeze(-1) ).square().mean()

