import torch,torchvision

import matplotlib.pyplot as plt
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
from torch.nn import functional as F
import torchvision.transforms as T

import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

import torch_kmeans
from torch_kmeans import KMeans

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)
hom = lambda x: torch.cat((x, torch.ones_like(x[..., [0]])), -1)
unhom = lambda x: x[..., :-1] / (1e-5 + x[..., -1:])
grid_samp_ = lambda x, y, pad,mode: F.grid_sample( x, y * 2 - 1, mode=mode, padding_mode=pad)  # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x, y, pad="border",mode="bilinear": ( grid_samp_(x, y, pad,mode) if len(x.shape) == 4 else grid_samp_(x.flatten(0, 1), y.flatten(0, 1),pad,mode).unflatten(0, x.shape[:2]))  
project = lambda crds, K: unhom(torch.einsum("b...cij,b...ckj->b...cki", K, crds))
warp = lambda crds, poses, K: project( torch.einsum("b...cij,b...ckj->b...cki", poses, hom(crds))[..., :3], K)
shuffle = lambda x: x[torch.randperm(len(x)).to(x)]

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, "weight"): nn.init.kaiming_normal_( m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
    layers = []
    for i in range(len(dims) - 1):
        layers.append(nn.Linear(dims[i], dims[i + 1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

static_solve = True
use_depth_inp = True
scratch_model=True

class FlowMap(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # Our actual model
        self.time_stride = 1#args.time_stride
        fdim = 64#32
        self.depth_est = make_net([fdim, fdim, 1])
        self.depth_conv = nn.Conv2d(fdim, 1, 3, padding=1)
        affinity_dim = 6
        self.affinities_conv = nn.Conv2d(fdim, affinity_dim, 3, padding=1)
        self.general_confidence_conv = nn.Conv2d(fdim, 1, 3, padding=1)
        self.general_confidence_conv_static = nn.Conv2d(fdim, 1, 3, padding=1)

        self.corr_weighter_perpoint = make_net([fdim * 2, 16, 1])

        #self.img_enc = nn.Sequential( ResnetFPN(in_ch=3 * self.time_stride, use_first_pool=True), nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)
        self.img_enc = ResnetFPN(in_ch=(3+use_depth_inp*0) * self.time_stride, use_first_pool=True)#, nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)
        #self.img_enc = smp.Unet( encoder_name="mobileone_s2",encoder_weights="imagenet", in_channels=3+use_depth_inp,classes=fdim) # from s0 to s4 from 4-13M param   
        
        self.step=0

        #self.midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", pretrained=not scratch_model)
        #self.midas_out=self.midas.scratch.output_conv
        #self.midas.scratch.output_conv=nn.Identity()

    def forward_( self, model_input, track_idxs=None, out={}):  # point track based contrastive redo
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        low_imres = (64, 64)
        rand_subset = torch.linspace(0, model_input["x_pix"].size(-2)-1, 2000).long()
        track_idxs = torch.randperm(model_input["pred_tracks"].size(-2))[:10000]

        def random_mask(rgb):
            patch_size=22
            keep_ratio=.7
            image=rgb[0,0]
            C, H, W = image.shape
            num_patches_h = H // patch_size
            num_patches_w = W // patch_size
            total_patches = num_patches_h * num_patches_w

            num_keep = int(total_patches * keep_ratio)
            mask = torch.zeros(total_patches, dtype=torch.bool)
            mask[torch.randperm(total_patches)[:num_keep]] = 1  # Randomly select patches to keep
            mask = mask.view(num_patches_h, num_patches_w)  # Reshape to (grid_h, grid_w)

            # Expand mask to full image resolution
            mask_full = mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)  # (H, W)

            # Apply mask to the image (set masked areas to zero)
            return F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            masked_image = rgb * F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            plt.imsave("/home/cameronsmith/tmp.png",image.permute(1,2,0).cpu().numpy()*.5+.5)
            plt.imsave("/home/cameronsmith/tmp2.png",masked_image[0,0].permute(1,2,0).cpu().numpy()*.5+.5)
            from pdb import set_trace as pdb_;pdb_()
        def random_color_jitter(rgb):
            color_transform = T.ColorJitter( brightness=np.random.uniform(0, .1), contrast=np.random.uniform(0, .2), saturation=np.random.uniform(0, .2), hue=np.random.uniform(0.0, 0.1))
            return color_transform(rgb*.5+.5)*2-1

        # General feature map backbone prediction
        # TODO predict affine matrix per batch elem not for entire batch
        # FPN
        img_warp = model_input["rgb"]

        img_warp = random_color_jitter(img_warp)

        # FPN
        img_inp = img_inp_premask = img_warp#torch.cat((model_input["rgb"],img_warp),1) 
        #random_mask = random_mask(img_warp)
        #img_inp = img_inp * random_mask

        fmap = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",).unflatten(0,img_inp.shape[:2])

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb= self.affinities_conv( fmap.flatten(0, 1)).unflatten(0, fmap.shape[:2])

        affinity_emb = F.normalize(affinity_emb, dim=2) # NOTE doing cosine similarity here

        aff_emb_pertrack = ch_sec( grid_samp(affinity_emb, model_input["pred_tracks"].unsqueeze(-2)))[:,:,track_idxs]
        src_feats,warp_feats = aff_emb_pertrack.unbind(1)

        mask = model_input["pred_visibility"][:,:,track_idxs].min(dim=1)[0][...,None]

        aff_sim = torch.einsum( "b p c, b q c -> b p q", src_feats, warp_feats)

        aff_sim=aff_sim*.5+.5+.0001

        #loss = F.cross_entropy(aff_sim.flatten(0,-2)*1, torch.arange(len(track_idxs)).cuda()[None].expand(b,-1).flatten())

        #attn_weights = F.softmax(aff_sim, dim=-1)  # (B, P, Q)
        # NOTE doing weighted sum instead of softmax here
        attn_weights = aff_sim/aff_sim.sum(dim=-1,keepdim=True)

        warp_crds=model_input["pred_tracks"][:,1,track_idxs]
        pred_crds = torch.einsum("b p q, b q d -> b p d", attn_weights,  warp_crds)
        loss=((pred_crds-warp_crds)*mask).square().mean()*1e2

        return out | {
            "imgs":img_inp,
            "contrastive_loss":loss,
            "aff_sim": aff_sim,
            "affinity_emb": affinity_emb,
            "imgs_premask":img_inp_premask,
        }
    def forward( self, model_input, track_idxs=None, out={}):  # image based contrastive just self attention
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        low_imres = (64, 64)
        rand_subset = torch.linspace(0, model_input["x_pix"].size(-2)-1, 4000).long()

        fmap = F.interpolate( self.img_enc(model_input["rgb"].flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",).unflatten(0,model_input["rgb"].shape[:2])

        def random_affine_warp(rgb):
            affine_matrices=[]
            for i in range(len(rgb)):
                scale,tx,ty,flip_x,flip_y=np.clip(1/(np.random.rand()+1),.5,1),np.random.rand(),np.random.rand(),np.random.choice([-1,1]),np.random.choice([-1,1])
                affine_matrices.append( torch.tensor([ [scale* flip_x, 0, np.clip(tx,-(1-scale),1-scale)], [0, scale*flip_y, np.clip(ty,-(1-scale),1-scale)] ]) )
            warp_crds = torch.einsum('bij,btxj->btxi', torch.stack(affine_matrices).cuda().float(), hom(model_input["x_pix"])*2-1)
            return ch_fst(grid_samp(rgb,warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3),warp_crds
        def random_color_jitter(rgb):
            color_transform = T.ColorJitter( brightness=np.random.uniform(0, .1), contrast=np.random.uniform(0, .2), saturation=np.random.uniform(0, .2), hue=np.random.uniform(0.0, 0.1))
            return color_transform(rgb*.5+.5)*2-1

        img_warp,warp_crds = random_affine_warp(model_input["rgb"])
        img_warp = random_color_jitter(img_warp)

        img_inp = img_inp_premask = torch.cat((model_input["rgb"],img_warp),1) 
        fmap = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",).unflatten(0,img_inp.shape[:2])

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb= self.affinities_conv( fmap.flatten(0, 1)).unflatten(0, fmap.shape[:2])#*30

        #affinity_emb = affinity_emb.tanh()
        affinity_emb = F.normalize(affinity_emb, dim=2) # NOTE doing cosine similarity here

        feat_warp=ch_fst(grid_samp(affinity_emb[:,:1],warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3)
        src_feats,warp_feats = ch_sec(affinity_emb[:,1:]).squeeze(1)[:,rand_subset],ch_sec(feat_warp).squeeze(1)[:,rand_subset]
        warp_feats=src_feats

        aff_sim = torch.einsum( "b p c, b q c -> b p q", src_feats, warp_feats)*30#*.5+.5+.001
        loss = F.cross_entropy(aff_sim.flatten(0,-2), torch.arange(len(rand_subset)).cuda()[None].expand(b,-1).flatten())
        #loss=(aff_sim.relu()-torch.eye(len(rand_subset)).cuda()).square().mean()

        #from pdb import set_trace as pdb_;pdb_() 
        #aff_sim=aff_sim.relu()+.001
        #attn_weights = aff_sim/(1e-5+aff_sim.sum(dim=-1,keepdim=True))
        #pred_crds = torch.einsum("b p q, b q d -> b p d", attn_weights,  warp_crds[:,0,rand_subset])
        #loss=(pred_crds-warp_crds[:,0,rand_subset]).square().mean()*1e1

        # just for vis
        with torch.no_grad():
            dsl_aff=F.interpolate(affinity_emb[:,0],(64,64))
            aff_sim_grid = torch.einsum( "b p c, b q c -> b p q",ch_sec(dsl_aff[...,::8,::8]),ch_sec(dsl_aff)).softmax(dim=-1).unflatten(-2,(8,8)).unflatten(-1,(64,64))   # from all source pix to all other source pix

        return out | {
            "imgs":img_inp,
            "contrastive_loss":loss,
            "aff_sim": aff_sim,
            "aff_sim_grid": aff_sim_grid,
            "affinity_emb": affinity_emb,
            "fmap": fmap,
        }
    def forward_( self, model_input, track_idxs=None, out={}):  # image based contrastive sire like
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        low_imres = (64, 64)
        rand_subset = torch.linspace(0, model_input["x_pix"].size(-2)-1, 4000).long()

        def random_mask(rgb):
            patch_size=24
            keep_ratio=.3
            image=rgb[0,0]
            C, H, W = image.shape
            num_patches_h = H // patch_size
            num_patches_w = W // patch_size
            total_patches = num_patches_h * num_patches_w

            num_keep = int(total_patches * keep_ratio)
            mask = torch.zeros(total_patches, dtype=torch.bool)
            mask[torch.randperm(total_patches)[:num_keep]] = 1  # Randomly select patches to keep
            mask = mask.view(num_patches_h, num_patches_w)  # Reshape to (grid_h, grid_w)

            # Expand mask to full image resolution
            mask_full = mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)  # (H, W)

            # Apply mask to the image (set masked areas to zero)
            return F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            masked_image = rgb * F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            plt.imsave("/home/cameronsmith/tmp.png",image.permute(1,2,0).cpu().numpy()*.5+.5)
            plt.imsave("/home/cameronsmith/tmp2.png",masked_image[0,0].permute(1,2,0).cpu().numpy()*.5+.5)
            from pdb import set_trace as pdb_;pdb_() 

        #model_input["rgb"]

        def random_affine_warp(rgb):
            affine_matrices=[]
            for i in range(len(rgb)):
                scale,tx,ty,flip_x,flip_y=np.clip(1/(np.random.rand()+1),.5,1),np.random.rand(),np.random.rand(),np.random.choice([-1,1]),np.random.choice([-1,1])
                affine_matrices.append( torch.tensor([ [scale* flip_x, 0, np.clip(tx,-(1-scale),1-scale)], [0, scale*flip_y, np.clip(ty,-(1-scale),1-scale)] ]) )
            warp_crds = torch.einsum('bij,btxj->btxi', torch.stack(affine_matrices).cuda().float(), hom(model_input["x_pix"])*2-1)
            return ch_fst(grid_samp(rgb,warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3),warp_crds
        def random_color_jitter(rgb):
            color_transform = T.ColorJitter( brightness=np.random.uniform(0, .1), contrast=np.random.uniform(0, .2), saturation=np.random.uniform(0, .2), hue=np.random.uniform(0.0, 0.1))
            return color_transform(rgb*.5+.5)*2-1

        # General feature map backbone prediction
        # TODO predict affine matrix per batch elem not for entire batch
        img_warp,warp_crds = random_affine_warp(model_input["rgb"])
        img_warp = random_color_jitter(img_warp)

        # FPN
        img_inp = img_inp_premask = torch.cat((model_input["rgb"],img_warp),1) 
        #random_mask = random_mask(model_input["rgb"])
        #img_inp = img_inp * random_mask

        fmap = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",).unflatten(0,img_inp.shape[:2])

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb= self.affinities_conv( fmap.flatten(0, 1)).unflatten(0, fmap.shape[:2])

        affinity_emb = F.normalize(affinity_emb, dim=2) # NOTE doing cosine similarity here
        #affinity_emb=affinity_emb.relu()#+1

        feat_warp=ch_fst(grid_samp(affinity_emb[:,:1],warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3)
        src_feats,warp_feats = ch_sec(affinity_emb[:,1:]).squeeze(1)[:,rand_subset],ch_sec(feat_warp).squeeze(1)[:,rand_subset]

        aff_sim = torch.einsum( "b p c, b q c -> b p q", src_feats, warp_feats)#*.5+.5+.001

        #attn_weights = aff_sim/(1e-5+aff_sim.sum(dim=-1,keepdim=True))
        #attn_weights = F.softmax(aff_sim, dim=-1)  # (B, P, Q)

        #pred_crds = torch.einsum("b p q, b q d -> b p d", attn_weights,  warp_crds[:,0,rand_subset])
        #loss=(pred_crds-warp_crds[:,0,rand_subset]).square().mean()*1e1

        loss = F.cross_entropy(aff_sim.flatten(0,-2), torch.arange(len(rand_subset)).cuda()[None].expand(b,-1).flatten())
        # Manual cross entopy of above to avoid nans with eps
        #softmax_probs = torch.softmax(aff_sim.flatten(0, -2), dim=-1)  # shape: (B * P, Q)
        #log_probs = torch.log(softmax_probs + 1e-4)  # Small epsilon to prevent log(0)
        #print(affinity_emb.max(),affinity_emb.min(),img_inp.max(),img_inp.min(),aff_sim.max(),aff_sim.min(),log_probs.max(),log_probs.min())
        #targets = torch.arange(len(rand_subset)).cuda()[None].expand(b, -1).flatten()  # shape: (B * P,)
        #nll = -log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # Negative log likelihood
        #loss = nll.mean()


        return out | {
            "imgs":img_inp,
            "imgs_premask":img_inp_premask,
            "contrastive_loss":loss,
            "aff_sim": aff_sim,
            "affinity_emb": affinity_emb,
            "fmap": fmap,
        }

    def forward_( self, model_input, track_idxs=None, out={}):  # image based contrastive
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        low_imres = (64, 64)
        rand_subset = torch.linspace(0, model_input["x_pix"].size(-2)-1, 4000).long()

        def random_mask(rgb):
            patch_size=24
            keep_ratio=.3
            image=rgb[0,0]
            C, H, W = image.shape
            num_patches_h = H // patch_size
            num_patches_w = W // patch_size
            total_patches = num_patches_h * num_patches_w

            num_keep = int(total_patches * keep_ratio)
            mask = torch.zeros(total_patches, dtype=torch.bool)
            mask[torch.randperm(total_patches)[:num_keep]] = 1  # Randomly select patches to keep
            mask = mask.view(num_patches_h, num_patches_w)  # Reshape to (grid_h, grid_w)

            # Expand mask to full image resolution
            mask_full = mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)  # (H, W)

            # Apply mask to the image (set masked areas to zero)
            return F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            masked_image = rgb * F.interpolate(mask_full[None,None].float().cuda(),(H,W))  # (C, H, W)
            plt.imsave("/home/cameronsmith/tmp.png",image.permute(1,2,0).cpu().numpy()*.5+.5)
            plt.imsave("/home/cameronsmith/tmp2.png",masked_image[0,0].permute(1,2,0).cpu().numpy()*.5+.5)
            from pdb import set_trace as pdb_;pdb_() 

        #model_input["rgb"]

        def random_affine_warp(rgb):
            affine_matrices=[]
            for i in range(len(rgb)):
                scale,tx,ty,flip_x,flip_y=np.clip(1/(np.random.rand()+1),.5,1),np.random.rand(),np.random.rand(),np.random.choice([-1,1]),np.random.choice([-1,1])
                affine_matrices.append( torch.tensor([ [scale* flip_x, 0, np.clip(tx,-(1-scale),1-scale)], [0, scale*flip_y, np.clip(ty,-(1-scale),1-scale)] ]) )
            warp_crds = torch.einsum('bij,btxj->btxi', torch.stack(affine_matrices).cuda().float(), hom(model_input["x_pix"])*2-1)
            return ch_fst(grid_samp(rgb,warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3),warp_crds
        def random_color_jitter(rgb):
            color_transform = T.ColorJitter( brightness=np.random.uniform(0, .1), contrast=np.random.uniform(0, .2), saturation=np.random.uniform(0, .2), hue=np.random.uniform(0.0, 0.1))
            return color_transform(rgb*.5+.5)*2-1

        # General feature map backbone prediction
        # TODO predict affine matrix per batch elem not for entire batch
        img_warp,warp_crds = random_affine_warp(model_input["rgb"])
        img_warp = random_color_jitter(img_warp)

        # FPN
        img_inp = img_inp_premask = torch.cat((model_input["rgb"],img_warp),1) 
        #random_mask = random_mask(model_input["rgb"])
        #img_inp = img_inp * random_mask

        fmap = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",).unflatten(0,img_inp.shape[:2])

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb= self.affinities_conv( fmap.flatten(0, 1)).unflatten(0, fmap.shape[:2])

        affinity_emb=affinity_emb.relu()#+1

        feat_warp=ch_fst(grid_samp(affinity_emb[:,:1],warp_crds.unsqueeze(-2)*.5+.5),imsize[0]).squeeze(-3)
        src_feats,warp_feats = ch_sec(affinity_emb[:,1:]).squeeze(1)[:,rand_subset],ch_sec(feat_warp).squeeze(1)[:,rand_subset]

        aff_sim = torch.einsum( "b p c, b q c -> b p q", src_feats, warp_feats)

        attn_weights = F.softmax(aff_sim, dim=-1)  # (B, P, Q)
        pred_crds = torch.einsum("b p q, b q d -> b p d", attn_weights,  warp_crds[:,0,rand_subset])
        loss=(pred_crds-warp_crds[:,0,rand_subset]).square().mean()*1e1

        #loss = F.cross_entropy(aff_sim.flatten(0,-2), torch.arange(len(rand_subset)).cuda()[None].expand(b,-1).flatten())
        # Manual cross entopy of above to avoid nans with eps
        #softmax_probs = torch.softmax(aff_sim.flatten(0, -2), dim=-1)  # shape: (B * P, Q)
        #log_probs = torch.log(softmax_probs + 1e-4)  # Small epsilon to prevent log(0)
        #print(affinity_emb.max(),affinity_emb.min(),img_inp.max(),img_inp.min(),aff_sim.max(),aff_sim.min(),log_probs.max(),log_probs.min())
        #targets = torch.arange(len(rand_subset)).cuda()[None].expand(b, -1).flatten()  # shape: (B * P,)
        #nll = -log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # Negative log likelihood
        #loss = nll.mean()


        return out | {
            "imgs":img_inp,
            "imgs_premask":img_inp_premask,
            "contrastive_loss":loss,
            "aff_sim": aff_sim,
            "affinity_emb": affinity_emb,
            "fmap": fmap,
        }

    def forward_( self, model_input, track_idxs=None, out={}):  # point track based contrastive
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        n_samp = imsize[0] * imsize[1]  # min(40000,imsize[0]*imsize[1])
        rand_subset = torch.linspace(0, model_input["pred_tracks"].size(-2)-1, n_samp).long()
        low_imres = (64, 64)

        # pick n random points if not provided
        if track_idxs is None: track_idxs = torch.randperm(model_input["pred_tracks"].size(-2))[:200]

        # General feature map backbone prediction
        img_inp = model_input["rgb"]# if not use_depth_inp else torch.cat((model_input["rgb"],depth_inp.log()-1),2)
        img_inp = rearrange( img_inp, "b (t s) c x y -> b t (s c) x y", s=self.time_stride)
        # FPN
        fmap_out = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",)
        model_input["fmap"] = rearrange( fmap_out, "(b t) (s c) x y -> b (t s) c x y", s=self.time_stride, b=b)

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb_unnorm = self.affinities_conv( model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt))

        # Take affinity emb as mean over frames masked by visibility
        affinity_emb = F.normalize(affinity_emb_unnorm, dim=2)
        aff_emb_pertrack_allframe = ch_sec( grid_samp(affinity_emb, model_input["pred_tracks"].unsqueeze(-2)))[:,:,track_idxs]

        rand_frame = np.random.randint(n_trgt) # the frame we'll treat as source features

        aff_emb_src_feats = aff_emb_pertrack_allframe[:,rand_frame]

        aff_sim = torch.einsum( "b p c, b q c -> b p q", aff_emb_src_feats, aff_emb_pertrack_allframe.flatten(1,2)).unflatten(-1,aff_emb_pertrack_allframe.shape[1:3])  # from all source pix to all other source pix

        #classification = (aff_sim*5).softmax(dim=-1)
        #vis_mask = (model_input["pred_visibility"][:,rand_frame,track_idxs][:,None].expand(-1,n_trgt,-1)*model_input["pred_visibility"][:,:,track_idxs]).flatten()
        #loss = F.cross_entropy(aff_sim.flatten(0,-2)*5, torch.eye(len(track_idxs))[None,:,None].expand(b,-1,n_trgt,-1).flatten(0,-2).cuda(),reduction="none")
        loss = ((aff_sim*5).softmax(dim=-1) - torch.eye(len(track_idxs))[None,:,None].expand(b,-1,n_trgt,-1).cuda()).square() * (model_input["pred_visibility"][:,rand_frame,track_idxs][...,None,None] * model_input["pred_visibility"][:,:,track_idxs][:,None])

        return out | {
            "contrastive_loss":loss.mean()*1e3,#(loss * vis_mask).mean(),
            "aff_sim": aff_sim,
            "affinity_emb": affinity_emb,
            "affinity_emb_unnorm": affinity_emb_unnorm,
        }

    def forward_( self, model_input, track_idxs=None, out={}):  # point track based simpler
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)

        # pick n random points to solve for/supervise if not provided
        if track_idxs is None: track_idxs = torch.randperm(model_input["pred_tracks"].size(-2))[:200]

        # sample rigid flow masks per track and aggregate over time (if ever dynamic then dynamic) to get an estimate if a track is rigid
        rig_samp = ( grid_samp( model_input["rig_flow_masks"][:, :, [0]], model_input["pred_tracks"][:, 1:].unsqueeze(-2),) .squeeze(2) .squeeze(-1) .round())
        rig_samp = rig_samp_allframe = torch.where( model_input["pred_visibility"][:, 1:], rig_samp, torch.ones_like(rig_samp))
        rig_samp = rig_samp.min(dim=1)[0]

        # General feature map backbone prediction
        img_inp = model_input["rgb"]# if not use_depth_inp else torch.cat((model_input["rgb"],depth_inp.log()-1),2)
        img_inp = rearrange( img_inp, "b (t s) c x y -> b t (s c) x y", s=self.time_stride)
        if 1:
            if "fmap" not in model_input or torch.is_grad_enabled():
                model_input["midas_feats"]=midas_feats = self.midas(F.interpolate((model_input["rgb"]*.5+.5).flatten(0,1),(imsize[0]//32*32,imsize[1]//32*32),mode="bilinear"))
                model_input["fmap"]=fmap_out=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))/(100 if not scratch_model else 1)
            depth = res_depth = 1e3/(F.interpolate(self.midas_out(model_input["midas_feats"]),imsize,mode="bilinear").unflatten(0,(b,n_trgt))+1e-1)
        else :
            if "fmap" not in model_input or torch.is_grad_enabled():
                # FPN
                model_input["fmap"] = fmap_out = self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5)
                fmap_out = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",)
                model_input["fmap"] = rearrange( fmap_out, "(b t) (s c) x y -> b (t s) c x y", s=self.time_stride, b=b)
            depth = res_depth = F.softplus( self.depth_conv(model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt)) + 1)+1 
        # Use gt depth as input if desired
        if "depth_inp" in model_input and 0:
            depth = ch_fst(model_input["depth_inp"],imsize[0])*1e5
            depth_mask = (depth!=0).float()
            model_input["pred_visibility"] *= grid_samp( depth_mask, model_input["pred_tracks"].unsqueeze(-2),mode="nearest",pad="zeros").squeeze(2).squeeze(-1).bool()

        # Lift point track into 2.5D image-aligned surface
        rds_track = geometry.get_world_rays( model_input["pred_tracks"], model_input["intrinsics"], None)[1]
        eye_surf_track = rds_track * grid_samp( depth, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb_unnorm = self.affinities_conv( model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt))

        # Take affinity emb as mean over frames masked by visibility
        affinity_emb = F.normalize(affinity_emb_unnorm, dim=2)
        aff_emb_pertrack_allframe = ch_sec( grid_samp(affinity_emb, model_input["pred_tracks"].unsqueeze(-2)))
        aff_emb_pertrack = ( aff_emb_pertrack_allframe * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
        aff_sim = torch.einsum( "b p c, b q c -> b p q", aff_emb_pertrack[:, track_idxs], aff_emb_pertrack)  # from all source pix to all other source pix

        #if 1: aff_sim=torch.ones_like(aff_sim);print("completely rigid est") 
        aff_sim_rig = torch.where( rig_samp.bool()[:, track_idxs, None].expand(-1, -1, aff_sim.size(-1)), torch.ones_like(aff_sim), aff_sim,) # replace points in rigid mask with 1s

        # Predict general correspondence weights -- how reliable is this track generally at each frame
        general_conf = ( self.general_confidence_conv(model_input["fmap"].flatten(0, 1)) .unflatten(0, (b, n_trgt)) .sigmoid() .clip(min=1e-4))
        if self.step<500 or 0: general_conf = torch.ones_like(general_conf)
        general_conf_track = grid_samp( general_conf, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2) * model_input["pred_visibility"].unsqueeze(-1)

        if static_solve: # same code as below case where rigidities=1 but more efficient since only one solve
            if 0:
                poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, ], eye_surf_track[:, None, :-1], (general_conf_track[:, None, :-1]*rig_samp[:,None,None,:,None]).clip(min=1e-4),)[1]
                for i in range(n_trgt - 1, 0, -1): poses = torch.cat( (poses[:, :, :i], poses[:, :, [i - 1]] @ poses[:, :, i:]), -3)  # aggregate adjacent poses
            else:
                poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, ], eye_surf_track[:, None, :1].expand(-1,-1,n_trgt-1,-1,-1), (general_conf_track[:, None, :-1]*rig_samp[:,None,None,:,None]).clip(min=1e-4),)[1]
                print("doing direct pose regression static")
            poses = poses.expand( -1, len(track_idxs), -1, -1, -1)  # if static solve, just use single pose as pose for all points
        else:
            solve_stride = ( model_input["pred_tracks"].size(-2) // 3000)  # use every nth point in the solve
            poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:,  ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1), 
                                                   eye_surf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1),
                                             ( general_conf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1) * 
                                                          aff_sim_rig[:, :, None,   ::solve_stride, None]).clip(min=1e-4),)[1]
            for i in range(n_trgt - 1, 0, -1): poses = torch.cat( (poses[:, :, :i], poses[:, :, [i - 1]] @ poses[:, :, i:]), -3)  # aggregate adjacent poses
        poses = torch.cat( ( torch.eye(4).to(poses)[None, None, None].expand(poses.size(0), poses.size(1), -1, -1, -1), poses,), -3,)  # add identity for starting pose

        # Compute point track reprojection
        poses_all_to_all = repeat( poses.inverse(), "b p t x y -> b p s t x y", s=n_trgt) @ repeat(poses, "b p t x y -> b p t s x y", s=n_trgt)
        point_track_surf_reproj = torch.einsum( "bpstij,bstpj->bstpi", poses_all_to_all, hom( repeat( eye_surf_track[:, :, track_idxs], "b t p c -> b t s p c", s=n_trgt)),)[..., :3]
        point_track_reproj = project( point_track_surf_reproj, model_input["intrinsics"]).clip(0, 1)
        point_track_loss = ( ( ( point_track_reproj - model_input["pred_tracks"][:, None, :, track_idxs]) * model_input["pred_visibility"][:, None, :, track_idxs, None]).square().flatten().mean())

        # for vis of affsim on regular grid
        with torch.no_grad():
            track_sl,dsl=(64,4) if 0 else (42,3)
            src_tracks_0=rearrange(aff_emb_pertrack,"b (x y s) c -> b s c x y",y=track_sl,x=track_sl)[:,0]
            aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::dsl,::dsl]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//dsl,track_sl//dsl)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix
            # mean color and crds per track for vis
            rgb_pertrack = ch_sec( grid_samp(model_input["rgb"], model_input["pred_tracks"].unsqueeze(-2)))
            rgb_pertrack = ( rgb_pertrack * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
            worldcrds_pertrack = torch.einsum( "bptij,btpj->btpi", poses, hom(eye_surf_track[:, :, track_idxs]))[..., :3]
            worldcrds_pertrack = ( worldcrds_pertrack * model_input["pred_visibility"][:, :, track_idxs].unsqueeze(-1)
                                    ).sum(dim=1) / model_input["pred_visibility"][:, :, track_idxs].unsqueeze( -1).sum( dim=1).clip( min=1)

        return out | {
            "worldcrds_pertrack": worldcrds_pertrack,
            "rgb_pertrack": rgb_pertrack,
            "rig_pertrack": rig_samp,
            "poses_all": poses,
            #"depth_inp": model_input["depth_inp"],
            "point_track_loss": point_track_loss,
            "point_track_reproj": point_track_reproj[:, 0],
            "corr_weights": general_conf,
            "aff_sim": aff_sim,
            "aff_sim_grid": aff_sim_grid,
            "affinity_emb": affinity_emb,
            "affinity_emb_unnorm": affinity_emb_unnorm,
            "aff_emb_pertrack": aff_emb_pertrack,
            "depth": ch_sec(depth),
        }

    def forward_( self, model_input, track_idxs=None, out={}):  # point track based, solving for pose per point, given n points to track for
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        n_samp = imsize[0] * imsize[1]  # min(40000,imsize[0]*imsize[1])
        rand_subset = torch.linspace(0, model_input["pred_tracks"].size(-2)-1, n_samp).long()
        low_imres = (64, 64)

        # sample rigid flow masks per track and aggregate over time (if ever dynamic then dynamic) to get an estimate if a track is rigid
        rig_samp = ( grid_samp( model_input["rig_flow_masks"][:, :, [0]], model_input["pred_tracks"][:, 1:].unsqueeze(-2),) .squeeze(2) .squeeze(-1) .round())
        rig_samp = rig_samp_allframe = torch.where( model_input["pred_visibility"][:, 1:], rig_samp, torch.ones_like(rig_samp))
        rig_samp = rig_samp.min(dim=1)[0]

        # pick n random points if not provided
        if track_idxs is None: track_idxs = torch.randperm(model_input["pred_tracks"].size(-2))[:200]

        # General feature map backbone prediction
        img_inp = model_input["rgb"]# if not use_depth_inp else torch.cat((model_input["rgb"],depth_inp.log()-1),2)
        img_inp = rearrange( img_inp, "b (t s) c x y -> b t (s c) x y", s=self.time_stride)
        if "fmap" not in model_input or torch.is_grad_enabled():
            # FPN
            model_input["fmap"] = fmap_out = self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5)
            fmap_out = F.interpolate( self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5), imsize, mode="bilinear",)
            model_input["fmap"] = rearrange( fmap_out, "(b t) (s c) x y -> b (t s) c x y", s=self.time_stride, b=b)

            # MIDAS
            #model_input["midas_feats"]=midas_feats = self.midas(F.interpolate((model_input["rgb"]*.5+.5).flatten(0,1),(imsize[0]//32*32,imsize[1]//32*32),mode="bilinear"))
            #model_input["fmap"]=fmap_out=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))/(100 if not scratch_model else 1)
            
            #fmap_out = F.interpolate( self.model2(F.interpolate(img_inp.flatten(0,1),(128,224),mode="bilinear") * 0.5 + 0.5)/10, imsize, mode="bilinear",) # todo upscale size to nearest 32 multiple 
            #model_input["fmap"] = rearrange( fmap_out, "(b t) (s c) x y -> b (t s) c x y", s=self.time_stride, b=b)
        #depth = res_depth = 1e3/(F.interpolate(self.midas_out(model_input["midas_feats"]),imsize,mode="bilinear").unflatten(0,(b,n_trgt))+1e-1)
        depth = res_depth = F.softplus( self.depth_conv(model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt)) + 1)+1 
        if "depth_inp" in model_input and 1:
            # res depth
            if 0:
                depth_inp = ch_fst(model_input["depth_inp"], imsize[0])# * use_depth_inp
                depth = depth_inp + (res_depth-1)
            # use depth as gt inp
            else:
                depth = ch_fst(model_input["depth_inp"],imsize[0])*1e5
                depth_mask = (depth!=0).float()
                model_input["pred_visibility"] *= grid_samp( depth_mask, model_input["pred_tracks"].unsqueeze(-2),mode="nearest",pad="zeros").squeeze(2).squeeze(-1).bool()

        #return { "res_depth": ch_sec(res_depth), "depth": ch_sec(depth), } # architecture testing 
        #print("no depth test")
        #depth=res_depth=torch.ones_like(depth)

        # Lift point track into 2.5D image-aligned surface
        rds_track = geometry.get_world_rays( model_input["pred_tracks"], model_input["intrinsics"], None)[1]
        eye_surf_track = rds_track * grid_samp( depth, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)

        # Do static solve first
        if 1:
            general_conf_static = ( self.general_confidence_conv_static(model_input["fmap"].flatten(0, 1)) .unflatten(0, (b, n_trgt)) .sigmoid() .clip(min=1e-4))
            #if self.step<500 or 0: general_conf_static = torch.ones_like(general_conf_static)
            general_conf_track_static = grid_samp( general_conf_static, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2) * model_input["pred_visibility"].unsqueeze(-1)

            if 1:
                static_poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, rand_subset], eye_surf_track[:, None, :-1, rand_subset], general_conf_track_static[:, None, :-1, rand_subset].clip(min=1e-4),)[1]
                for i in range(n_trgt - 1, 0, -1): static_poses = torch.cat( (static_poses[:, :, :i], static_poses[:, :, [i - 1]] @ static_poses[:, :, i:]), -3)  # aggregate adjacent poses
                static_poses = torch.cat( ( torch.eye(4).to(static_poses)[None, None, None].expand(static_poses.size(0), static_poses.size(1), -1, -1, -1), static_poses,), -3,)  # add identity for starting pose
            else:
                static_poses = geometry.efficient_procrustes( eye_surf_track[:, None, :, rand_subset],eye_surf_track[:, None, :1, rand_subset].expand(-1,-1,n_trgt,-1,-1), general_conf_track_static[:, None, :, rand_subset].clip(min=1e-4),)[1]
                print("doing direct pose regression static")
            static_poses = static_poses.expand( -1, len(track_idxs), -1, -1, -1)  # if static solve, just use single pose as pose for all points

            # Compute static point track reprojection
            static_poses_all_to_all = repeat( static_poses.inverse(), "b p t x y -> b p s t x y", s=n_trgt) @ repeat(static_poses, "b p t x y -> b p t s x y", s=n_trgt)
            static_point_track_surf_reproj = torch.einsum( "bpstij,bstpj->bstpi", static_poses_all_to_all, hom( repeat( eye_surf_track[:, :, track_idxs], "b t p c -> b t s p c", s=n_trgt)),)[..., :3]
            static_point_track_reproj = project( static_point_track_surf_reproj, model_input["intrinsics"]).clip(0, 1)
            vis_and_rig_mask = (model_input["pred_visibility"] * torch.cat((torch.ones_like(rig_samp_allframe[:,:1]),rig_samp_allframe),1))
            static_point_track_loss = ( ( ( static_point_track_reproj - model_input["pred_tracks"][:, None, :, track_idxs]) * vis_and_rig_mask[:, None, :, track_idxs, None]).square().flatten().mean())

            out |= {
                "corr_weights_static": general_conf_static,
                "point_track_loss_static": static_point_track_loss,
                "point_track_reproj_static": static_point_track_reproj[:, 0],
                "poses": static_poses[:,0],
                "res_depth": ch_sec(res_depth),
                "depth": ch_sec(depth),
            }
        #print("just doing static");return out

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb_unnorm = self.affinities_conv( model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt))

        # Take affinity emb as mean over frames masked by visibility
        affinity_emb = F.normalize(affinity_emb_unnorm, dim=2)
        aff_emb_pertrack_allframe = ch_sec( grid_samp(affinity_emb, model_input["pred_tracks"].unsqueeze(-2)))
        aff_emb_pertrack = ( aff_emb_pertrack_allframe * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
        aff_sim = torch.einsum( "b p c, b q c -> b p q", aff_emb_pertrack[:, track_idxs], aff_emb_pertrack)  # from all source pix to all other source pix

        # Predict general correspondence weights -- how reliable is this track generally at each frame
        general_conf = ( self.general_confidence_conv(model_input["fmap"].flatten(0, 1)) .unflatten(0, (b, n_trgt)) .sigmoid() .clip(min=1e-4))
        #if self.step<500 or 0: general_conf = torch.ones_like(general_conf)
        general_conf_track = grid_samp( general_conf, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2) * model_input["pred_visibility"].unsqueeze(-1)

        #out |= self.flowcam_est(model_input,depth,general_conf)

        solve_stride = ( model_input["pred_tracks"].size(-2) // 3000)  # use every nth point in the solve
        aff_sim_rig = torch.where( rig_samp.bool()[:, track_idxs, None].expand(-1, -1, aff_sim.size(-1)), torch.ones_like(aff_sim), aff_sim,) # replace points in rigid mask with 1s
        if 1:
            poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1), 
                                                   eye_surf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1),
                                ( general_conf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1) * aff_sim_rig[:, :, None, ::solve_stride, None]).clip(min=1e-4),)[1]
            for i in range(n_trgt - 1, 0, -1): poses = torch.cat( (poses[:, :, :i], poses[:, :, [i - 1]] @ poses[:, :, i:]), -3)  # aggregate adjacent poses
        else:
            poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1), 
                                                   eye_surf_track[:, None, :1, ::solve_stride].expand(-1, aff_sim.size(1), n_trgt-1, -1, -1),
                                ( general_conf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1) * aff_sim_rig[:, :, None, ::solve_stride, None]).clip(min=1e-4),)[1]
            print("doing direct pose regression dyn")
        poses = torch.cat( ( torch.eye(4).to(poses)[None, None, None].expand(poses.size(0), poses.size(1), -1, -1, -1), poses,), -3,)  # add identity for starting pose

        # Compute point track reprojection
        poses_all_to_all = repeat( poses.inverse(), "b p t x y -> b p s t x y", s=n_trgt) @ repeat(poses, "b p t x y -> b p t s x y", s=n_trgt)
        point_track_surf_reproj = torch.einsum( "bpstij,bstpj->bstpi", poses_all_to_all, hom( repeat( eye_surf_track[:, :, track_idxs], "b t p c -> b t s p c", s=n_trgt)),)[..., :3]
        point_track_reproj = project( point_track_surf_reproj, model_input["intrinsics"]).clip(0, 1)
        point_track_loss = ( ( ( point_track_reproj - model_input["pred_tracks"][:, None, :, track_idxs]) * model_input["pred_visibility"][:, None, :, track_idxs, None]).square().flatten().mean())

        # Compute pose-induced optical flow
        # Note since the optical flow is computed frame to frame, any time you are skipping frames in the dataloader, this supervision is incorrect
        #point_track_surf_reproj_optflow = torch.einsum( "bptij,btpj->btpi", poses[:, :, :-1].inverse() @ poses[:, :, 1:], hom(eye_surf_track[:, 1:, track_idxs]),)[..., :3]
        #point_track_reproj_optflow = ( project( point_track_surf_reproj_optflow, model_input["intrinsics"][:, 1:]).clip(0, 1) - model_input["pred_tracks"][:, 1:, track_idxs])
        #gt_track_adj_flow = rearrange( grid_samp( model_input["bwd_flow"], model_input["pred_tracks"][:, 1:, track_idxs].unsqueeze(-2),), "b t c p 1 -> b t p c",)
        #adj_optflow_loss = ( (gt_track_adj_flow - point_track_reproj_optflow).square().mean())

        # mean color and crds per track for vis
        rgb_pertrack = ch_sec( grid_samp(model_input["rgb"], model_input["pred_tracks"].unsqueeze(-2)))
        rgb_pertrack = ( rgb_pertrack * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
        worldcrds_pertrack = torch.einsum( "bptij,btpj->btpi", poses, hom(eye_surf_track[:, :, track_idxs]))[..., :3]
        worldcrds_pertrack = ( worldcrds_pertrack * model_input["pred_visibility"][:, :, track_idxs].unsqueeze(-1)
                                ).sum(dim=1) / model_input["pred_visibility"][:, :, track_idxs].unsqueeze( -1).sum( dim=1).clip( min=1) 

        # for vis of affsim on regular grid
        with torch.no_grad():
            track_sl=64 if 1 else 42
            src_tracks_0=rearrange(aff_emb_pertrack,"b (x y s) c -> b s c x y",y=track_sl,x=track_sl)[:,0]
            #aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::4,::4]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//4,track_sl//4)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix
            #aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::3,::3]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//3,track_sl//3)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix
            aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::4,::4]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//4,track_sl//4)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix

        return out | {
            "worldcrds_pertrack": worldcrds_pertrack,
            "rgb_pertrack": rgb_pertrack,
            "rig_pertrack": rig_samp,
            "poses_all": poses,
            #"depth_inp": model_input["depth_inp"],
            "point_track_loss": point_track_loss,
            "point_track_reproj": point_track_reproj[:, 0],
            "corr_weights": general_conf,
            "aff_sim": aff_sim,
            "aff_sim_grid": aff_sim_grid,
            "affinity_emb": affinity_emb,
            "affinity_emb_unnorm": affinity_emb_unnorm,
            "aff_emb_pertrack": aff_emb_pertrack,
            "depth": ch_sec(depth),
        }

    def forward_( self, model_input, track_idxs=None, out={}):  # flowcam like 2 frame est
        if torch.is_grad_enabled():self.step+=1

        imsize = model_input["rgb"].shape[-2:]
        (b, _), n_trgt = model_input["rgb"].shape[:2], model_input["rgb"].size(1)
        n_samp = 1000#imsize[0] * imsize[1]  # min(40000,imsize[0]*imsize[1])
        rand_subset = torch.randperm(imsize[0]*imsize[1])[:n_samp]
        low_imres = (64, 64)

        # General feature map backbone prediction
        img_inp = model_input["rgb"]# if not use_depth_inp else torch.cat((model_input["rgb"],depth_inp.log()-1),2)
        img_inp = rearrange( img_inp, "b (t s) c x y -> b t (s c) x y", s=self.time_stride)
        if "fmap" not in model_input or torch.is_grad_enabled():
            if 1:
                # FPN
                model_input["fmap"] = fmap_out = self.img_enc(img_inp.flatten(0, 1) * 0.5 + 0.5)
                model_input["fmap"] = rearrange( fmap_out, "(b t) (s c) x y -> b (t s) c x y", s=self.time_stride, b=b)
                depth = res_depth = F.softplus( self.depth_conv(model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt)) + 1)+1 
            else:
                # Midas
                midas_feats = self.midas(F.interpolate((model_input["rgb"]*.5+.5).flatten(0,1),(imsize[0]//32*32,imsize[1]//32*32),mode="bilinear"))
                model_input["fmap"]=fmap_out=F.interpolate(midas_feats,imsize,mode="bilinear").unflatten(0,(b,n_trgt))/(100 if not scratch_model else 1)
                depth = res_depth = 1e3/(F.interpolate(self.midas_out(midas_feats),imsize,mode="bilinear").unflatten(0,(b,n_trgt))+1e-1)

        # Lift depth map into point cloud
        corresp_uv = (model_input["x_pix"][:,:-1]+ch_sec(model_input["bwd_flow"]))
        rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"],None)[1]
        eye_surf=rds*ch_sec(depth)
        corresp_surf = F.grid_sample(ch_fst(eye_surf,imsize[0])[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))

        # Do static solve first
        corresp_feat = F.grid_sample(model_input["fmap"][:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt-1))
        corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(model_input["fmap"])[:,1:]),-1)).sigmoid().clip(min=1e-4)
        #general_conf_static = ( self.general_confidence_conv_static(model_input["fmap"].flatten(0, 1)) .unflatten(0, (b, n_trgt)) .sigmoid() .clip(min=1e-4))
        if self.step<500 or 0: corr_weights = torch.ones_like(corr_weights)

        # Estimate poses via procrustes 
        adj_transf = geometry.procrustes(eye_surf[:,1:,rand_subset], corresp_surf[:,:,rand_subset], corr_weights[:,:,rand_subset] )[1]
        # Integrate poses
        poses = adj_transf
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1)

        # Compute pose-induced optical flow
        adj_opt_flow = warp(eye_surf[:,1:],poses[:,:-1].inverse()@poses[:,1:],model_input["intrinsics"][:,1:])-model_input["x_pix"][:,1:]
        flow_from_pose_static_loss = ( adj_opt_flow.clip(-.2,.2) - ch_sec(model_input["bwd_flow"]).clip(-.2,.2) ).square().mean()

        out |= {
            "corr_weights": ch_fst(corr_weights,imsize[0]),
            "flow_from_pose_static_loss": flow_from_pose_static_loss,
            "flow_from_pose": adj_opt_flow,
            "poses": poses,
            "depth": ch_sec(depth),
        }
        print("just doing static");return out

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb_unnorm = self.affinities_conv( model_input["fmap"].flatten(0, 1)).unflatten(0, (b, n_trgt))

        # Take affinity emb as mean over frames masked by visibility
        affinity_emb = F.normalize(affinity_emb_unnorm, dim=2)
        aff_emb_pertrack_allframe = ch_sec( grid_samp(affinity_emb, model_input["pred_tracks"].unsqueeze(-2)))
        aff_emb_pertrack = ( aff_emb_pertrack_allframe * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
        aff_sim = torch.einsum( "b p c, b q c -> b p q", aff_emb_pertrack[:, track_idxs], aff_emb_pertrack)  # from all source pix to all other source pix

        # Predict general correspondence weights -- how reliable is this track generally at each frame
        general_conf = ( self.general_confidence_conv(model_input["fmap"].flatten(0, 1)) .unflatten(0, (b, n_trgt)) .sigmoid() .clip(min=1e-4))
        if self.step<500 or 0: general_conf = torch.ones_like(general_conf)
        general_conf_track = grid_samp( general_conf, model_input["pred_tracks"].unsqueeze(-2)).squeeze(2) * model_input["pred_visibility"].unsqueeze(-1)


        #out |= self.flowcam_est(model_input,depth,general_conf)

        solve_stride = ( model_input["pred_tracks"].size(-2) // 3000)  # use every nth point in the solve
        aff_sim_rig = torch.where( rig_samp.bool()[:, track_idxs, None].expand(-1, -1, aff_sim.size(-1)), torch.ones_like(aff_sim), aff_sim,) # replace points in rigid mask with 1s
        poses = geometry.efficient_procrustes( eye_surf_track[:, None, 1:, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1), 
                                               eye_surf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1),
                            ( general_conf_track[:, None, :-1, ::solve_stride].expand(-1, aff_sim.size(1), -1, -1, -1) * aff_sim_rig[:, :, None, ::solve_stride, None]).clip(min=1e-4),)[1]
        for i in range(n_trgt - 1, 0, -1): poses = torch.cat( (poses[:, :, :i], poses[:, :, [i - 1]] @ poses[:, :, i:]), -3)  # aggregate adjacent poses
        poses = torch.cat( ( torch.eye(4).to(poses)[None, None, None].expand(poses.size(0), poses.size(1), -1, -1, -1), poses,), -3,)  # add identity for starting pose

        # Compute point track reprojection
        poses_all_to_all = repeat( poses.inverse(), "b p t x y -> b p s t x y", s=n_trgt) @ repeat(poses, "b p t x y -> b p t s x y", s=n_trgt)
        point_track_surf_reproj = torch.einsum( "bpstij,bstpj->bstpi", poses_all_to_all, hom( repeat( eye_surf_track[:, :, track_idxs], "b t p c -> b t s p c", s=n_trgt)),)[..., :3]
        point_track_reproj = project( point_track_surf_reproj, model_input["intrinsics"]).clip(0, 1)
        point_track_loss = ( ( ( point_track_reproj - model_input["pred_tracks"][:, None, :, track_idxs]) * model_input["pred_visibility"][:, None, :, track_idxs, None]).square().flatten().mean())

        # Compute pose-induced optical flow
        # Note since the optical flow is computed frame to frame, any time you are skipping frames in the dataloader, this supervision is incorrect
        #point_track_surf_reproj_optflow = torch.einsum( "bptij,btpj->btpi", poses[:, :, :-1].inverse() @ poses[:, :, 1:], hom(eye_surf_track[:, 1:, track_idxs]),)[..., :3]
        #point_track_reproj_optflow = ( project( point_track_surf_reproj_optflow, model_input["intrinsics"][:, 1:]).clip(0, 1) - model_input["pred_tracks"][:, 1:, track_idxs])
        #gt_track_adj_flow = rearrange( grid_samp( model_input["bwd_flow"], model_input["pred_tracks"][:, 1:, track_idxs].unsqueeze(-2),), "b t c p 1 -> b t p c",)
        #adj_optflow_loss = ( (gt_track_adj_flow - point_track_reproj_optflow).square().mean())

        # mean color and crds per track for vis
        rgb_pertrack = ch_sec( grid_samp(model_input["rgb"], model_input["pred_tracks"].unsqueeze(-2)))
        rgb_pertrack = ( rgb_pertrack * model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1) / model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip( min=1)
        worldcrds_pertrack = torch.einsum( "bptij,btpj->btpi", poses, hom(eye_surf_track[:, :, track_idxs]))[..., :3]
        worldcrds_pertrack = ( worldcrds_pertrack * model_input["pred_visibility"][:, :, track_idxs].unsqueeze(-1)
                                ).sum(dim=1) / model_input["pred_visibility"][:, :, track_idxs].unsqueeze( -1).sum( dim=1).clip( min=1) 

        # for vis of affsim on regular grid
        with torch.no_grad():
            track_sl=64 if 0 else 42
            src_tracks_0=rearrange(aff_emb_pertrack,"b (x y s) c -> b s c x y",y=track_sl,x=track_sl)[:,0]
            #aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::4,::4]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//4,track_sl//4)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix
            aff_sim_grid = torch.einsum( "b p c, b q c -> b p q", ch_sec(src_tracks_0[...,::3,::3]), ch_sec(src_tracks_0)).unflatten(1,(track_sl//3,track_sl//3)).unflatten(-1,(track_sl,track_sl))  # from all source pix to all other source pix

        return out | {
            "worldcrds_pertrack": worldcrds_pertrack,
            "rgb_pertrack": rgb_pertrack,
            "rig_pertrack": rig_samp,
            "poses_all": poses,
            #"depth_inp": model_input["depth_inp"],
            "point_track_loss": point_track_loss,
            "point_track_reproj": point_track_reproj[:, 0],
            "corr_weights": general_conf,
            "aff_sim": aff_sim,
            "aff_sim_grid": aff_sim_grid,
            "affinity_emb": affinity_emb,
            "affinity_emb_unnorm": affinity_emb_unnorm,
            "aff_emb_pertrack": aff_emb_pertrack,
        }



class ResnetFPN(nn.Module): # from pixelnerf code
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError(
                    "normalization layer [%s] is not found" % norm_type
                )
            return norm_layer

        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(pretrained=pretrained, norm_layer=norm_layer)

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 512, 1),
        )

        self.out_dim=64#32 # todo make arg
        #self.combs = nn.ModuleList([ nn.Sequential(nn.Conv2d(256, 128, 1),nn.ReLU(),nn.Conv2d(128,128,1)) for d1,d2 in [(256,128)]])#[4,64,64,128,256]])
        self.combs_1 = nn.ModuleList([ nn.Conv2d(d1, d2, 1) for d1,d2 in [(256,128),(128,64),(64,64),(64,64),(64,self.out_dim)]]).cuda()#[4,64,64,128,256]])
        self.combs_2 = nn.ModuleList([ nn.Conv2d(d, d, 1) for d in [128,64,64,64,self.out_dim]]).cuda()#[4,64,64,128,256]])
        self.last_conv_up=nn.Conv2d(in_ch, self.out_dim, 1).cuda()

    def forward(self, x, custom_size=None):

        if len(x.shape) > 4: return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate( x, scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area", align_corners=True if self.feature_scale > 1.0 else None, recompute_scale_factor=True,)
        latents = [x]

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents.append(x)

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]

        up_latent = self.combs_2[0]( (self.combs_1[0](F.interpolate(latents[-1],latents[-2].shape[-2:],mode="bilinear"))+latents[-2]).relu() )
        up_latent = self.combs_2[1]( (self.combs_1[1](F.interpolate(up_latent,latents[-3].shape[-2:],mode="bilinear"))+latents[-3]).relu() )
        up_latent = self.combs_2[3]( (self.combs_1[3](F.interpolate(up_latent,latents[-4].shape[-2:],mode="bilinear"))+latents[-4]).relu() )
        up_latent = self.combs_2[4]( (self.combs_1[4](F.interpolate(up_latent,latents[-5].shape[-2:],mode="bilinear"))+self.last_conv_up(latents[-5])).relu() )
        return up_latent
        #for i in range(len(latents)):
        #    latents[i] = F.interpolate(
        #        latents[i],
        #        latent_sz if custom_size is None else custom_size,
        #        mode=self.upsample_interp,
        #        align_corners=align_corners,
        #    )
        #self.latent = torch.cat(latents, dim=1)
        #self.latent_scaling[0] = self.latent.shape[-1]
        #self.latent_scaling[1] = self.latent.shape[-2]
        #self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        #return self.out(self.latent)

    def forward_(self, x, custom_size=None):

        if len(x.shape) > 4:
            return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        #latents.append(x)
        latents = [x]

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
