import torch, torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
import torchvision.transforms as T

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

import torch_kmeans
from torch_kmeans import KMeans

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear",padding_mode="border") # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

class FlowMap(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Our actual model
        self.time_stride = args.time_stride
        fdim=32#*self.time_stride#384
        self.fmap_downproj = nn.Linear(512,fdim)
        self.fmap_downproj2= nn.Linear(fdim,9)
        self.depth_est = make_net([fdim,fdim,1])
        self.corr_weighter_perpoint = make_net([fdim*2,16,1])

        self.n_rig=2
        self.rig_predictor = make_net([fdim,fdim,self.n_rig])
        #self.rig_predictor = make_net([9,fdim,self.n_rig])

        self.resnet_enc=nn.Sequential(PixelNeRFEncoder(in_ch=3*self.time_stride,use_first_pool=True),nn.Conv2d(512,fdim*self.time_stride,3,padding=1))
        #self.resnet_enc=nn.Sequential(PixelNeRFEncoder(in_ch=12*self.time_stride),nn.Conv2d(512,fdim*self.time_stride,3,padding=1))
        self.depth_conv=nn.Conv2d(fdim,1,3,padding=1)

        affinity_dim = 16
        self.affinities_conv=nn.Conv2d(fdim,affinity_dim,3,padding=1)
        self.general_confidence_conv=nn.Conv2d(fdim,1,3,padding=1)

    def forward(self, model_input, out={}): # point track based

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=imsize[0]*imsize[1]#min(40000,imsize[0]*imsize[1])
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()
        low_imres=(64,64)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])

        # General feature map backbone prediction
        img_inp = model_input["rgb"]
        img_inp=rearrange(img_inp,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        fmap_out = F.interpolate(self.resnet_enc(img_inp.flatten(0,1)*.5+.5),imsize,mode="bilinear")
        model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)
        
        # Pred depth
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1)/2
        depth = depth_inp + res_depth

        # Lift point track into 2.5D image-aligned surface
        rds_track = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
        eye_surf_track = rds_track * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)

        # Predict general correspondence weights -- how reliable is this feature generally
        general_conf = self.general_confidence_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().clip(min=1e-4)

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb = F.normalize( self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)), dim=2 )
        affinity_emb = torch.ones_like(affinity_emb);print("sanity affinity as ones")# sanity check

        from pdb import set_trace as pdb_;pdb_() 
        #aff_sources = rearrange(grid_samp(affinity_emb[:,[0]], model_input["pred_tracks"][:,[0]].unsqueeze(-2)),"b t c p 1 -> b t p c")
        #aff_sim = torch.einsum('b t p c, b t q c -> b t p q', aff_sources, aff_sources) # from all source pix to all other source pix

        # Need to do -- local sampling (use unfold example in other dir)
        # Solve for N_track poses per timestep and integrating along it, use local samples instead of all points to all points since would require ~5k^2 pose estimations
        poses= geometry.efficient_procrustes( eye_surf[:,1:,:].expand(self.n_rig,-1,-1,-1), corresp_surf[:,:,:].expand(self.n_rig,-1,-1,-1), 
                (corr_weights[:,:,:]*rig_masks[:,1:,:]).clip(min=1e-4),)[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add identity for starting pose

        # Main loss signal -- reproject all points to all other points
        point_track_surf_reproj = torch.einsum('btpij,btpj->btpi',poses.inverse(),hom(eye_surf_track[:,[0]].expand(-1,n_trgt,-1,-1)))[...,:3]
        point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        point_track_loss = ( (point_track_reproj - model_input["pred_tracks"]) * model_input["pred_visibility"].unsqueeze(-1) ).square().mean()

        # For vis, unfold image into 3 42x42 grids per frame, and vis the adjacent lie as we did before as well as the opt flow from all rames to all other frames in 42x42 grid

        # Optical flow is just reprojection from frame i to frame i-1 (no need to recompute, just for vis)

        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks 
        if "pred_tracks" in model_input and 1:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",
                ).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            out["track_idxs"] = torch.randperm(model_input["pred_tracks"].size(-2))[:int(8000//(n_trgt**2/20**2))] # choose random smaller subset to combat quadratic complexity
            pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1, padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
            pose_all_to_all_perpix = repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p t s x y",s=n_trgt)
            track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks[:,out["track_idxs"]]))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)
        else:
            print("skipping point tracks sup")

        bg_pose = poses[rig_masks.flatten(1,-1).sum(-1).max(dim=0)[1]][None]

        return out | {
            "rig_masks":rearrange(rig_masks,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "poses":poses,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "flow_inp_": model_input["bwd_flow"],
            "world_crds": torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3],
            "rgb_crds": ch_sec(model_input["rgb"]),
            "lie_crds" : (rig_masks.flatten(1,2)[...,None] * poses_lie[:,None]).sum(0),
            "lie_perpix":lie_perpix.unflatten(-2,imsize),
        }



    def forward_(self, model_input, out={}): # mlp rigid rasac bucket

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=imsize[0]*imsize[1]#min(40000,imsize[0]*imsize[1])
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()
        low_imres=(64,64)

        #img_inp = torch.cat((model_input["rgb"],depth_inp/10-1,flow_inp,model_input["dino_pca"][:,:,:]*2,
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        flow_inp = torch.cat((torch.zeros_like(model_input["bwd_flow"][:,:1]),model_input["bwd_flow"]*5e1),1)
        #img_inp = torch.cat((model_input["rgb"],model_input["dino_pca"][:,:,:]*2),2)
        #img_inp = model_input["dino_pca"][:,:,:3]
        img_inp = model_input["rgb"]
        #img_inp=flow_inp
        img_inp=rearrange(img_inp,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        fmap_out = F.interpolate(self.resnet_enc(img_inp.flatten(0,1)*.5+.5),imsize,mode="bilinear")
        model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)

        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1)/2
        depth = depth_inp + res_depth

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)
        corr_weights = corr_weights * (~(corresp_feat==0).all(dim=-1).unsqueeze(-1)).float() # mask out out-of-border points (not necessary but why not)

        #corr_weights = corr_weights * (~(corresp_feat==0).all(dim=-1).unsqueeze(-1)).float() * ch_sec(model_input["rig_flow_masks"]) # mask out out-of-border points and rig mask

        if 0:
            rig_mask_inp = rearrange(torch.cat((torch.ones_like(model_input["rig_flow_masks"][:,:1]),model_input["rig_flow_masks"]),1),"1 t o x y -> o t (x y) 1")
            rig_mask_inp = torch.cat((rig_mask_inp,(rig_mask_inp.round().max(dim=0)[0]!=1)[None])) # add last mask to input for remaining mask so sums to 1
            # Estimate poses via procrustes and integrate 
            #idxs=[(torch.arange(len(x)).cuda()[x==1])[::(imsize[0]*imsize[1] // 500)] for x in rig_mask_inp.flatten(0,1).squeeze(-1)] # todo vectorize
            idxs=[(torch.arange(len(x)).cuda()[x==1]) for x in rig_mask_inp.flatten(0,1).squeeze(-1)] # todo vectorize
            idxs=[x if len(x)>10 else torch.arange(10).cuda() for x in idxs]
            idxs=[x if len(x)>100 else torch.cat([x for _ in range(10)]) for x in idxs]
            idxs=[x[torch.randperm(len(x))][:100] for x in idxs]
            #maxlen=max(len(y) for y in idxs)
            #from pdb import set_trace as pdb_;pdb_() 
            #idxs=[x if len(x)==maxlen else torch.cat((x,x[:1].repeat(maxlen-len(x)))) if len(x)!=0 else torch.tensor([0]).cuda().repeat(maxlen) for x in idxs]
            idxs=torch.stack(idxs).unflatten(0,rig_mask_inp.shape[:2]).unsqueeze(-1)

            rig_masks=rearrange( (1e0*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")

            poses = geometry.procrustes(torch.gather(eye_surf.expand(self.n_rig,-1,-1,-1)[:,1:],2,idxs[:,1:].expand(-1,-1,-1,3)), 
                                        torch.gather(corresp_surf.expand(self.n_rig,-1,-1,-1),  2,  idxs[:,1:].expand(-1,-1,-1,3)), 
                                        torch.gather(corr_weights*rig_masks[:,1:],2,idxs[:,1:]).clip(min=1e-4),)[1]
        elif 1:
            if 1:
                rig_masks=rearrange( (1e0*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")
            elif 0:
                rig_masks=rearrange( (1e1*self.rig_predictor(ch_sec(model_input["dino_pca"]))).softmax(-1), "b t x o -> (b o) t x 1")
            else:
                rig_inp = ch_sec(model_input["dino_pca"])+self.fmap_downproj2(ch_sec(model_input["fmap"]))/4
                rig_masks=rearrange( (1e1*self.rig_predictor(rig_inp)).softmax(-1), "b t x o -> (b o) t x 1")

            if 0:
                rig_mask_inp = rearrange(torch.cat((torch.ones_like(model_input["rig_flow_masks"][:,:1]),model_input["rig_flow_masks"]),1),"1 t o x y -> o t (x y) 1")
                idxs=[(torch.arange(len(x)).cuda()[x==1]) for x in rig_mask_inp.flatten(0,1).squeeze(-1)] # todo vectorize
                idxs=[x if len(x)>100 else torch.cat([x for _ in range(10)]) for x in idxs]
                idxs=[x[torch.randperm(len(x))][:100] for x in idxs]
                idxs=torch.stack(idxs).unflatten(0,rig_mask_inp.shape[:2]).permute(1,2,0).flatten(1,2)

                subset_=lambda x: torch.stack([y[:,z] for y,z in zip(x.unbind(1),idxs)],1)
                poses = geometry.procrustes( subset_(eye_surf[:,1:].expand(self.n_rig,-1,-1,-1)), subset_(corresp_surf.expand(self.n_rig,-1,-1,-1)), 
                                             subset_(corr_weights*rig_masks[:,1:,]).clip(min=1e-4),)[1]
            else:
                #poses = geometry.procrustes( eye_surf[:,1:,rand_subset].expand(self.n_rig,-1,-1,-1), corresp_surf[:,:,rand_subset].expand(self.n_rig,-1,-1,-1), 
                #                         (corr_weights[:,:,rand_subset]*rig_masks[:,1:,rand_subset]).clip(min=1e-4),)[1]
                poses= geometry.efficient_procrustes( eye_surf[:,1:,:].expand(self.n_rig,-1,-1,-1), corresp_surf[:,:,:].expand(self.n_rig,-1,-1,-1), 
                        (corr_weights[:,:,:]*rig_masks[:,1:,:]).clip(min=1e-4),)[1]
        else:
            rig_mask_inp = rearrange(torch.cat((torch.ones_like(model_input["rig_flow_masks"][:,:1]),model_input["rig_flow_masks"]),1),"1 t o x y -> o t (x y) 1")
            rig_masks = rig_mask_inp = torch.cat((rig_mask_inp,(rig_mask_inp.round().max(dim=0)[0]!=1)[None])) # add last mask to input for remaining mask so sums to 1
            idxs=[(torch.arange(len(x)).cuda()[x==1]) for x in rig_mask_inp[:,1:].flatten(0,1).squeeze(-1)] # todo vectorize
            poses=[]
            for i,(rig_mask,idx) in enumerate(zip(rig_masks.flatten(0,1),idxs)):
                if len(idx)<10:
                    pose=torch.eye(4).cuda()[None,None]
                else:
                    if len(idx)>500:
                        idx=idx[torch.randperm(len(idx))][:500]
                    pose =geometry.procrustes( eye_surf[:,1:,idx].expand(self.n_rig,-1,-1,-1).flatten(0,1)[i][None,None], corresp_surf[:,:,idx].expand(self.n_rig,-1,-1,-1).flatten(0,1)[i][None,None], ( 1e1*corr_weights[:,:,idx].expand(self.n_rig,-1,-1,-1).flatten(0,1)[i] *rig_mask[idx]).clip(min=1e-4)[None,None], )[1]
                    #poses = geometry.procrustes( eye_surf[:,1:,rand_subset].expand(self.n_rig,-1,-1,-1), corresp_surf[:,:,rand_subset].expand(self.n_rig,-1,-1,-1), (corr_weights[:,:,rand_subset]*rig_masks[:,1:,rand_subset]).clip(min=1e-4),)[1]
                poses.append( pose )
            poses=torch.cat(poses).squeeze(1).unflatten(0,(self.n_rig,n_trgt-1))
            # output should be o t 4 4
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add identity for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-4),poses[...,:3,-1]),-1)
        lie_perpix = (rig_masks*poses_lie.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks 
        if "pred_tracks" in model_input and 1:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",
                ).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            out["track_idxs"] = torch.randperm(model_input["pred_tracks"].size(-2))[:int(8000//(n_trgt**2/20**2))] # choose random smaller subset to combat quadratic complexity
            pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1, padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
            pose_all_to_all_perpix = repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p t s x y",s=n_trgt)
            track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks[:,out["track_idxs"]]))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)
        else:
            print("skipping point tracks sup")

        bg_pose = poses[rig_masks.flatten(1,-1).sum(-1).max(dim=0)[1]][None]

        return out | {
            "rig_masks":rearrange(rig_masks,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "poses":poses,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "flow_inp_": model_input["bwd_flow"],
            "world_crds": torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3],
            "rgb_crds": ch_sec(model_input["rgb"]),
            "lie_crds" : (rig_masks.flatten(1,2)[...,None] * poses_lie[:,None]).sum(0),
            "lie_perpix":lie_perpix.unflatten(-2,imsize),
        }




    def forward_(self, model_input, out={}): # static est

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=1000
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Est depth
        flow_inp = torch.cat((torch.zeros_like(model_input["bwd_flow"][:,:1]),model_input["bwd_flow"]*5e2),1)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbdf = torch.cat((model_input["rgb"],depth_inp,flow_inp),2)

        img_inp = torch.cat((model_input["rgb"],depth_inp/10-1),2)
        img_inp=rearrange(img_inp,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            fmap_out = F.interpolate(self.resnet_enc(img_inp.flatten(0,1)*.5+.5),imsize,mode="bilinear")

        model_input["fmap"] = rearrange(fmap_out.float(),"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)

        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = depth_inp + res_depth

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)
        corr_weights = corr_weights * (~(corresp_feat==0).all(dim=-1).unsqueeze(-1)).float() * ch_sec(model_input["rig_flow_masks"]) # mask out out-of-border points and rig mask

        # Estimate poses via procrustes and integrate 
        poses = geometry.procrustes( eye_surf[:,1:,rand_subset], corresp_surf[:,:,rand_subset], corr_weights[:,:,rand_subset],)[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btij,btpj->btpi",poses[:,:-1].inverse()@poses[:,1:],hom(eye_surf[:,1:]))[...,:3], model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks
        if "pred_tracks" in model_input and 1:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,
                                                    padding_mode="border",).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            pose_all_to_all = repeat(poses,"b t x y -> b s t x y",s=n_trgt).inverse()@repeat(poses,"b t x y -> b t s x y",s=n_trgt)
            out["track_idxs"] = torch.randperm(model_input["pred_tracks"].size(-2))[:int(8000//(n_trgt**2/20**2))] # choose random smaller subset to combat quadratic complexity
            track_surf_reprojs = torch.einsum("bsnij,bksj->bksni",pose_all_to_all,hom(eye_tracks[:,out["track_idxs"]]))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        return out | {
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "poses":poses,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "flow_inp_": model_input["bwd_flow"],
            "world_crds": torch.einsum("btij,btpj->btpi",poses,hom(eye_surf))[...,:3],
            #"world_crds" : torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3],
            "rgb_crds": ch_sec(model_input["rgb"]),
        }

    
    def forward_(self, model_input, out={}): # mlp rigid bucket

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=500
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()
        low_imres=(64,64)

        #if "dino_clusters" not in model_input:
        #    print("doing dino kmeans")
        #    dino_kmeans = KMeans(n_clusters=self.n_rig)(ch_sec(model_input["dino_pca"]).flatten(1,2)[...,:3]).labels.view(1,n_trgt,imsize[0]*imsize[1])
        #    model_input["dino_clusters"] = torch.cat([dino_kmeans==i for i in range(self.n_rig)]).unsqueeze(-1)

        # Est depth
        # TODO add dino pca image and make sure same range as input
        #flow_inp = torch.cat((torch.zeros_like(model_input["bwd_flow"][:,:1]),model_input["bwd_flow"]*5e2),1)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        #img_inp = torch.cat((model_input["rgb"],depth_inp/10-1,flow_inp,model_input["dino_pca"][:,:,:]*2,
        img_inp = torch.cat((model_input["rgb"],depth_inp/10-1),2)
        #img_inp = model_input["rgb"]
        img_inp=rearrange(img_inp,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        fmap_out = F.interpolate(self.resnet_enc(img_inp.flatten(0,1)*.5+.5),imsize,mode="bilinear")
        model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)

        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1)/2
        depth = depth_inp + res_depth

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)
        corr_weights = corr_weights * (~(corresp_feat==0).all(dim=-1).unsqueeze(-1)).float() # mask out out-of-border points (not necessary but why not)

        # run kmeans on dino feats as initial seg

        # Predict rigid masks (todo add dino feats)
        # First sanity case -- use kmeans directly as rigid seg
        # Second -- use this as predicted offset to dino kmeans
        
        #rig_masks_dino = dino_kmeans==k for k in dino_kmeans.unique()
        
        #res_clusters = self.rig_predictor(ch_sec(model_input["fmap"]))+rearrange(model_input["dino_clusters"]/2,"o t xy 1 -> 1 t xy o",)
        #rig_masks=rearrange( (res_clusters).softmax(-1), "b t x o -> (b o) t x 1")
        #rig_masks=rearrange( (2e0*self.rig_predictor(ch_sec(model_input["dino_pca"]))).softmax(-1), "b t x o -> (b o) t x 1")
        rig_masks=rearrange( (1e0*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")
        #rig_masks = model_input["dino_clusters"]

        # Estimate poses via procrustes and integrate 
        poses = geometry.procrustes( eye_surf[:,1:,rand_subset].expand(self.n_rig,-1,-1,-1), corresp_surf[:,:,rand_subset].expand(self.n_rig,-1,-1,-1), 
                                     (corr_weights[:,:,rand_subset]*rig_masks[:,1:,rand_subset]).clip(min=1e-4),)[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 

        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-4),poses[...,:3,-1]),-1)
        lie_perpix = (rig_masks*poses_lie.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks 
        if "pred_tracks" in model_input and 1:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",
                ).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            out["track_idxs"] = torch.randperm(model_input["pred_tracks"].size(-2))[:int(8000//(n_trgt**2/20**2))] # choose random smaller subset to combat quadratic complexity
            pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1, padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
            pose_all_to_all_perpix = repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks[:,out["track_idxs"]],"b p t x y -> b p t s x y",s=n_trgt)
            track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks[:,out["track_idxs"]]))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        bg_pose = poses[rig_masks.flatten(1,-1).sum(-1).max(dim=0)[1]][None]

        return out | {
            "rig_masks":rearrange(rig_masks,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,
            "poses":poses,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "flow_inp_": model_input["bwd_flow"],
            "world_crds": torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3],
            "rgb_crds": ch_sec(model_input["rgb"]),
            "lie_crds" : (rig_masks.flatten(1,2)[...,None] * poses_lie[:,None]).sum(0),
            #"lie_crds":lie_perpix,
            "lie_perpix":lie_perpix.unflatten(-2,imsize),
            #"dino_clusters":model_input["dino_clusters"],
        }

    def forward_(self, model_input, out={}): # mlp rigid bucket, old

        imsize=model_input["rgb"].shape[-2:]
        low_imres=(64,64)
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=300
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        if self.args.overfit: rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()

        # Est depth
        flow_inp = torch.cat((torch.zeros_like(model_input["bwd_flow"][:,:1]),model_input["bwd_flow"]*5e2),1)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])
        rgbdf = torch.cat((model_input["rgb"],depth_inp,flow_inp),2)

        rgbdf=rearrange(rgbdf,"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        fmap_out = F.interpolate(self.resnet_enc(rgbdf.flatten(0,1)*.5+.5),imsize,mode="bilinear")
        model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)

        #model_input["fmap"] = F.interpolate(self.resnet_enc(rgbdf.flatten(0,1)*.5+.5),imsize,mode="bilinear").unflatten(0,(b,n_trgt))

        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1) 
        depth = depth_inp + res_depth

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Estimate correspondence weights
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)

        # Estimate rigid masks -- with n_rig=1 reduces exactly to static flowmap code -- note since normalizing here we're doing cosine sim, try unnorm dot product too
        affinity_mask=affinity_mask_unnorm=rearrange( (2e0*self.rig_predictor(ch_sec(model_input["fmap"]))).softmax(-1), "b t x o -> (b o) t x 1")
        #affinity_mask=affinity_mask_unnorm=rearrange( (2e0*self.rig_predictor(ch_sec(model_input["fmap"]+model_input["dino_feats"]/5))).softmax(-1), "b t x o -> (b o) t x 1")
        affinity_mask=ch_sec(F.interpolate(ch_fst(affinity_mask,imsize[0]).flatten(0,1),low_imres).unflatten(0,affinity_mask.shape[:2]))
        #low_imres=imsize
        #corr_weights_obj = (corr_weights * affinity_mask_unnorm[:,1:]).clip(min=1e-4)

        # Local sampling per embedding based on point track corresponding to embedding source location 
        get_xpix = lambda s,e,n: torch.stack(torch.meshgrid(*[torch.linspace(s+.01,e-.01,n) for _ in range(2)])).flatten(1,2)
        sample_locs = uniform_grid_pix = repeat(get_xpix(0,1,int(500**.5)),"c p -> (b o) t p 1 c",o=self.n_rig,b=b,t=n_trgt-1).cuda()
        #todo above just use xpix indexing and instead of grid_samp use indexing and repeat like before

        # Estimate poses via procrustes and integrate ; this rearanging is ugly below for efficiency but can definitely refactor
        poses = geometry.procrustes(
            rearrange(grid_samp(ch_fst(eye_surf[:,1:],imsize[0]),rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
            rearrange(grid_samp(ch_fst(corresp_surf,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2)),
           (rearrange(grid_samp(ch_fst(corr_weights,imsize[0]),  rearrange(sample_locs,"(b s) t p 1 c -> b t (p s) 1 c",b=b)),"b t c (p s) 1 -> (b s) t p c",p=sample_locs.size(2))*
            grid_samp(ch_fst(affinity_mask_unnorm[:,1:],low_imres[0]),sample_locs).squeeze(2)).clip(min=1e-4)
                        )[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add for starting pose

        # Composite the poses here to be per-pixel so that each pixel has an se3 trajectory (lie space composition then se3 exponentiation) 
        poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(poses[...,:3,:3],eps=1e-5),poses[...,:3,-1]),-1)
        lie_perpix = (affinity_mask*poses_lie.unsqueeze(-2)).unflatten(0,(b,self.n_rig)).sum(1)
        lie_perpix = ch_sec(F.interpolate(ch_fst(lie_perpix,low_imres[0]).flatten(0,1),imsize,mode="bilinear").unflatten(0,(b,n_trgt))) # upsample to hires (just bilinear rn)

        pose_perpix = torch.eye(4)[None,None,None].expand(b,n_trgt,lie_perpix.size(-2),-1,-1).to(lie_perpix)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix[...,4:]

        # Compute pose flow
        adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
                                                model_input["intrinsics"][:,1:] )-model_input["x_pix"][:,1:]

        ## Compute pose induced point tracks -- TODO subsample points with random fixed point buget
        if "pred_tracks" in model_input:
            eye_tracks=rearrange(F.grid_sample(ch_fst(eye_surf,imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",
                ).unflatten(0,(b,n_trgt)), "b t c p 1 -> b p t c",)
            pose_tracks=rearrange(F.grid_sample(ch_fst(pose_perpix.flatten(-2,-1),imsize[0]).flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1,
                                    padding_mode="border").unflatten(0,(b,n_trgt)), "b t (x y) p 1 -> b p t x y",x=4)
            pose_all_to_all_perpix = repeat(pose_tracks,"b p t x y -> b p s t x y",s=n_trgt).inverse()@repeat(pose_tracks,"b p t x y -> b p t s x y",s=n_trgt)
            track_surf_reprojs = torch.einsum("bksnij,bksj->bksni",pose_all_to_all_perpix,hom(eye_tracks))[...,:3]
            out["track_reprojs"] = unhom(torch.einsum("bij,bksnj->bksni",model_input["intrinsics"][:,0],track_surf_reprojs)).clip(0,1)

        bg_pose = poses[affinity_mask.flatten(1,-1).sum(-1).max(dim=0)[1]][None] # take max

        world_crds = torch.einsum("btpij,btpj->btpi",pose_perpix,hom(eye_surf))[...,:3]
        world_crds_tracks = torch.einsum("btpij,btpj->btpi",pose_tracks.permute(0,2,1,3,4),hom(eye_tracks.permute(0,2,1,3)))[...,:3]
        rgb_tracks=F.grid_sample(model_input["rgb"].flatten(0,1),model_input["pred_tracks"].flatten(0,1).unsqueeze(-2)*2-1, 
                                                            padding_mode="border").unflatten(0,(b,n_trgt)).permute(0,1,3,2,4).squeeze(-1)
        lie_tracks = torch.cat((torch.tensor([1,0,0,0]).float().cuda()[None,None,None].expand(b,n_trgt,rgb_tracks.size(-2),-1),
                                (eye_tracks[:,:,[0]]-eye_tracks).permute(0,2,1,3)
                                ),-1).unsqueeze(-2)
        rgb_crds = ch_sec(model_input["rgb"])

        return out | {
            "rig_masks":rearrange(affinity_mask,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            #"rig_masks_unnorm":rearrange(affinity_mask_unnorm,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            #"affinity_emb":affinity_emb,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "flow_from_pose":adj_opt_flow,

            "lie_crds":torch.cat((lie_perpix.unsqueeze(-2),lie_tracks),-3),
            "world_crds":torch.cat((world_crds,world_crds_tracks),-2),
            "rgb_crds":torch.cat((rgb_crds,rgb_tracks),-2),

            #"lie_crds":lie_tracks,
            #"world_crds":world_crds_tracks,
            #"rgb_crds":rgb_tracks,

            #"lie_crds":lie_perpix.unflatten(-2,imsize),
            #"world_crds":world_crds,
            #"rgb_crds":rgb_crds,

            #"affinity_emb":model_input["dino_feats"],

            "lie_perpix":lie_perpix.unflatten(-2,imsize),
            "pose_perpix":pose_perpix.unflatten(-3,imsize),
            "eye_surf":eye_surf,
            "poses_all":poses,
            "poses":bg_pose,
            "depth_inp":model_input["depth_inp"],
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "poses":bg_pose,
            "flow_inp_": model_input["bwd_flow"],
        }

class PixelNeRFEncoder(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
            return norm_layer


        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(
            pretrained=pretrained, norm_layer=norm_layer
        )

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 512, 1),
        )

    def forward(self, x, custom_size=None):


        if len(x.shape)>4: return self(x.flatten(0,1),custom_size).unflatten(0,x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )
        x = x.to(device=self.latent.device)
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents = [x]
        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        self.latents = latents
        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
