"""Code for pixelnerf and alternatives."""
import torch, torchvision
from torch import nn
import kornia
from einops import rearrange, repeat
import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear",padding_mode="border") # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

# Base model for FlowCam or FlowCams
class FlowMap(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Optical flow model - RAFT
        from torchvision.models.optical_flow import Raft_Large_Weights
        self.raft_transforms = Raft_Large_Weights.DEFAULT.transforms()
        from torchvision.models.optical_flow import raft_large
        self.raft_ = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False)

        # GMFlow is a faster alternative to RAFT - only used during large-scale training, for overfitting we use raft
        sys.path.append("./gmflow/")
        from gmflow.gmflow import GMFlow
        self.gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
        checkpoint = torch.load("./gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
        weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
        self.gm_flow.load_state_dict(weights, strict=False)
        for param in self.gm_flow.parameters(): param.requires_grad = False

        # Point tracking model
        from cotracker.predictor import CoTrackerPredictor
        self.co_tracker = CoTrackerPredictor( checkpoint=os.path.join( './third_party/co-tracker/checkpoints/cotracker_stride_4_wind_8.pth')).cuda()

        # Our actual model
        fdim=128
        self.fmap_downproj = nn.Linear(512,fdim)
        self.depth_est = make_net([fdim,64,1])
        self.corr_weighter_perpoint = make_net([fdim*2,128,64,1])
        self.static_est = make_net([fdim,64,1])
        self.mask_scorer = make_net([fdim,64,1])

        import conv_modules
        self.resnet_enc=nn.Sequential(conv_modules.PixelNeRFEncoder(in_ch=4),nn.Conv2d(512,fdim,3,padding=1))
        self.depth_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.focal_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.temperature_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.affinities_conv=nn.Conv2d(fdim,32,3,padding=1)

        self.n_rig=2
        self.rig_predictor = make_net([fdim,128,64,self.n_rig])
        self.o_rep = lambda x: repeat(x,"b ... -> (b o) ...",o=self.n_rig)
        self.o_switch = lambda x: repeat(x,"b t o ... -> (b o) t ...")

        # Depth-as-variable vars in case that setting is used
        self.corr_weights,self.depth=None,None
        self.step=0

        if not args.load_save:
            print("loading depth estimator")
            repo = "isl-org/ZoeDepth"
            torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
            self.model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True).cuda()
            for param in self.model_zoe_n.parameters(): param.requires_grad = False
            print("done loading depth estimator")

    def forward(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=300
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["zoe_depth"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = depth_inp + res_depth

        #what well do is the following: static flowmap produces poses, then point tracks at every n frame approach produces poses. composite dyn point tracks to image plane with
        #affinity emb attention approach, the composite that with the static poses via an 'is static' binary map, where you put a small regularization to use the static map where possible,
        #and tracks that are completely out of bounds can use the is_static map for those frames
        #point tracks use local sampling with ~300 pts and global uses ~1k or more points

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
            model_input["intrinsics"][...,1,1]=focal

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        # Estimate rigid masks -- with n_rig=1 reduces exactly to static flowmap code -- note since normalizing here we're doing cosine sim, try unnorm dot product too
        # First memory reduction is make the affinity emb downsampled to 4x lower emb 
        affinity_emb = F.normalize( self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)), dim=2 )

        #aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], model_input["pred_tracks"][:,[0]].unsqueeze(-2)),"b t c p 1 -> b t p c") # random 2 choices for emb
        #aff_src_samp_locs = model_input["pred_tracks"][:,:,[27]].unsqueeze(-2)
        aff_src_samp_locs = model_input["pred_tracks"].unsqueeze(-2)
        aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], aff_src_samp_locs[:,[0]]),"b t c p 1 -> b t p c") # random 2 choices for emb

        affinity_emb_low = F.interpolate(affinity_emb.flatten(0,1),(64,64),mode="bilinear").unflatten(0,(b,n_trgt))

        aff_sim_allpix = torch.einsum('b t p c, b s c -> b t p s', ch_sec(affinity_emb_low), aff_sources.squeeze(1)) # from all source pix to all other pix
        affinity_mask_unnorm = rearrange(aff_sim_allpix,"b t p s -> (b s) t p 1")
        temperature = self.temperature_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).max()*5e1+1
        affinity_mask = (affinity_mask_unnorm*temperature*(np.sqrt(aff_sim_allpix.size(-1))/8)).softmax(dim=0).clip(min=1e-4)

        # Below is rigidity bucket sanity case before going into discretized embedding set
        #affinity_mask=rearrange( self.rig_predictor(ch_sec(model_input["fmap"])).softmax(-1), "b t x o -> (b o) t x 1")
        #corr_weights_obj = (corr_weights * affinity_mask_unnorm[:,1:]).clip(min=1e-4)

        #from pdb import set_trace as pdb_;pdb_() 
        get_xpix = lambda s,e,n: torch.stack(torch.meshgrid(*[torch.linspace(s+.01,e-.01,n) for _ in range(2)])).flatten(1,2)
        #self.o_rep = lambda x: repeat(x,"b ... -> (b o) ...",o=aff_sources.size(2))
        n_total_samp,samp_amts = 324,[.5,.25,.1]
        uniform_grid_pix = repeat(get_xpix(0,1,int(n_total_samp**.5)),"c p -> (b o) t p 1 c",o=aff_sources.size(2),b=b,t=n_trgt-1).cuda()
        local_samp_grids = torch.cat([get_xpix(-i,i,int((n_total_samp/len(samp_amts))**.5)) for i in samp_amts],1).T
        sample_locs = rearrange(aff_src_samp_locs[:,1:],"b t s 1 c -> (b s) t 1 1 c")+local_samp_grids.cuda()[None,None,:,None]
        sample_locs = torch.where((sample_locs<0)|(sample_locs>1),uniform_grid_pix[:,:,:sample_locs.size(2)],sample_locs)# replace where out of bounds with uniform grid

        # Estimate poses via procrustes and integrate
        # this rearanging is ugly below for efficiency but can definitely refactor
        poses = geometry.procrustes(
            rearrange(grid_samp(ch_fst(eye_surf[:,1:],imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
            rearrange(grid_samp(ch_fst(corresp_surf,imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
           (rearrange(grid_samp(ch_fst(corr_weights,imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2))*
            grid_samp(ch_fst(affinity_mask_unnorm[:,1:],64),sample_locs).squeeze(2)).clip(min=1e-4)
                                              #,"bo t c p 1 -> bo t p c").clip(min=1e-4), 
                        )[1]
        #poses = geometry.procrustes(self.o_rep(eye_surf[:,1:,rand_subset]), self.o_rep(corresp_surf[:,:,rand_subset]), corr_weights_obj[:,:,rand_subset])[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-5),poses[...,:3,-1]),-1)
        lie_perpix = (affinity_mask*poses_lie.unsqueeze(-2)).unflatten(0,(b,aff_sources.size(2))).sum(1)

        lie_perpix = ch_sec(F.interpolate(ch_fst(lie_perpix,64).flatten(0,1),imsize,mode="bilinear").unflatten(0,(b,n_trgt))) # upsample to hires (just bilinear rn)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks
        eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1).squeeze(-2).unflatten(0,(b,n_trgt)),
                             "b t c p 1 -> b p t c",)
        pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,
                                padding_mode="border").squeeze(-2).unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
        pose_all_to_all_perpix = repeat(pose_tracks,"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks,"b p t x y -> b p t s x y",s=n_trgt)
        track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks))[...,:3]
        track_reprojs = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        ## Unpack the per-object estimations and composite over them 
        #adj_opt_flow =  (adj_opt_flow*affinity_mask[:,1:]).unflatten(0,(b,self.n_rig)).sum(1)
        #track_affinities=rearrange(F.grid_sample(ch_fst(affinity_mask,imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]
        #                    ).flatten(0,1).unsqueeze(-2)*2-1).squeeze(-2).unflatten(0,(b*self.n_rig,n_trgt)), "b t c p 1 -> b p t c",)
        #track_reprojs = (track_reprojs*track_affinities.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        return out | {
            "rig_masks":rearrange(affinity_mask,"(b o) t xy 1 -> b t o xy 1",o=aff_sources.size(2)),
            "rig_masks_unnorm":rearrange(affinity_mask_unnorm,"(b o) t xy 1 -> b t o xy 1",o=aff_sources.size(2)),
            "affinity_emb":affinity_emb,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "lie_perpix":lie_perpix.unflatten(-2,imsize),
            "pose_perpix":pose_perpix.unflatten(-3,imsize),
            "eye_surf":eye_surf,
            "poses_all":poses,
            "track_reprojs":track_reprojs,
            "zoe_depth":model_input["zoe_depth"],
            "zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":poses[:1],
            "flow_inp_": model_input["bwd_flow"],
        }

    def forward_static(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=1000
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["zoe_depth"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = res_depth + depth_inp

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
            model_input["intrinsics"][...,1,1]=focal

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        # Estimate poses via procrustes 
        adj_transf = geometry.procrustes(eye_surf[:,1:,rand_subset], corresp_surf[:,:,rand_subset], corr_weights[:,:,rand_subset] )[1]
        # Integrate poses
        poses = adj_transf
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1)

        # Compute pose-induced optical flow
        adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]

        # Compute pose induced point tracks
        eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1).squeeze(-2).unflatten(0,(b,n_trgt)),"b t c p 1 -> b p t c",)
        track_surf_reprojs = torch.einsum("bsnij,bksj->bksni",repeat(poses,"b t x y -> b s t x y",s=n_trgt).inverse()@repeat(poses,"b t x y -> b t s x y",s=n_trgt),hom(eye_tracks))[...,:3]
        track_reprojs = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        return out | {
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "zoe_depth":model_input["zoe_depth"],
            "track_reprojs":track_reprojs,
            "zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":poses,
            "flow_inp_": model_input["bwd_flow"],
        }

    def forward_(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["zoe_depth"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = res_depth + depth_inp

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
            model_input["intrinsics"][...,1,1]=focal

        # Get corresp weights per track corresp 
        track_feat   = rearrange(grid_samp(model_input["fmap"], model_input["pred_tracks"].unsqueeze(-2)),"b t c p 1 -> b t p c")
        corr_weights = self.corr_weighter_perpoint(torch.cat((track_feat[:,[0]].expand(-1,n_trgt,-1,-1),track_feat),-1)).sigmoid().clip(min=1e-4)

        # since too expensive without local sampling, will for now just use N point tracks until implemented
        neighb_list = np.arange(42**2)[::10]

        # Est poses per source point 
        rds_track = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
        eye_surf_track = rds_track * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)
        poses = geometry.procrustes(eye_surf_track[...,neighb_list,:], eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)[...,neighb_list,:], corr_weights[...,neighb_list,:])[1]

        point_track_surf_reproj = torch.einsum('btij,btpj->btpi',poses.inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        vis_mask = torch.minimum(model_input["pred_visibility"],model_input["pred_visibility"][:,[0]])
        point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * vis_mask.unsqueeze(-1) ).square().mean()

        ## Supervise optical flow 
        eye_surf = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1] * ch_sec(depth)
        adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]
        
        return out | {
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "point_track_loss":point_track_loss*1e2,
            "zoe_depth":model_input["zoe_depth"],
            "zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
            "corr_weights": ch_fst(corr_weights,42),
            "poses":poses,
            "flow_inp_": model_input["bwd_flow"],
        }
    def get_flow(self,model_input):
        (b,_),n_trgt,imsize=model_input["rgb"].shape[:2],model_input["rgb"].size(1),model_input["rgb"].shape[-2:]

        if "zoe_depth" not in model_input: 
            print("doing zoe depth")
            with torch.no_grad():
                zoe_est = F.interpolate(self.model_zoe_n.infer(F.interpolate(model_input["rgb_large"].flatten(0,1)/255,scale_factor=.25)),model_input["rgb"].shape[-2:]) # todo run offline
                #zoe_est = self.model_zoe_n.infer(model_input["rgb"].flatten(0,1)*.5+.5) # todo run offline
                model_input["zoe_depth"] = ch_sec(zoe_est.unflatten(0,model_input["rgb"].shape[:2]))
            print("done doing zoe depth")

        # RAFT optical flow 
        def raft(x,y):

            print("doing raft",x.shape)

            # Split query into 2 if >4 since raft has high memory constraints  
            if self.args.overfit and len(x)>4: 
                x_left,y_left = x[:len(x)//2],y[:len(x)//2]
                x_right,y_right = x[len(x)//2:],y[len(x)//2:]
                return torch.cat((raft(x_left,y_left),raft(x_right,y_right)))
            # Use raft is overfitting since raft is slow but only need to compute once in this setting; otherwise use gm flow
            if self.args.overfit and not self.args.gm_flow: 
                out=self.raft_(x,y,num_flow_updates=32)[-1]
            else: 
                x,y=(x*.5+.5)*255,(y*.5+.5)*255 # gm flow uses inputs in [0,255]
                out=self.gm_flow(x,y,attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
            # Rescale flow to be in resolution-independent pixel coordinates (-1,1)
            out = out/(torch.tensor(x.shape[-2:][::-1])-1).to(x)[None,:,None,None] # normalize flow coordinates to [-1,1]
            return F.interpolate(out,imsize,mode="bilinear",antialias=True) # downsample from hires to lowres

        with torch.no_grad():
            self.n_track_frames=1 if self.args.overfit else 1
            if "bwd_flow" not in model_input: # don't run if flow already in scene dict

                # Run optical flow net
                raft_inp_x,raft_inp_y = self.raft_transforms(model_input["rgb_large"][:,1:].flatten(0,1).to(torch.uint8), model_input["rgb_large"][:,:-1].flatten(0,1).to(torch.uint8))
                model_input["bwd_flow"] = raft(raft_inp_x,raft_inp_y,).unflatten(0,(b,n_trgt-1))

                # Run point tracking
                if self.args.point_track:
                    pred_tracks,visibilities=[],[]
                    for i,start_frame in enumerate(torch.linspace(0,n_trgt-1,self.n_track_frames).long()):
                        print("doing video tracking",start_frame)
                        pred_track,visibility = self.co_tracker(model_input["rgb_large"][:1], grid_size=8 if 1 else 42 if self.args.overfit else 20, grid_query_frame=start_frame, backward_tracking=True)
                        pred_track_norm = pred_track/(torch.tensor(model_input["rgb_large"].shape[-2:][::-1])-1)[None,None,None].cuda()
                        pred_tracks.append(pred_track_norm);visibilities.append(visibility)

                        # Save point track visualization video 
                        if self.args.overfit or (not self.args.overfit and self.step%25==0):
                            print("writing point track video for visualization")
                            Visualizer(save_dir='./output/', pad_value=100).visualize( video=model_input["rgb_large"], tracks=pred_track, visibility=visibility, filename='tracks_%02d'%i, query_frame=start_frame);
                    model_input["pred_tracks"]=torch.cat(pred_tracks,2)
                    visibilities=torch.cat(visibilities,2)

                    # Pad batch dimension with dummy 0 values since point tracks don't support batched inference
                    model_input["pred_tracks"]=torch.cat((model_input["pred_tracks"],torch.ones_like(model_input["pred_tracks"]).expand(b-1,-1,-1,-1)))
                    model_input["pred_visibility"]=torch.cat((visibilities,torch.zeros_like(visibilities).expand(b-1,-1,-1)))

            # Add sky mask heuristic to mask out bad point tracks in outdoor scenes
            #if "not_sky" not in model_input and self.args.point_track: 
            #    model_input["not_sky"]= torch.stack([(self.seg_model(model_input["rgb_large"][:,i]/255).max(1)[1][:,None]!=10).float() for i in range(n_trgt)],1)
            #    not_sky = grid_samp(model_input["not_sky"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(1)).squeeze(1).squeeze(1).unflatten(0,(b,n_trgt))
            #    model_input["pred_visibility"]=model_input["pred_visibility"]*not_sky
