
#research: right now dino upsampled feats pca are bad, start with featup in testing.py / notebook and integrate to our code with same input/output pca vis for confirmation"""Code for pixelnerf and alternatives."""
import torch, torchvision
from torch import nn
import kornia
from einops import rearrange, repeat
import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
import torchvision.transforms as T

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear",padding_mode="border") # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

# Base model for FlowCam or FlowCams
class FlowMap(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Optical flow model - RAFT
        from torchvision.models.optical_flow import Raft_Large_Weights
        self.raft_transforms = Raft_Large_Weights.DEFAULT.transforms()
        from torchvision.models.optical_flow import raft_large
        self.raft_ = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False)

        # GMFlow is a faster alternative to RAFT - only used during large-scale training, for overfitting we use raft
        sys.path.append("./gmflow/")
        from gmflow.gmflow import GMFlow
        self.gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
        checkpoint = torch.load("./gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
        weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
        self.gm_flow.load_state_dict(weights, strict=False)
        for param in self.gm_flow.parameters(): param.requires_grad = False

        # Point tracking model
        from cotracker.predictor import CoTrackerPredictor
        self.co_tracker = CoTrackerPredictor( checkpoint=os.path.join( './third_party/co-tracker/checkpoints/cotracker_stride_4_wind_8.pth')).cuda()

        # Our actual model
        self.time_stride = 4
        fdim=64//self.time_stride#384
        self.fmap_downproj = nn.Linear(512,fdim)
        self.depth_est = make_net([fdim,64,1])
        self.corr_weighter_perpoint = make_net([fdim*2,16,1])
        #self.corr_weighter_perpoint = make_net([fdim*2,128,64,1])
        self.static_est = make_net([fdim,64,1])
        self.mask_scorer = make_net([fdim,64,1])


        import conv_modules
        self.resnet_enc=nn.Sequential(conv_modules.PixelNeRFEncoder(in_ch=6*self.time_stride),nn.Conv2d(512,fdim*self.time_stride,3,padding=1))
        self.depth_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.focal_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.temperature_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.affinities_conv=nn.Conv2d(fdim,32,3,padding=1)
        self.affinities_conv2=nn.Conv2d(fdim,32,3,padding=1)

        self.n_rig=1
        self.rig_predictor = make_net([fdim,64,self.n_rig])
        #self.rig_predictor = make_net([fdim,128,64,self.n_rig])
        self.o_rep = lambda x: repeat(x,"b ... -> (b o) ...",o=self.n_rig)
        self.o_switch = lambda x: repeat(x,"b t o ... -> (b o) t ...")

        # Depth-as-variable vars in case that setting is used
        self.corr_weights,self.depth=None,None
        self.step=0

        self.upsampler = torch.hub.load("mhamilton723/FeatUp", 'dinov2', use_norm=True).cuda()

        #if not args.load_save:
        #    print("loading depth estimator")
        #    repo = "isl-org/ZoeDepth"
        #    torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
        #    self.model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True).cuda()
        #    for param in self.model_zoe_n.parameters(): param.requires_grad = False
        #    print("done loading depth estimator")

    def forward(self, model_input, out={}): # mlp rigid bucket
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        low_imres=(64,64)
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=300
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        flow_inp = torch.cat((torch.zeros_like(model_input["bwd_flow"][:,:1]),model_input["bwd_flow"]*5e2),1)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbdf = torch.cat((model_input["rgb"],depth_inp,flow_inp),2)

        rgbdf=rearrange(rgbdf,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        fmap_out = F.interpolate(self.resnet_enc(rgbdf.flatten(0,1)*.5+.5),imsize,mode="bilinear")
        model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)

        #model_input["fmap"] = F.interpolate(self.resnet_enc(rgbdf.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))

        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = depth_inp + res_depth

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)+.3
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,:,0,0]=focal[:,None]*model_input["org_ratio"][:1]
            model_input["intrinsics"][...,:,1,1]=focal[:,None]

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        # Estimate rigid masks -- with n_rig=1 reduces exactly to static flowmap code -- note since normalizing here we're doing cosine sim, try unnorm dot product too
        affinity_mask=affinity_mask_unnorm=rearrange( (2e0*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")
        #affinity_mask=affinity_mask_unnorm=rearrange( (2e0*self.rig_predictor(ch_sec(model_input["fmap"]+model_input["dino_feats"]/5))).softmax(-1), "b t x o -> (b o) t x 1")
        affinity_mask=ch_sec(F.interpolate(ch_fst(affinity_mask,imsize[0]).flatten(0,1),low_imres).unflatten(0,affinity_mask.shape[:2]))
        #low_imres=imsize
        #corr_weights_obj = (corr_weights * affinity_mask_unnorm[:,1:]).clip(min=1e-4)

        # Local sampling per embedding based on point track corresponding to embedding source location 
        get_xpix = lambda s,e,n: torch.stack(torch.meshgrid(*[torch.linspace(s+.01,e-.01,n) for _ in range(2)])).flatten(1,2)
        n_total_samp,samp_amts = 500,[.5,.25,.1]
        sample_locs = uniform_grid_pix = repeat(get_xpix(0,1,int(n_total_samp**.5)),"c p -> (b o) t p 1 c",o=self.n_rig,b=b,t=n_trgt-1).cuda()
        #local_samp_grids = torch.cat([get_xpix(-i,i,int((n_total_samp/len(samp_amts))**.5)) for i in samp_amts],1).T
        #sample_locs = rearrange(aff_src_samp_locs[:,1:],"b t s 1 c -> (b s) t 1 1 c")+local_samp_grids.cuda()[None,None,:,None]
        #sample_locs = torch.where((sample_locs<0)|(sample_locs>1),uniform_grid_pix[:,:,:sample_locs.size(2)],sample_locs)# replace where out of bounds with uniform grid

        #sample_locs=uniform_grid_pix

        # Estimate poses via procrustes and integrate ; this rearanging is ugly below for efficiency but can definitely refactor
        poses = geometry.procrustes(
            rearrange(grid_samp(ch_fst(eye_surf[:,1:],imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
            rearrange(grid_samp(ch_fst(corresp_surf,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
           (rearrange(grid_samp(ch_fst(corr_weights,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2))*
            grid_samp(ch_fst(affinity_mask_unnorm[:,1:],low_imres[0]),sample_locs).squeeze(2)).clip(min=1e-4)
                        )[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-5),poses[...,:3,-1]),-1)
        lie_perpix = (affinity_mask*poses_lie.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)
        lie_perpix = ch_sec(F.interpolate(ch_fst(lie_perpix,low_imres[0]).flatten(0,1),imsize,mode="bilinear").unflatten(0,(b,n_trgt))) # upsample to hires (just bilinear rn)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks
        if "pred_tracks" in model_input:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",
                ).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,
                                    padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
            pose_all_to_all_perpix = repeat(pose_tracks,"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks,"b p t x y -> b p t s x y",s=n_trgt)
            track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        bg_pose = poses[affinity_mask.flatten(1,-1).sum(-1).max(dim=0)[1]][None] # take max

        world_crds = torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3]
        world_crds_tracks = torch.einsum("btpij,btpj->btpi",pose_tracks.permute(0,2,1,3,4),hom(eye_tracks.permute(0,2,1,3)))[...,:3]
        rgb_tracks=F.grid_sample(model_input["rgb"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1, 
                                                            padding_mode="border").unflatten(0,(b,n_trgt)).permute(0,1,3,2,4).squeeze(-1)
        lie_tracks = torch.cat((torch.tensor([1,0,0,0]).float().cuda()[None,None,None].expand(b,n_trgt,rgb_tracks.size(-2),-1),
                                (eye_tracks[:,:,[0]]-eye_tracks).permute(0,2,1,3)
                                ),-1).unsqueeze(-2)
        rgb_crds = ch_sec(model_input["rgb"])

        #I think using only up to 80 frames is probably fine — most Davis sequences are at around 80 frames — and can always cite just using larger gpus
        #Also we can just reduce the image resolution to ~64x64 for a 4x decrease in res
        #Next up then — low res, point rack striding, run on more videos, switch to better point tracker

        return out | {
            "rig_masks":rearrange(affinity_mask,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            #"rig_masks_unnorm":rearrange(affinity_mask_unnorm,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            #"affinity_emb":affinity_emb,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,

            "lie_crds":torch.cat((lie_perpix.unsqueeze(-2),lie_tracks),-3),
            "world_crds":torch.cat((world_crds,world_crds_tracks),-2),
            "rgb_crds":torch.cat((rgb_crds,rgb_tracks),-2),

            #"lie_crds":lie_tracks,
            #"world_crds":world_crds_tracks,
            #"rgb_crds":rgb_tracks,

            #"lie_crds":lie_perpix.unflatten(-2,imsize),
            #"world_crds":world_crds,
            #"rgb_crds":rgb_crds,

            #"affinity_emb":model_input["dino_feats"],

            "lie_perpix":lie_perpix.unflatten(-2,imsize),
            "pose_perpix":pose_perpix.unflatten(-3,imsize),
            "eye_surf":eye_surf,
            "poses_all":poses,
            "poses":bg_pose,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":bg_pose,
            "flow_inp_": model_input["bwd_flow"],
        }

    def forward_(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        low_imres=(64,64)
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=300
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = depth_inp# + res_depth

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,:,0,0]=focal[:,None]*model_input["org_ratio"][:1]
            model_input["intrinsics"][...,:,1,1]=focal[:,None]

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        #corr_weights = torch.ones_like(corr_weights)

        # Estimate rigid masks -- with n_rig=1 reduces exactly to static flowmap code -- note since normalizing here we're doing cosine sim, try unnorm dot product too
        # First memory reduction is make the affinity emb downsampled to 4x lower emb 
        #affinity_emb = F.normalize( self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)), dim=2 )
        affinity_emb = self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))
        #affinity_emb2= self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))

        #aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], model_input["pred_tracks"][:,[0]].unsqueeze(-2)),"b t c p 1 -> b t p c") # random 2 choices for emb
        #aff_src_samp_locs = model_input["pred_tracks"][:,:,[27]].unsqueeze(-2)
        aff_src_samp_locs = model_input["pred_tracks"].unsqueeze(-2)
        aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], aff_src_samp_locs[:,[0]]),"b t c p 1 -> b t p c")
        #aff_sources2= rearrange(grid_samp(affinity_emb2[:,[0]], aff_src_samp_locs[:,[0]]),"b t c p 1 -> b t p c")

        affinity_emb_low = F.interpolate(affinity_emb.flatten(0,1),(64,64),mode="bilinear").unflatten(0,(b,n_trgt))
        #affinity_emb_low2= F.interpolate(affinity_emb2.flatten(0,1),(64,64),mode="bilinear").unflatten(0,(b,n_trgt))

        # TODO simplify this by only doing once but taking product of norms

        aff_sim_allpix = torch.einsum('b t p c, b s c -> b t p s', ch_sec(affinity_emb_low), aff_sources.squeeze(1)) # from all source pix to all other pix
        affinity_mask_unnorm = rearrange(aff_sim_allpix,"b t p s -> (b s) t p 1")
        #aff_sim_allpix2= torch.einsum('b t p c, b s c -> b t p s', ch_sec(affinity_emb_low2), aff_sources2.squeeze(1)) # from all source pix to all other pix
        #affinity_mask_unnorm2= rearrange(aff_sim_allpix2,"b t p s -> (b s) t p 1")
        #temperature = self.temperature_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).max()*5e1+1
        #affinity_mask = (affinity_mask_unnorm*temperature*(np.sqrt(aff_sim_allpix.size(-1))/8)).softmax(dim=0).clip(min=1e-4)
        #affinity_mask = (affinity_mask_unnorm2*1e1).softmax(dim=0).clip(min=1e-4)
        affinity_mask =affinity_mask_unnorm= (affinity_mask_unnorm).softmax(dim=0).clip(min=1e-4)

        # Below is rigidity bucket sanity case before going into discretized embedding set
        #affinity_mask=affinity_mask_unnorm=rearrange( self.rig_predictor(ch_sec(model_input["fmap"])).softmax(-1), "b t x o -> (b o) t x 1")
        #affinity_mask=rearrange( (1e1*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")
        #affinity_mask=ch_sec(F.interpolate(ch_fst(affinity_mask,imsize[0]).flatten(0,1),low_imres).unflatten(0,affinity_mask.shape[:2]))
        #low_imres=imsize
        #corr_weights_obj = (corr_weights * affinity_mask_unnorm[:,1:]).clip(min=1e-4)

        # Local sampling per embedding based on point track corresponding to embedding source location 
        get_xpix = lambda s,e,n: torch.stack(torch.meshgrid(*[torch.linspace(s+.01,e-.01,n) for _ in range(2)])).flatten(1,2)
        n_total_samp,samp_amts = 225,[.5,.25,.1]
        uniform_grid_pix = repeat(get_xpix(0,1,int(n_total_samp**.5)),"c p -> (b o) t p 1 c",o=aff_sources.size(2),b=b,t=n_trgt-1).cuda()
        local_samp_grids = torch.cat([get_xpix(-i,i,int((n_total_samp/len(samp_amts))**.5)) for i in samp_amts],1).T
        sample_locs = rearrange(aff_src_samp_locs[:,1:],"b t s 1 c -> (b s) t 1 1 c")+local_samp_grids.cuda()[None,None,:,None]
        sample_locs = torch.where((sample_locs<0)|(sample_locs>1),uniform_grid_pix[:,:,:sample_locs.size(2)],sample_locs)# replace where out of bounds with uniform grid

        #sample_locs=uniform_grid_pix

        # Estimate poses via procrustes and integrate ; this rearanging is ugly below for efficiency but can definitely refactor
        poses = geometry.procrustes(
            rearrange(grid_samp(ch_fst(eye_surf[:,1:],imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
            rearrange(grid_samp(ch_fst(corresp_surf,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
           (rearrange(grid_samp(ch_fst(corr_weights,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2))*
            grid_samp(ch_fst(affinity_mask_unnorm[:,1:],low_imres[0]),sample_locs).squeeze(2)).clip(min=1e-4)
                        )[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-5),poses[...,:3,-1]),-1)
        lie_perpix = (affinity_mask*poses_lie.unsqueeze(-2)).unflatten(0,(b,aff_sources.size(2))).sum(1)
        lie_perpix = ch_sec(F.interpolate(ch_fst(lie_perpix,low_imres[0]).flatten(0,1),imsize,mode="bilinear").unflatten(0,(b,n_trgt))) # upsample to hires (just bilinear rn)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks
        eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1).unflatten(0,(b,n_trgt)),
                             "b t c p 1 -> b p t c",)
        pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,
                                padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
        pose_all_to_all_perpix = repeat(pose_tracks,"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks,"b p t x y -> b p t s x y",s=n_trgt)
        track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks))[...,:3]
        track_reprojs = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        ## Unpack the per-object estimations and composite over them 
        #adj_opt_flow =  (adj_opt_flow*affinity_mask[:,1:]).unflatten(0,(b,self.n_rig)).sum(1)
        #track_affinities=rearrange(F.grid_sample(ch_fst(affinity_mask,imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]
        #                    ).flatten(0,1).unsqueeze(-2)*2-1).squeeze(-2).unflatten(0,(b*self.n_rig,n_trgt)), "b t c p 1 -> b p t c",)
        #track_reprojs = (track_reprojs*track_affinities.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        return out | {
            "rig_masks":rearrange(affinity_mask,"(b o) t xy 1 -> b t o xy 1",o=aff_sources.size(2)),
            "rig_masks_unnorm":rearrange(affinity_mask_unnorm,"(b o) t xy 1 -> b t o xy 1",o=aff_sources.size(2)),
            "affinity_emb":affinity_emb,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "lie_perpix":lie_perpix.unflatten(-2,imsize),
            "pose_perpix":pose_perpix.unflatten(-3,imsize),
            "eye_surf":eye_surf,
            "poses_all":poses,
            "track_reprojs":track_reprojs,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":pose_perpix[:,:,10], # until good static approximation just choose rand pix
            "flow_inp_": model_input["bwd_flow"],
        }

    def forward_(self, model_input, out={}): # static, clean
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=1000
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = res_depth + depth_inp

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,:,0,0]=focal[:,None]*model_input["org_ratio"][:1]
            model_input["intrinsics"][...,:,1,1]=focal[:,None]

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        # Estimate poses via procrustes 
        adj_transf = geometry.procrustes(eye_surf[:,1:,rand_subset], corresp_surf[:,:,rand_subset], corr_weights[:,:,rand_subset] )[1]
        # Integrate poses
        poses = adj_transf
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1)

        # Compute pose-induced optical flow
        adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]

        # Compute pose induced point tracks
        if "pred_tracks" in model_input:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1).squeeze(-2).unflatten(0,(b,n_trgt)),
                                                                                                                                            "b t c p 1 -> b p t c",)
            track_surf_reprojs = torch.einsum("bsnij,bksj->bksni",repeat(poses,"b t x y -> b s t x y",s=n_trgt).inverse()@repeat(poses,"b t x y -> b t s x y",s=n_trgt),hom(eye_tracks))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        return out | {
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "depth_inp":model_input["depth_inp"],
            "zoe_d_loss":(1/(1e-5+model_input["depth_inp"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":poses,
            "flow_inp_": model_input["bwd_flow"],
        }

    def forward_(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

        # Run optical flow and point track networks
        self.get_flow(model_input)

        # Est depth
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = res_depth + depth_inp

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
            model_input["intrinsics"][...,1,1]=focal

        # Get corresp weights per track corresp 
        track_feat   = rearrange(grid_samp(model_input["fmap"], model_input["pred_tracks"].unsqueeze(-2)),"b t c p 1 -> b t p c")
        corr_weights = self.corr_weighter_perpoint(torch.cat((track_feat[:,[0]].expand(-1,n_trgt,-1,-1),track_feat),-1)).sigmoid().clip(min=1e-4)

        # since too expensive without local sampling, will for now just use N point tracks until implemented
        neighb_list = np.arange(42**2)[::10]

        # Est poses per source point 
        rds_track = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
        eye_surf_track = rds_track * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)
        poses = geometry.procrustes(eye_surf_track[...,neighb_list,:], eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)[...,neighb_list,:], corr_weights[...,neighb_list,:])[1]

        point_track_surf_reproj = torch.einsum('btij,btpj->btpi',poses.inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        vis_mask = torch.minimum(model_input["pred_visibility"],model_input["pred_visibility"][:,[0]])
        point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * vis_mask.unsqueeze(-1) ).square().mean()

        ## Supervise optical flow 
        eye_surf = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1] * ch_sec(depth)
        adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]
        
        return out | {
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "point_track_loss":point_track_loss*1e2,
            "depth_inp":model_input["depth_inp"],
            "zoe_d_loss":(1/(1e-5+model_input["depth_inp"])-1/(1e-5+ch_sec(depth))).square().mean()*1e1,
            "corr_weights": ch_fst(corr_weights,42),
            "poses":poses,
            "flow_inp_": model_input["bwd_flow"],
        }
    def get_flow(self,model_input):
        (b,_),n_trgt,imsize=model_input["rgb"].shape[:2],model_input["rgb"].size(1),model_input["rgb"].shape[-2:]

        #if "depth_inp" not in model_input: 
        #    print("doing zoe depth")
        #    with torch.no_grad():
        #        zoe_est = F.interpolate(self.model_zoe_n.infer(F.interpolate(model_input["rgb_large"].flatten(0,1)/255,scale_factor=.25)),model_input["rgb"].shape[-2:]) # todo run offline
        #        #zoe_est = self.model_zoe_n.infer(model_input["rgb"].flatten(0,1)*.5+.5) # todo run offline
        #        model_input["depth_inp"] = ch_sec(zoe_est.unflatten(0,model_input["rgb"].shape[:2]))
        #    print("done doing zoe depth")

        # RAFT optical flow 
        def raft(x,y):
            # Split query into 2 if >4 since raft has high memory constraints  
            if self.args.overfit and len(x)>4: 
                x_left,y_left = x[:len(x)//2],y[:len(x)//2]
                x_right,y_right = x[len(x)//2:],y[len(x)//2:]
                return torch.cat((raft(x_left,y_left),raft(x_right,y_right)))
            # Use raft is overfitting since raft is slow but only need to compute once in this setting; otherwise use gm flow
            if self.args.overfit and not self.args.gm_flow: 
                out=self.raft_(x,y,num_flow_updates=32)[-1]
            else: 
                x,y=(x*.5+.5)*255,(y*.5+.5)*255 # gm flow uses inputs in [0,255]
                out=self.gm_flow(x,y,attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
            # Rescale flow to be in resolution-independent pixel coordinates (-1,1)
            out = out/(torch.tensor(x.shape[-2:][::-1])-1).to(x)[None,:,None,None] # normalize flow coordinates to [-1,1]
            return F.interpolate(out,imsize,mode="bilinear",antialias=True) # downsample from hires to lowres

        with torch.no_grad():
            self.n_track_frames=1 if self.args.overfit else 1
            if "bwd_flow" not in model_input: # don't run if flow already in scene dict

                # Run optical flow net
                raft_inp_x,raft_inp_y = self.raft_transforms(model_input["rgb_large"][:,1:].flatten(0,1).to(torch.uint8), model_input["rgb_large"][:,:-1].flatten(0,1).to(torch.uint8))
                model_input["bwd_flow"] = raft(raft_inp_x,raft_inp_y,).unflatten(0,(b,n_trgt-1))
                #dino_norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                #print("doing dino feats")
                #model_input["dino_feats"] = F.interpolate(self.upsampler(F.interpolate(model_input["rgb_large"].flatten(0,1),(280,280))/256),imsize).unflatten(0,(b,n_trgt))
                #model_input["dino_feats"] = torch.stack([F.interpolate(self.upsampler(F.interpolate(x,(280,280),mode="bilinear")/256),imsize,mode="bilinear") for x in model_input["rgb_large"].unbind(1)],1)
                #print("done dino feats")

                # Run point tracking
                if self.args.point_track:
                    pred_tracks,visibilities=[],[]
                    for i,start_frame in enumerate(torch.linspace(0,n_trgt-1,self.n_track_frames).long()):
                        print("doing video tracking",start_frame)
                        pred_track,visibility = self.co_tracker(model_input["rgb_large"][:1], grid_size=16 if 1 else 42 if self.args.overfit else 20, grid_query_frame=start_frame, backward_tracking=True)
                        pred_track_norm = pred_track/(torch.tensor(model_input["rgb_large"].shape[-2:][::-1])-1)[None,None,None].cuda()
                        pred_tracks.append(pred_track_norm);visibilities.append(visibility)

                        # Save point track visualization video 
                        if self.args.overfit or (not self.args.overfit and self.step%25==0):
                            print("writing point track video for visualization")
                            Visualizer(save_dir='./output/', pad_value=100).visualize( video=model_input["rgb_large"], tracks=pred_track, visibility=visibility, filename='tracks_%02d'%i, query_frame=start_frame);
                    model_input["pred_tracks"]=torch.cat(pred_tracks,2)
                    visibilities=torch.cat(visibilities,2)

                    # Pad batch dimension with dummy 0 values since point tracks don't support batched inference
                    model_input["pred_tracks"]=torch.cat((model_input["pred_tracks"],torch.ones_like(model_input["pred_tracks"]).expand(b-1,-1,-1,-1)))
                    model_input["pred_visibility"]=torch.cat((visibilities,torch.zeros_like(visibilities).expand(b-1,-1,-1)))

            # Add sky mask heuristic to mask out bad point tracks in outdoor scenes
            #if "not_sky" not in model_input and self.args.point_track: 
            #    model_input["not_sky"]= torch.stack([(self.seg_model(model_input["rgb_large"][:,i]/255).max(1)[1][:,None]!=10).float() for i in range(n_trgt)],1)
            #    not_sky = grid_samp(model_input["not_sky"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(1)).squeeze(1).squeeze(1).unflatten(0,(b,n_trgt))
            #    model_input["pred_visibility"]=model_input["pred_visibility"]*not_sky
