"""Code for pixelnerf and alternatives."""
import torch, torchvision
from torch import nn
import kornia
from einops import rearrange, repeat
import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear") # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

# Base model for FlowCam or FlowCams
class FlowMap(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Optical flow model - RAFT
        from torchvision.models.optical_flow import Raft_Large_Weights
        self.raft_transforms = Raft_Large_Weights.DEFAULT.transforms()
        from torchvision.models.optical_flow import raft_large
        self.raft_ = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False)

        # GMFlow is a faster alternative to RAFT - only used during large-scale training, for overfitting we use raft
        sys.path.append("./gmflow/")
        from gmflow.gmflow import GMFlow
        self.gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
        checkpoint = torch.load("./gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
        weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
        self.gm_flow.load_state_dict(weights, strict=False)
        for param in self.gm_flow.parameters(): param.requires_grad = False

        # Point tracking model
        from cotracker.predictor import CoTrackerPredictor
        self.co_tracker = CoTrackerPredictor( checkpoint=os.path.join( './third_party/co-tracker/checkpoints/cotracker_stride_4_wind_8.pth')).cuda()

        # Segmentation model to get sky region since particle tracking doesn't predict well on sky (not catastrophic otherwise though)
        #sys.path.append("./third_party/DeepLabV3Plus-Pytorch")
        #import predict as seg_predict
        #self.seg_model=seg_predict.model

        if not args.load_save and 0:
            print("making sam")
            from segment_anything import sam_model_registry, SamPredictor
            sam_checkpoint = "/home/cameronsmith/repos/scene_graph_extraction/sam_vit_h_4b8939.pth"
            model_type = "vit_h"
            device = "cuda"
            self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
            for param in self.sam.parameters(): param.requires_grad = False
            self.sam.to(device=device)
            from segment_anything.utils.transforms import ResizeLongestSide
            self.resize_transform = ResizeLongestSide(self.sam.image_encoder.img_size)
            print("made sam")

        # Our actual model
        fdim=128
        self.fmap_downproj = nn.Linear(512,fdim)
        self.depth_est = make_net([fdim,64,1])
        self.corr_weighter_perpoint = make_net([fdim*2,128,64,1])
        self.static_est = make_net([fdim,64,1])
        self.mask_scorer = make_net([fdim,64,1])
        self.convex_weights_upsampler = make_net([fdim,64,1])

        self.n_rig=10
        n_mask_level=3
        self.mask_level_predictor = make_net([fdim,64,n_mask_level])
        self.rig_predictor = make_net([fdim,128,64,self.n_rig])
        self.o_rep = lambda x: repeat(x,"b t ... -> (b o) t ...",o=self.n_rig)
        self.o_switch = lambda x: repeat(x,"b t o ... -> (b o) t ...")

        self.midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", pretrained=not self.args.scratch_net)
        self.midas_out=self.midas.scratch.output_conv
        self.midas.scratch.output_conv=nn.Identity()

        import conv_modules
        self.resnet_enc=nn.Sequential(conv_modules.PixelNeRFEncoder(in_ch=4),nn.Conv2d(512,fdim,3,padding=1))
        self.depth_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.focal_conv=nn.Conv2d(fdim,1,3,padding=1)
        self.temperature_conv=nn.Conv2d(fdim,1,3,padding=1)
        affinity_dim = 16
        self.affinities_conv=nn.Conv2d(fdim,affinity_dim,3,padding=1)

        # Depth-as-variable vars in case that setting is used
        self.corr_weights,self.depth=None,None
        self.step=0

        self.mask_scores=torch.nn.Parameter(torch.ones(1,6,21,1),requires_grad=True)

        #self.upsample_factor=8
        #self.upsampler = nn.Sequential(nn.Conv2d(feature_channels, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, self.upsample_factor ** 2 * 9, 1, 1, 0)).cuda()

    def forward(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        lowres_factor=1
        gs=int(np.sqrt(model_input["pred_tracks"].size(-2)))#grid size

        # Run optical flow and point track networks
        self.get_flow(model_input)
        #corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))

        # Est depth
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbd = torch.cat((model_input["rgb"],depth_inp),2)
        model_input["fmap"] = F.interpolate(self.resnet_enc(rgbd.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1)
        depth = depth_inp + res_depth

        # Est intrinsics
        if not self.args.use_gt_intrinsics:
            focal = self.focal_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().flatten(1,-1).mean(dim=-1)
            model_input["intrinsics"]=torch.eye(3)[None].float().to(depth).repeat(b,n_trgt,1,1)
            model_input["intrinsics"][...,0,2]=model_input["intrinsics"][...,1,2]=.5
            model_input["intrinsics"][...,0,0]=focal*model_input["org_ratio"]
            model_input["intrinsics"][...,1,1]=focal

        # Est affinity weights and similarities for each source point
        affinity_emb = F.normalize( self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)), dim=2 )
        aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], model_input["pred_tracks"][:,[0]].unsqueeze(-2)),"b t c p 1 -> b t p c")
        aff_sim = torch.einsum('b t p c, b t q c -> b t p q', aff_sources, aff_sources) # from all source pix to all other source pix
        
        #aff_sim=torch.ones_like(aff_sim)

        # Get corresp weights per track corresp 
        track_feat   = rearrange(grid_samp(model_input["fmap"], model_input["pred_tracks"].unsqueeze(-2)),"b t c p 1 -> b t p c")
        corr_weights = self.corr_weighter_perpoint(torch.cat((track_feat[:,[0]].expand(-1,n_trgt,-1,-1),track_feat),-1)).sigmoid().clip(min=1e-4)

        corr_weights_obj = corr_weights.unsqueeze(2) * aff_sim.unsqueeze(-1)


        # local sampling generate local neighbor grids
        #neighb_list = torch.cat([F.unfold( torch.arange(gs*gs).unflatten(0,(gs,gs)).float()[None,None], 5, dilation=dilation, padding=2*dilation ) for dilation in [1]],1)[0].long()
        #neighb_list = torch.cat((neighb_list,torch.arange(model_input["pred_tracks"].size(-2))[::10][:,None].expand(-1,gs**2)),0) # add uniform grid
        neighb_list = torch.arange(model_input["pred_tracks"].size(-2))[::10][:,None].expand(-1,42**2)

        # Est poses per source point 
        o_rep = lambda x: repeat(x,"b t ... -> b t o ...",o=model_input["pred_tracks"].size(-2))
        rds_track = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
        eye_surf_track = rds_track * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)
        #poses = geometry.procrustes(o_rep(eye_surf_track[...,neighb_list,:]),o_rep(eye_surf_track[...,neighb_list,:][:,[0]].expand(-1,n_trgt,-1,-1)),corr_weights_obj[...,neighb_list,:])[1]
        poses = geometry.procrustes(
                torch.gather(eye_surf_track,2,neighb_list.flatten()[None,None,:,None].expand(-1,n_trgt,-1,3).cuda()).unflatten(-2,neighb_list.shape).permute(0,1,3,2,4),
                torch.gather(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1),2,neighb_list.flatten()[None,None,:,None].expand(-1,n_trgt,-1,3).cuda()).unflatten(-2,neighb_list.shape).permute(0,1,3,2,4),
                torch.gather(corr_weights_obj,-2,neighb_list.T[None,None,:,:,None].expand(-1,n_trgt,-1,-1,-1).cuda())
                )[1]

        # Reproject point tracks using est poses (move frame 0 to all other frames) (just use pose per solve point)
        #point_track_surf_reproj = torch.einsum('btij,btpj->btpi',poses.inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        point_track_surf_reproj = torch.einsum('btpij,btpj->btpi',poses.inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * model_input["pred_visibility"].unsqueeze(-1) ).square().mean()

        ## Supervise optical flow 
        # Compress poses to liespace for sampling
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-5),poses[...,:3,-1]),-1)

        # Compress poses into N-body via pca
        #C=1
        #U, S, Vh = torch.linalg.svd(poses_lie)
        #U_reduced = U[:, :, :, :C]  # B x T x N x C
        #S_reduced = S[:, :, :C]     # B x T x C
        #Vh_reduced = Vh[:, :, :C, :]  # B x T x C x N
        #reconstructed_matrix = U_reduced @ torch.diag_embed(S_reduced) @ Vh_reduced
        #singular_value_penalty = torch.mean((S_reduced - 1)**2)
        #poses_lie=reconstructed_matrix

        ## just upsample up to hires even if not actual res then bilinear downsample to correct res (from gmflow/raft upsampler)
        #from pdb import set_trace as pdb_;pdb_() 
        #feature_channels=128
        #concat=ch_fst(track_feat.flatten(0,1),42)
        #mask = self.upsampler(concat)
        #flow = ch_fst(poses_lie.flatten(0,1),42)
        #b, flow_channel, h, w = flow.shape
        #mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w)  # [B, 1, 9, K, K, H, W]
        #mask = torch.softmax(mask, dim=2)


        #- Do n-body again, and it will miss some subtle motions
        #    - Though we might also consider penalizing e.g. the gradient of the flow maps and that might encourage object discovery since high gradient at boundaries
        #-  but also include a second set of points from the raw depth lifted point tracks — their se3 is just raw translation
        #    - Can also filter these out by only using them where the pose-induced flow is is worse than some threshold and can visualize that thrshold image
        #    - nbody should be the default code in the splatting repo


        #up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
        #up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w)  # [B, 2, 9, 1, 1, H, W]

        #up_flow = torch.sum(mask * up_flow, dim=2)  # [B, 2, K, K, H, W]
        #up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # [B, 2, K, H, K, W]
        #up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
        #                          self.upsample_factor * w)  # [B, 2, K*H, K*W]
        #from pdb import set_trace as pdb_;pdb_() 
        ##track_confidence = self.track_conf_predictor(torch.cat((track_feat[:,[0]].expand(-1,n_trgt,-1,-1),track_feat),-1))

        # TODO predict convex upsampling weights, simplest sanity case is just use bilinear upsampl

        if 0:
            temperature = self.temperature_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).relu().flatten(1,-1).mean(dim=-1)+30
            #print(temperature)
            # Get SE3 per pixel location using affinity embeddings to each source solve point
            aff_sim_allpix = torch.einsum('b t p c, b s c -> b t p s', ch_sec(F.interpolate(affinity_emb.flatten(0,1),scale_factor=lowres_factor).unflatten(0,(b,n_trgt))), aff_sources.squeeze(1)) # from all source pix to all other pix
            lie_perpix = torch.einsum('btsc,btps->btpc',poses_lie,(aff_sim_allpix*temperature).softmax(dim=-1))
            lie_perpix = ch_sec(F.interpolate(ch_fst(lie_perpix,imsize[0]//2).flatten(0,1),scale_factor=int(1/lowres_factor),mode="bilinear").unflatten(0,(b,n_trgt)))
            # Exponentiate to get pose per pixel
            pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
            pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
            pose_perpix[...,:3,-1] = lie_perpix[...,4:]
            # Compute pose flow
            eye_surf = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1] * ch_sec(depth)
            adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                    model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]
            out["pose_perpix"]=pose_perpix.unflatten(-3,imsize)
            out["lie_perpix"]=lie_perpix.unflatten(-2,imsize)
            out["flow_from_pose"]=adj_opt_flow
        
        # Note for static sanity case est where aff=1 will all be same se3 per point
        #eye_surf = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1] * ch_sec(depth)
        #adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]
        
        bg_pose = poses[:,:,gs//2 * gs //2] # todo figure out how to extract bg static pose with some reg

        #Compress trajs with PCA for vis/regularization -- todo refactor pca method out
        #trajs = rearrange(poses_lie,"b t xy c -> b xy (t c)")
        #center the data, get the singular value for penalizing their magnitudes after component N, add the center back and plot the poses corresponding to top n eigenvectors/poses
        #see note on regularization -- its always going to be some regularziation, maybe just anneal the simple affinity weights and find good value across many scenes empirically
        #PCA on trajs -- from P_srcxTx6 -> Ordered N_compxTx6 
        #trajs = rearrange(poses_lie,"b t xy c -> b (t c) xy")
        #trajs_mean = trajs.mean(dim=2, keepdim=True)
        #trajs_center = trajs - trajs_mean
        #covariance = torch.bmm(trajs_center, trajs_center.transpose(1, 2)) / (trajs_center.size(-1))
        #U, S, V = torch.svd(covariance)
        #num_components=6
        #traj_top_comps = torch.bmm(U[:, :, :num_components].transpose(1, 2), trajs_center).unflatten(-1,(42,42))

        #First reduce the dog size to 64x64 if too expensive
        #Here’s the regularization you’re going to use:
        #- Solve for the PCA of the se3 field:
        #    - Compute the pose-induced flow N times using up to N components
        #    - This way the network doesn’t have to discover the correct optimal rigid affinities loss tradeoff but rather we’re discretizing over them
        #    - optionally downright the few-component solves or only start loss after N components even

        return out | {
            #"eigens":U,
            #"traj_top_comps":traj_top_comps,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "poses_lie":poses_lie.unflatten(-2,(gs,gs)),
            "point_track_loss":point_track_loss*1e2,
            "corr_weights": ch_fst(corr_weights,gs),
            "poses":bg_pose,
            "flow_inp_": model_input["bwd_flow"],
            "intrinsics": model_input["intrinsics"],
            "affinity_emb": affinity_emb,
            "affinity_sim": aff_sim.unflatten(-2,(gs,gs)).unflatten(-1,(gs,gs)),
        }

    def forward_(self, model_input, out={}): 
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

        # Random subset used for procrustes solving (since expensive to run on full image)
        n_samp=1000
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long() if self.args.overfit else torch.randperm(imsize[0]*imsize[1])[:n_samp] 

        # Run optical flow and point track networks
        self.get_flow(model_input)

        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))

        # Get depth and flow-confidence weights
        depth,corr_weights = self.get_depth_and_flow_weights(model_input,corresp_uv)
        #depth=model_input["zoe_depth"]

        # Predict intrinsics
        #if not self.args.use_gt_intrinsics: model_input["intrinsics"] = self.intrinsics_est(model_input,depth,corresp_uv,rand_subset,corr_weights*ch_sec(is_static)[:,1:]) 

        # Estimate poses via procrustes and integrate relative poses through time
        corr_weights_obj = corr_weights.unsqueeze(2)

        from pdb import set_trace as pdb_;pdb_() 
        # Get adjacent poses
        eye_surf, poses= self.adj_pose_est(model_input| {k:self.o_rep(model_input[k]) for k in ["x_pix","intrinsics"]},
                                                            self.o_rep(corresp_uv),self.o_rep(depth),self.o_switch(corr_weights_obj),rand_subsets)
        # Add identity matrix to first pose
        poses = torch.cat((poses[...,:1,:,:].detach().clone(),poses),-3)
        poses[...,:1,:,:]=torch.eye(4).to(poses)

        # Pose-induced point tracks
        if "pred_tracks" in model_input:

            ## 3D identity-camera surface at each point track location
            eye_tracks = grid_samp(ch_fst(eye_surf,imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]).flatten(0,1).unsqueeze(-2))
            eye_tracks = rearrange(eye_tracks,"(b t) c p 1 -> b p t c",t=n_trgt)

            ## Warp all frames to all other frames
            all_pair_poses = repeat(poses,"b t x y -> b s t x y",s=n_trgt).inverse() @ repeat(poses,"b t x y -> b t s x y",s=n_trgt)
            all_pair_surfs = torch.einsum("bsnij,bksj->bksni",all_pair_poses,hom(eye_tracks))[...,:3]

            all_pair_pix_uncomp = project(all_pair_surfs,model_input["intrinsics"][:,:1]).clip(0,1)

            affinity_sample = grid_samp(rearrange(affinity_mask,"b t o (x y) 1 -> (b o) t 1 x y",x=imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]).flatten(0,1).unsqueeze(-2))
            affinity_sample = rearrange(affinity_sample,"(b t) c p 1 -> b p t c",t=n_trgt)
            out["all_pair_pix"] = (all_pair_pix_uncomp*affinity_sample.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        # Unpack and composite pose induced flow 
        flow_from_pose_comp = ( rearrange(flow_from_pose,"(b o) c ... -> b c o ...",o=affinity_mask.size(2)) * affinity_mask[:,1:] ).sum(2)
        poses = rearrange(poses,"(b o) ... -> b o ...",o=affinity_mask.size(2))

        pts_canon = torch.einsum("brtij,brtxj->brtxi",poses,hom(eye_surf.unflatten(0,(b,self.n_rig))))[...,:3]

        return out | {
            "depth":depth,
            "zoe_depth":model_input["zoe_depth"],
            "zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+depth)).square().mean()*1e4,
            "corr_weights": corr_weights,
            "rig_masks": affinity_mask,
            "pts_canon": pts_canon,
            "poses":poses,
            "intrinsics": model_input["intrinsics"],
        }


    def forward_(self, model_input, out={}):
        self.step+=1

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

        # Random subset used for procrustes solving (since expensive to run on full image)
        n_samp=1000
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long() if self.args.overfit else torch.randperm(imsize[0]*imsize[1])[:n_samp] 

        # Run optical flow and point track networks
        self.get_flow(model_input)
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))

        # Get depth and flow-confidence weights
        depth,corr_weights = self.get_depth_and_flow_weights(model_input,corresp_uv)

        # Predict intrinsics
        if not self.args.use_gt_intrinsics: model_input["intrinsics"] = self.intrinsics_est(model_input,depth,corresp_uv,rand_subset,corr_weights) 

        # Predict rigid-embedding weights -- todo make sure sums to 1

        # Affinity weights is just masks and we have a separate static mask for the global mask
        # second choose rand subset by masking x_pix (don't need to vectorize at first), just keep masks small loop over and pick linearly spaced points from masked xpix
        # third is need affinity masks -- take fmap * masks, simple tiny conv to map the masked image feature to some tensor that we can softmax over spatially to get affinity masks

        mask_levels=[2]#sorted([int(k[-1]) for k in model_input.keys() if "all_mask" in k])
        masks = [torch.ones_like(model_input["all_masks_%i"%mask_levels[0]][:,:1])[None]]+[model_input["all_masks_%d"%i][None] for i in mask_levels]
        masks_all=torch.cat(masks,2)
        is_static = ch_fst(self.static_est(ch_sec(model_input["fmap"])),imsize[0])
        #masks = torch.cat((torch.ones_like(model_input["all_masks"][:,:,:1]),model_input["all_masks"]),2)
        #affinity_weights = torch.cat((ch_fst(self.static_est(ch_sec(model_input["fmap"])),imsize[0]),masks[:,:,1:]),2)
        #self.n_rig=masks.size(2)

        # affinities=conv(affinity_weights*fmap).softmax(2)

        #note for integration: since were trying to choose masks per frame we can probably use a small transformer or 3d conv to predict consistent N rigid bodies across the frames
        #instead of relying on the point tracks for integration
        #(model_input["fmap"].unsqueeze(2)*masks.unsqueeze(3))
        # Max pool to get mask embeddings (dumb but simple)
        mask_embeddings = [ch_sec((model_input["fmap"].unsqueeze(2)*masks_.unsqueeze(3))).max(dim=-2)[0] for masks_ in masks]
        mask_scores = [self.mask_scorer(x).relu() for x in mask_embeddings]
        mask_spatial_scores = [mask_score.unsqueeze(-1)*spatial_mask for mask_score,spatial_mask in zip(mask_scores,masks)]
        level_scores = ch_fst(self.mask_level_predictor(ch_sec(model_input["fmap"])).softmax(dim=-1),imsize[0])

        mask_spatial_scores = mask_spatial_scores[:1]
        masks_all=masks_all[:,:,:1]
        is_static=torch.ones_like(is_static)

        # used for compositing flow
        composite_weights = [mask_spatial_score.flatten(-2,-1).softmax(dim=-2) * level_score 
                                for mask_spatial_score,level_score in zip(mask_spatial_scores, ch_sec(level_scores).unsqueeze(2).unbind(-1))]

        #affinity_mask=self.rig_predictor(ch_sec(model_input["fmap"])).softmax(-1).permute(0,1,3,2).unsqueeze(-1)

        # Estimate poses via procrustes and integrate relative poses through time
        # TODO make affinity masks based on object mask embedding (can encode mask as center feature + bounding box of mask?)
        #corr_weights_obj = (corr_weights.unsqueeze(2) * affinity_mask[:,:-1])
        corr_weights_obj = torch.cat(((corr_weights*ch_sec(is_static)[:,1:]).unsqueeze(2),
                                       corr_weights[:,:,None].expand(-1,-1,masks_all.size(2)-1,-1,-1)),2).clip(min=1e-4)

        # Get adjacent poses -- sample points from within each mask -- todo use different points per mask level
        points_permask = 500 
        rand_shuffle_noise=torch.rand_like(masks_all)*.001  # random noise to "shuffle" the sorting process
        rand_subsets = torch.sort(ch_sec((masks_all+rand_shuffle_noise).unsqueeze(-3)), dim=-2, descending=True)[1][...,:points_permask,:]

        eye_surf, poses= self.adj_pose_est(model_input| {k:self.o_rep(model_input[k]) for k in ["x_pix","intrinsics"]}, 
                                    self.o_rep(corresp_uv),self.o_rep(depth),self.o_switch(corr_weights_obj),rand_subsets)

        ## TODO integration -- use point tracks to integrate poses through time 
        ## Integrate poses through time
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[...,:i,:,:],poses[...,[i-1],:,:]@poses[...,i:,:,:]),-3)
        # Add identity matrix to first pose
        poses = torch.cat((poses[...,:1,:,:].detach().clone(),poses),-3)
        poses[...,:1,:,:]=torch.eye(4).to(poses)

        # Pose-induced flow : transform all frames back one timestep and project for 2D locations, subtract uv grid for pose-induced optical flow
        flow_from_pose = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]

        # Pose-induced point tracks
        if "pred_tracks" in model_input:

            ## 3D identity-camera surface at each point track location
            eye_tracks = grid_samp(ch_fst(eye_surf,imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]).flatten(0,1).unsqueeze(-2))
            eye_tracks = rearrange(eye_tracks,"(b t) c p 1 -> b p t c",t=n_trgt)

            ## Warp all frames to all other frames
            all_pair_poses = repeat(poses,"b t x y -> b s t x y",s=n_trgt).inverse() @ repeat(poses,"b t x y -> b t s x y",s=n_trgt)
            all_pair_surfs = torch.einsum("bsnij,bksj->bksni",all_pair_poses,hom(eye_tracks))[...,:3]

            all_pair_pix_uncomp = project(all_pair_surfs,model_input["intrinsics"][:,:1]).clip(0,1)

            affinity_sample = grid_samp(rearrange(affinity_mask,"b t o (x y) 1 -> (b o) t 1 x y",x=imsize[0]).flatten(0,1),self.o_rep(model_input["pred_tracks"]).flatten(0,1).unsqueeze(-2))
            affinity_sample = rearrange(affinity_sample,"(b t) c p 1 -> b p t c",t=n_trgt)
            out["all_pair_pix"] = (all_pair_pix_uncomp*affinity_sample.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        # Unpack and composite pose induced flow 
        composite_weights_flat = torch.cat(composite_weights,2).unsqueeze(-1)

        flow_from_pose_comp = ( rearrange(flow_from_pose,"(b o) c ... -> b c o ...",o=composite_weights_flat.size(2)) * composite_weights_flat[:,:-1] ).sum(2)
        poses = rearrange(poses,"(b o) ... -> b o ...",o=composite_weights_flat.size(2))

        #pts_canon = torch.einsum("brtij,brtxj->brtxi",poses,hom(eye_surf.unflatten(0,(b,self.n_rig))))[...,:3]

        return out | {
            "depth":depth,
            "composite_weights":composite_weights,
            "level_scores":level_scores,
            "mask_spatial_scores":mask_spatial_scores,
            "is_static":is_static,
            #"zoe_depth":model_input["zoe_depth"],
            #"zoe_d_loss":(self.zoe_depth-depth).square().mean()*1e1,
            #"zoe_d_loss":(1/(1e-5+model_input["zoe_depth"])-1/(1e-5+depth)).square().mean()*1e4, # TODO move to args loss
            "corr_weights": corr_weights,
            #"rig_masks": affinity_mask,
            #"pts_canon": pts_canon,
            "flow_from_pose":flow_from_pose_comp,
            "flow_from_pose_uncomp":flow_from_pose_comp,
            "poses":poses[:,0],
            "rig_poses":poses,
            "flow_inp_": model_input["bwd_flow"],
            "intrinsics": model_input["intrinsics"],
        }
    
    # Procrustes relative pose estimation
    def adj_pose_est(self, model_input, corresp_uv, depth, corr_weights, rand_subset):

        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*depth

        # just use einops to simplify this with pattern b ... t x y  -> (b ...) t x y
        if len(rds.shape)!=4: eye_surf,corresp_uv = [x.flatten(0,1) for x in [eye_surf,corresp_uv]]

        corresp_surf = grid_samp(ch_fst(eye_surf,model_input["rgb"].size(-2))[:,:-1],corresp_uv.unsqueeze(2)).squeeze(-2).permute(0,1,3,2)#.unflatten(0,rds.shape[:2])

        if len(rand_subset.shape)==1: rand_subset=repeat(rand_subset,"x -> bo t x 1",bo=corresp_surf.size(0),t=corresp_surf.size(1)).to(corresp_surf).long()
        else: rand_subset = rearrange(rand_subset,"b t o x 1 -> (b o) t x 1").to(corresp_surf).long()

        if rand_subset.size(1)==1:rand_subset=rand_subset.expand(-1,2,-1,-1)

        adj_transf = geometry.procrustes(torch.gather(eye_surf[:,1:],2,rand_subset.expand(-1,-1,-1,3)[:,1:]), 
                                         torch.gather(corresp_surf,  2,rand_subset.expand(-1,-1,-1,3)[:,1:]), 
                                         torch.gather(corr_weights,  2,rand_subset[:len(corr_weights),1:]), 
                                         )[1]

        if len(rds.shape)!=4: eye_surf,adj_transf = [x.unflatten(0,rds.shape[:2]) for x in [eye_surf,adj_transf]]
        return eye_surf, adj_transf

    #def adj_pose_est(self, model_input, corresp_uv, depth, corr_weights, rand_subset):

    #    rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
    #    eye_surf=rds*depth

    #    # Add support for any number of dimensions, just take shape of b * t .... and flatten then unflatten it, 
    #    # useful not just for intrinsics but also for potential object centric formulations 
    #    # just use einops to simplify this with pattern b ... t x y  -> (b ...) t x y
    #    if len(rds.shape)!=4: eye_surf,corresp_uv = [x.flatten(0,-3) for x in [eye_surf,corresp_uv]]

    #    corresp_surf = grid_samp(ch_fst(eye_surf,model_input["rgb"].size(-2))[:,:-1],corresp_uv.unsqueeze(2)).squeeze(-2).permute(0,1,3,2)#.unflatten(0,rds.shape[:2])

    #    adj_transf = geometry.procrustes(eye_surf[:,1:,rand_subset], corresp_surf[:,:,rand_subset], corr_weights[:,:,rand_subset] )[1]

    #    if len(rds.shape)!=4: eye_surf,adj_transf = [x.unflatten(0,rds.shape[:-2]) for x in [eye_surf,adj_transf]]
    #    return eye_surf, adj_transf

    def get_seg(self,model_input):

        if any(["all_masks" in k for k in model_input.keys()]):return model_input

        grid_sl=16
        buf=5
        sam_sl=1024
        sam_inp_imgs = F.interpolate(model_input["rgb_large"].flatten(0,1),(sam_sl,sam_sl)).to(torch.uint8)
        query_crds = torch.stack(torch.meshgrid(torch.linspace(buf,sam_sl-buf,grid_sl),torch.linspace(buf,sam_sl-buf,grid_sl)),-1).cuda().flatten(0,1)[:,None]

        #query_sl_stride = 1
        #query_crds_all = model_input["pred_tracks"][:,:,::query_sl_stride]*sam_sl
        #query_crd_visibility = model_input["pred_visibility"][:,:,::query_sl_stride]

        mask_levels=[2,1][:1]
        max_masksperlevel=[20,60][:len(mask_levels)]
        masks_vis_,all_masks_,keep_idxs_all=[[[] for _ in range(len(mask_levels))] for _ in range(3)]
        for img_i,img in enumerate(sam_inp_imgs):
            #query_crds=query_crds_all[0,img_i].unsqueeze(1)
            batched_input = [
                 {
                     'image': img,
                     'point_coords': self.resize_transform.apply_coords_torch(query_crds, [sam_sl,sam_sl]).cuda(),
                     'point_labels': torch.ones_like(query_crds[...,0]),
                     'original_size': [sam_sl,sam_sl],
                     },
                 ]

            print("querying sam")
            with torch.no_grad(): batched_output = self.sam(batched_input, multimask_output=True)
            all_masks = F.interpolate(batched_output[0]['masks'].flatten(0,1)[:,None].float(),model_input["rgb"].shape[-2:]).bool().squeeze(1).unflatten(0,batched_output[0]["masks"].shape[:2])

            #all_masks = all_masks * query_crd_visibility[0,img_i][:,None,None,None]

            # Remoing redundant masks
            sf=1
            all_masks_hires= all_masks
            all_masks = torch.nn.functional.interpolate(all_masks.flatten(0,1)[:,None].float(),scale_factor=sf).squeeze(1).unflatten(0,all_masks.shape[:2]).bool()
            all_bboxes=[]
            for mask in all_masks.flatten(0,1):
                if mask.sum()<10: box=torch.tensor([[ 0.,  0., 0.,  0.]]).cuda()
                else: box=ops.masks_to_boxes(mask[None])
                all_bboxes.append( box )
            all_bboxes = torch.cat(all_bboxes).unflatten(0,all_masks.shape[:2])
            # filter out dup bboxes
            iou_threshold = 0.5
            keep_idxs=[torchvision.ops.nms(all_bboxes[:,i],torch.rand_like(all_bboxes[:,0,0]),.5) for i in range(3)]
            for j,i in enumerate(mask_levels):
                keep_idxs_all[j].append(keep_idxs[i])

            # need to a) expand keep_idxs to all timesteps in the trajectory (via max over time), and b) use visibility to mask out where the query point wasn't in frame

            img_low=F.interpolate(img[None],model_input["rgb"].shape[-2:])[0]
            masks_low=F.interpolate(all_masks.flatten(0,1)[:,None].float(),model_input["rgb"].shape[-2:]).squeeze(1).unflatten(0,all_masks.shape[:2])
            for (j,i),max_masks in zip(enumerate(mask_levels),max_masksperlevel):#range(3):
                masks_i=masks_low[keep_idxs[i],i]
                masks_i_hires=all_masks[keep_idxs[i],i]
                #masks_i=masks_low[:,i]
                padded_hires=torch.cat((masks_i_hires[:max_masks],torch.zeros(max(0,max_masks-len(masks_i)),*masks_i_hires.shape[1:]).cuda()))
                padded=torch.cat((masks_i[:max_masks],torch.zeros(max(0,max_masks-len(masks_i)),*masks_i.shape[1:]).cuda()))
                #padded=masks_i
                mask_vis = (padded[:,None,...,None].cpu().float()*torch.tensor([30, 144, 255])[None,None,None,None])+img_low.permute(1,2,0)[None,None].cpu()
                masks_vis_[j].append(mask_vis)
                all_masks_[j].append(padded)
        #model_input["all_masks"],model_input["masks_vis"] = [torch.stack([torch.stack(x) for x in y]) for y in [all_masks_,masks_vis_]]
        for i,j in enumerate(mask_levels):
            model_input["all_masks_%i"%j],model_input["masks_vis_%i"%j] = torch.stack(all_masks_[i]),torch.stack(masks_vis_[i])

        #mask_samps = grid_samp(model_input["all_masks"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2))
        #eye_tracks = rearrange(eye_tracks,"(b t) c p 1 -> b p t c",b=b)

        # Use point tracks to associate masks over time 

    # Predicts depth and flow weights from source depending on args (var, midas net, scratch net)
    def get_depth_and_flow_weights(self, model_input, corresp_uv):
        (b,_),n_trgt,imsize=model_input["rgb"].shape[:2],model_input["rgb"].size(1),model_input["rgb"].shape[-2:]

        # Depth source from optimized variable
        if self.depth is not None:

            depth=self.depth.clip(min=.1)
            corr_weights=self.corr_weights.abs()

        # Use depth as network but interpret as disparity and use stable-midas heuristics 
        elif not self.args.scratch_net:

            midas_feats = self.midas((model_input["rgb"]*.5+.5).flatten(0,1))
            model_input["fmap"]=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))/20
            depth = 1e3/(ch_sec(self.midas_out(midas_feats).unflatten(0,(b,n_trgt)))+1e-1)

            # Sample corresponding feature and predict flow-confidence weights
            corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
            corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

            # Prior-preserving hacky heuristic: downweight correspondences with very large depth differences
            corresp_depth = F.grid_sample(ch_fst(depth,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
            corr_weights = corr_weights * (1-2*(((corresp_depth-depth[:,1:]).abs()).sigmoid()-.5)) * (1-(torch.maximum(corresp_depth,depth[:,1:])/3).tanh()) # downweight points far away and with large diff h

        # Use depth as network but use a setup which is simpler and more stable training from scratch
        else: 

            midas_feats = self.midas((model_input["rgb"]*.5+.5).flatten(0,1))
            model_input["fmap"]=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))
            depth = ch_sec(F.softplus(self.midas_out(midas_feats+20)).unflatten(0,(b,n_trgt)))+1

            # Sample corresponding feature and predict flow-confidence weights
            corresp_feat = grid_samp(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
            corr_weights = (self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1))*1e-1).sigmoid().clip(min=1e-4)

        return depth, corr_weights

    def intrinsics_est(self,model_input,depth,corresp_uv,rand_subset,corr_weights,imsize=None,xpix=None):

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)

        rand_subset = torch.randperm(imsize[0]*imsize[1])[:100] # for procrustes estimation, independent of ctxt/trgt rays

        # Make candidate focal set - TODO shouldn't focals have a batch dim? only matters for generalization but still. Add it.
        focals=torch.linspace(.5,2,60).cuda()
        focal_set = torch.stack([focals*model_input["org_ratio"][0],focals],1)

        # Make candidate intrinsics matrices corresponding to candiate focal set
        sample_intrinsics=torch.eye(3)[None,None,None].float().to(depth).repeat(b,len(focal_set),1,1,1)
        sample_intrinsics[...,0,2]=sample_intrinsics[...,1,2]=.5
        sample_intrinsics[...,0,0]=focal_set[None,:,None,0]
        sample_intrinsics[...,1,1]=focal_set[None,:,None,1]

        # Estimate pose for just first frame using each intrinsics
        model_input_ = model_input | {"intrinsics":sample_intrinsics,"x_pix":model_input["x_pix"][:,:1].unsqueeze(1)}
        eye_surf, adj_pose = self.adj_pose_est(model_input_,corresp_uv[:,None,:1].expand(-1,len(focal_set),-1,-1,-1),depth[:,:2],corr_weights[:,:1],rand_subset)

        # Compute pose-induced flow for each candidate intrinsic
        flow_from_pose = warp(eye_surf[:,:,1:],adj_pose,sample_intrinsics)-model_input_["x_pix"]

        # Pose induced flow error maps
        flow_errs = ( (flow_from_pose.squeeze(2)-ch_sec(model_input["bwd_flow"][:,:1])).abs() * corr_weights[:,:1] )[:,:,rand_subset].sum(dim=[-2,-1])
        flow_errs = (flow_errs-flow_errs.min(dim=1,keepdim=True)[0])*1e1

        # Soft argmin on flow errors to choose candidate intrinsics
        est_focal = (F.softmin(flow_errs,dim=1).unsqueeze(-1)*focal_set[None]).sum(1)[:,None].expand(-1,n_trgt,-1)

        intrinsics=torch.eye(3)[None,None].float().to(depth).repeat(b,n_trgt,1,1)
        intrinsics[:,:,0,2]=intrinsics[:,:,1,2]=cx=cy=.5
        intrinsics[:,:,0,0]=est_focal[...,0]
        intrinsics[:,:,1,1]=est_focal[...,1]
        return intrinsics

    def get_flow(self,model_input):
        (b,_),n_trgt,imsize=model_input["rgb"].shape[:2],model_input["rgb"].size(1),model_input["rgb"].shape[-2:]

        # RAFT optical flow 
        def raft(x,y):

            print("doing raft",x.shape)

            # Split query into 2 if >4 since raft has high memory constraints  
            if self.args.overfit and len(x)>4: 
                x_left,y_left = x[:len(x)//2],y[:len(x)//2]
                x_right,y_right = x[len(x)//2:],y[len(x)//2:]
                return torch.cat((raft(x_left,y_left),raft(x_right,y_right)))
            # Use raft is overfitting since raft is slow but only need to compute once in this setting; otherwise use gm flow
            if self.args.overfit and not self.args.gm_flow: 
                out=self.raft_(x,y,num_flow_updates=32)[-1]
            else: 
                x,y=(x*.5+.5)*255,(y*.5+.5)*255 # gm flow uses inputs in [0,255]
                out=self.gm_flow(x,y,attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
            # Rescale flow to be in resolution-independent pixel coordinates (-1,1)
            out = out/(torch.tensor(x.shape[-2:][::-1])-1).to(x)[None,:,None,None] # normalize flow coordinates to [-1,1]
            return F.interpolate(out,imsize,mode="bilinear",antialias=True) # downsample from hires to lowres

        with torch.no_grad():
            self.n_track_frames=1#4 if self.args.overfit else 1
            if "bwd_flow" not in model_input: # don't run if flow already in scene dict

                # Run optical flow net
                raft_inp_x,raft_inp_y = self.raft_transforms(model_input["rgb_large"][:,1:].flatten(0,1).to(torch.uint8), model_input["rgb_large"][:,:-1].flatten(0,1).to(torch.uint8))
                model_input["bwd_flow"] = raft(raft_inp_x,raft_inp_y,).unflatten(0,(b,n_trgt-1))

                # Run point tracking
                if self.args.point_track:
                    pred_tracks,visibilities=[],[]
                    for i,start_frame in enumerate(torch.linspace(0,n_trgt-1,self.n_track_frames).long()):
                        print("doing video tracking",start_frame)
                        pred_track,visibility = self.co_tracker(model_input["rgb_large"][:1], grid_size=42 if self.args.overfit else 20, 
                                grid_query_frame=start_frame, backward_tracking=True)
                        pred_track_norm = pred_track/(torch.tensor(model_input["rgb_large"].shape[-2:][::-1])-1)[None,None,None].cuda()

                        # Heuristically filter out bad tracks -- need a better point tracker -- based on color along tracks
                        #track_rgbs=grid_samp(model_input["rgb_large"][0]/255,pred_track_norm[0][:,None]).squeeze(-2).permute(0,2,1)
                        #visibility = visibility * (((track_rgbs[[start_frame]]-track_rgbs).abs().norm(dim=-1)*visibility)<.2)

                        pred_tracks.append(pred_track_norm);visibilities.append(visibility)

                        # Save point track visualization video 
                        if self.args.overfit or (not self.args.overfit and self.step%25==0):
                            print("writing point track video for visualization")
                            Visualizer(save_dir='./output/', pad_value=100).visualize( 
                                            video=model_input["rgb_large"], tracks=pred_track, 
                                            visibility=visibility, filename='tracks_%02d'%i, query_frame=start_frame);
                    model_input["pred_tracks"]=torch.cat(pred_tracks,2)
                    visibilities=torch.cat(visibilities,2)

                    # Pad batch dimension with dummy 0 values since point tracks don't support batched inference
                    model_input["pred_tracks"]=torch.cat((model_input["pred_tracks"],torch.ones_like(model_input["pred_tracks"]).expand(b-1,-1,-1,-1)))
                    model_input["pred_visibility"]=torch.cat((visibilities,torch.zeros_like(visibilities).expand(b-1,-1,-1)))

            # Add sky mask heuristic to mask out bad point tracks in outdoor scenes
            #if "not_sky" not in model_input and self.args.point_track: 
            #    model_input["not_sky"]= torch.stack([(self.seg_model(model_input["rgb_large"][:,i]/255).max(1)[1][:,None]!=10).float() for i in range(n_trgt)],1)
            #    not_sky = grid_samp(model_input["not_sky"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(1)).squeeze(1).squeeze(1).unflatten(0,(b,n_trgt))
            #    model_input["pred_visibility"]=model_input["pred_visibility"]*not_sky
