import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
import matplotlib.pyplot as plt 
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T
#from torchcubicspline import(natural_cubic_spline_coeffs, NaturalCubicSpline)

import geometry,mlp_helpers

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)
hom = lambda x: torch.cat((x, torch.ones_like(x[..., [0]])), -1)
unhom = lambda x: x[..., :-1] / (1e-5 + x[..., -1:])
grid_samp_ = lambda x, y, pad,mode: F.grid_sample( x, y * 2 - 1, mode=mode, padding_mode=pad)  # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x, y, pad="border",mode="bilinear": ( grid_samp_(x, y, pad,mode) if len(x.shape) == 4 else grid_samp_(x.flatten(0, 1), y.flatten(0, 1),pad,mode).unflatten(0, x.shape[:2]))  
project = lambda crds, K: unhom(torch.einsum("b...cij,b...ckj->b...cki", K, crds))
warp = lambda crds, poses, K: project( torch.einsum("b...cij,b...ckj->b...cki", poses, hom(crds))[..., :3], K)
shuffle = lambda x: x[torch.randperm(len(x)).to(x)]

class CanonRobotTrajPredTransfEuclidean(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        print("loading model")
        ch_in = 5 if 0 else 6 
        self.img_enc = ResnetFPN(in_ch=ch_in, out_dim=3,use_first_pool=True)#, nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)

        latent_dim=64

        self.hires_out = nn.Sequential( 
                            nn.Conv2d(ch_in + latent_dim, latent_dim, 3, padding=1),      nn.ReLU(inplace=True), 
                            nn.Conv2d(        latent_dim, latent_dim, 3, padding=1),      nn.ReLU(inplace=True), 
                            nn.Conv2d(        latent_dim, latent_dim, 3, padding=1),      nn.ReLU(inplace=True), 
                            nn.Conv2d(        latent_dim, latent_dim, 1))
        self.joint_sdf_out= nn.Sequential( 
                            nn.Conv2d(        latent_dim, 16, 1),      nn.ReLU(inplace=True), 
                            #nn.Conv2d(        latent_dim, latent_dim, 3, padding=1),      nn.ReLU(inplace=True), 
                            nn.Conv2d(        16, 3, 1))
        #self.fg_map   = nn.Conv2d(latent_dim, 1, 1)
        self.joint_sdf_pred   = nn.Conv2d(latent_dim, 12*3, 1) # todo make this link compositional instead of this unflattening thing
        self.link_lin   = nn.Conv2d(256, latent_dim, 1) # todo make this link compositional instead of this unflattening thing
        #self.joint_encoder = mlp_helpers.make_net([12,latent_dim,latent_dim,latent_dim,latent_dim])

    def forward( self, model_input, track_idxs=None, out={}): # single view calib resnet

        link_enc = self.link_lin( self.img_enc(model_input["link_imgs"].float().flatten(0,1),just_global=True)[...,None,None] ).unflatten(0,model_input["link_imgs"].shape[:2])

        hires_inp = torch.cat((model_input["img"].squeeze(1),model_input["points_cam"].float()),1)
        img_feats = self.img_enc( hires_inp )

        full_res_feats = self.hires_out(torch.cat(( hires_inp, F.interpolate(img_feats,model_input["img"].shape[-2:],mode="bilinear")),1))
        link_dec_feats = self.joint_sdf_out((full_res_feats[:,None] + link_enc).flatten(0,1)).unflatten(0,link_enc.shape[:2])
        # attn(key(full_res_feats)

        # encode link images

        return out | {  
                    "joint_sdfs": link_dec_feats
                    #"joint_sdfs": self.joint_sdf_pred(full_res_feats).unflatten(1,(12,3))[:,1:]
                }

    def forward_( self, model_input, track_idxs=None, out={}): # single view calib dino
        # Dino feature extractor -> 16x16 patches -> ViT
        dino_feats = self.feature_extractor( F.interpolate(torch.cat((model_input["targ_img"],model_input["canon_img"]),1).flatten(0,1),(252,252),mode="bilinear"),autoresize=True ).unflatten(0,(-1,2))
        vit_out = self.vit( dino_feats ,None)[:,1]
        from pdb import set_trace as pdb_;pdb_() 

        # Make hires by upsampling and then conv
        feat_up = F.interpolate(vit_out,model_input["canon_img"].squeeze(1).shape[-2:],mode="bilinear")
        pointmap = self.hires_out( torch.cat([model_input["canon_img"].squeeze(1), feat_up], dim=1) )



        #pos_emb=torch.stack(torch.meshgrid(torch.arange(model_input["targ_img"].size(-2)),torch.arange(model_input["targ_img"].size(-1))))[None].expand(len(model_input["targ_img"]),-1,-1,-1).to(model_input["targ_img"])*2/512-1
        #hires_inp = torch.cat((model_input["targ_img"].squeeze(1),pos_emb),1)
        #img_feats = self.img_enc( hires_inp )#.unbind(1)
        #full_res_feats = self.hires_out(torch.cat((hires_inp, F.interpolate(img_feats,model_input["targ_img"].shape[-2:],mode="bilinear")),1))
        targ_img = model_input["targ_img"].squeeze(1)
        img_feats = self.img_enc( targ_img )#.unbind(1)
        from pdb import set_trace as pdb_;pdb_() 
        full_res_feats = self.hires_out(torch.cat((targ_img, F.interpolate(img_feats,model_input["targ_img"].shape[-2:],mode="bilinear")),1))
        return out | {  
                        "segs": self.seg_map(full_res_feats*5).softmax(dim=1), 
                        "pointmap_cam" : self.cam_map(full_res_feats), 
                        "pointmap_link": self.link_map(full_res_feats), 
                }
    
class CanonRobotTrajPredTransfEuclidean_(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # dust3r testing
        #   print("importing dus3tr")
        #   from dust3r.model import AsymmetricCroCo3DStereo
        #   model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
        #   print("loading dus3tr model")
        #   self.dust3r = AsymmetricCroCo3DStereo.from_pretrained(model_name).cuda()
        #   print("done dus3tr")

        print("loading model")
        # baseline model
        self.img_enc = ResnetFPN(in_ch=3 if 1 else 6, out_dim=3,use_first_pool=True)#, nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)

        latent_dim=64
        #n_layer_transf=6
        #self.cross_attns = nn.ModuleList([mlp_helpers.CrossAttn_(latent_dim,latent_dim) for _ in range(n_layer_transf)])

        self.hires_out = nn.Sequential( 
                            nn.Conv2d(3 + 128, 64, 3, padding=1),      nn.ReLU(inplace=True), 
                            nn.Conv2d(64, 64, 3, padding=1),           nn.ReLU(inplace=True), 
                            nn.Conv2d(64, 64, 3, padding=1),           nn.ReLU(inplace=True), 
                            nn.Conv2d(64, 64, 1))
        self.fg_map = nn.Conv2d(64, 1, 1)
        self.seg_map = nn.Conv2d(64, 12+1, 1) # need to replace with per embodiment N obj/link
        self.obj_map = nn.Conv2d(64, 3, 1)
        self.cam_map = nn.Conv2d(64, 3, 1)
        if 0:
            from dit import DiT
            import feature_extractors
            print("first dino")
            patch_res=32
            self.feature_extractor = feature_extractors.SpatialDino( freeze_weights=True, num_patches_x=patch_res, num_patches_y=patch_res).cuda()
            self.feature_dim = self.feature_extractor.feature_dim
            print("now vit")
            self.vit = DiT( in_channels=384, out_channels=128, width=patch_res, depth=8, hidden_size=512, max_num_images=2, P=1,).cuda()
            print("done laoding dit model")
            #self.unpatchify = nn.ConvTranspose2d(embed_dim, out_channels, kernel_size=16, stride=16)
            self.hires_out = nn.Sequential( nn.Conv2d(3 + 128, 64, 3, padding=1), nn.ReLU(inplace=True), 
                                nn.Conv2d(64, 64, 3, padding=1),           nn.ReLU(inplace=True), 
                                nn.Conv2d(64, 64, 3, padding=1),           nn.ReLU(inplace=True), 
                                nn.Conv2d(64, 64, 1))
            #self.hires_out = nn.Sequential( nn.Conv2d(3 + 128, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1),           nn.ReLU(inplace=True), nn.Conv2d(64, 3, 1))
        elif 0:
            import sys
            sys.path.append("/home/cameronsmith/repos/controll3r/MoGe")
            from moge.model.v1 import MoGeModel
            self.moge= MoGeModel.from_pretrained("Ruicheng/moge-vitl").cuda()

    def forward( self, model_input, track_idxs=None, out={}): # single view calib resnet

        targ_img = model_input["targ_img"].squeeze(1)
        img_feats = self.img_enc( targ_img )#.unbind(1)
        full_res_feats = self.hires_out(torch.cat((targ_img, F.interpolate(img_feats,model_input["targ_img"].shape[-2:],mode="bilinear")),1))
        return out | {  
                        "segs": self.seg_map(full_res_feats).softmax(dim=1), 
                        "cam_pointmap": self.cam_map(full_res_feats), 
                        "obj_pointmap": self.obj_map(full_res_feats), 
                }
        #fg = self.fg_map(full_res_feats)
        #return out | { "fg": fg }
    def forward_( self, model_input, track_idxs=None, out={}): # single view calib moge

        # moge preprocess
        image = model_input["targ_img"].squeeze(1)
        original_height, original_width = image.shape[-2:]
        area = original_height * original_width
        aspect_ratio = original_width / original_height
        min_tokens, max_tokens = [1200, 2500]
        num_tokens = int(min_tokens + (9 / 9) * (max_tokens - min_tokens))
        moge_out= self.moge(model_input["targ_img"].squeeze(1),num_tokens//2) # halving tokens for decreased mem

        return out | { "cam_pointmap": moge_out[0], "obj_pointmap": moge_out[1], "segs": moge_out[2].sigmoid(), }

        # Make hires by upsampling and then conv
        #feat_up = F.interpolate(vit_out,image.shape[-2:],mode="bilinear")
        #hires_feat = self.hires_out( torch.cat([model_input["targ_img"].squeeze(1), feat_up], dim=1) )
        #return out | { "cam_pointmap": self.cam_map(hires_feat), "obj_pointmap": self.obj_map(hires_feat), "segs": self.seg_map(hires_feat).sigmoid(), }
    def forward_( self, model_input, track_idxs=None, out={}): # single view calib dit

        # Dino feature extractor -> 16x16 patches -> ViT
        dino_feats = self.feature_extractor( F.interpolate(model_input["targ_img"][:,0],(252,252),mode="bilinear"),autoresize=True )
        vit_out = self.vit( dino_feats[:,None] ,None)[:,0]

        # Make hires by upsampling and then conv
        feat_up = F.interpolate(vit_out,model_input["targ_img"].squeeze(1).shape[-2:],mode="bilinear")
        hires_feat = self.hires_out( torch.cat([model_input["targ_img"].squeeze(1), feat_up], dim=1) )
        return out | { 
                    "cam_pointmap": self.cam_map(hires_feat), 
                    "obj_pointmap": self.obj_map(hires_feat), 
                    "segs": self.seg_map(hires_feat).sigmoid(), 
                    }


    def forward_( self, model_input, track_idxs=None, out={}): # timm transformer 

        # Dino feature extractor -> 16x16 patches -> ViT
        dino_feats = self.feature_extractor( F.interpolate(torch.cat((model_input["targ_img"],model_input["canon_img"]),1).flatten(0,1),(252,252),mode="bilinear"),autoresize=True ).unflatten(0,(-1,2))
        vit_out = self.vit( dino_feats ,None)[:,1]

        # Make hires by upsampling and then conv
        feat_up = F.interpolate(vit_out,model_input["canon_img"].squeeze(1).shape[-2:],mode="bilinear")
        pointmap = self.hires_out( torch.cat([model_input["canon_img"].squeeze(1), feat_up], dim=1) )
        return out | { "pointmap": pointmap, }

    def forward_( self, model_input, track_idxs=None, out={}): # one view concat resnet
        pointmap = self.img_enc.forward_oneview( torch.cat((model_input["targ_img"].squeeze(1),model_input["canon_img"].squeeze(1)),1) , upsample=True)#.unbind(1)
        return out | { "pointmap": pointmap, }
    def forward_( self, model_input, track_idxs=None, out={}): # two view resnet
        pointmap = self.img_enc.two_view( model_input["canon_img"].squeeze(1),model_input["targ_img"].squeeze(1) )#.unbind(1)
        return out | { "pointmap": pointmap, }

    def forward_( self, model_input, track_idxs=None, out={}): # dust3r

        view1={"img":model_input["targ_img" ][:,0],"instance":torch.arange(len(model_input["targ_img"])).cuda()}
        view2={"img":model_input["canon_img"][:,0],"instance":torch.arange(len(model_input["targ_img"])).cuda()}
        dust3r_out=self.dust3r(view1,view2)
        pointmap=dust3r_out[1]["pts3d_in_other_view"].permute(0,3,1,2)

        # Image encoder produces image aligned representation
        #img_feats = self.img_enc( torch.cat((model_input["targ_img"],model_input["canon_img"]),1) )#.unbind(1)
        #tokens_targ,tokens_canon= ch_sec(F.interpolate(img_feats.flatten(0,1),(8,8),mode="bilinear").unflatten(0,img_feats.shape[:2])).unbind(1)

        ## Canon tokens cross attend to target img
        #for attn in self.cross_attns: tokens_canon = attn(tokens_targ,tokens_canon)

        ## Combine low-res attention tokens with hires feats: first merge at half res, then combine with full res
        #comb_feats = torch.cat((img_feats[:,1], F.interpolate(ch_fst(tokens_canon,8),img_feats.shape[-2:],mode="bilinear")),1)

        #pointmap = ch_fst(self.pointmap_decode(ch_sec(comb_feats)),comb_feats.size(-2))
        return out | {
                "pointmap": pointmap, 
        }
    def forward_( self, model_input, track_idxs=None, out={}):

        # Image encoder produces image aligned representation
        img_feats = self.img_enc( torch.cat((model_input["targ_img"],model_input["canon_img"]),1) )#.unbind(1)
        tokens_targ,tokens_canon= ch_sec(F.interpolate(img_feats.flatten(0,1),(8,8),mode="bilinear").unflatten(0,img_feats.shape[:2])).unbind(1)

        # Canon tokens cross attend to target img
        for attn in self.cross_attns: tokens_canon = attn(tokens_targ,tokens_canon)

        # Combine low-res attention tokens with hires feats: first merge at half res, then combine with full res
        comb_feats = torch.cat((img_feats[:,1], F.interpolate(ch_fst(tokens_canon,8),img_feats.shape[-2:],mode="bilinear")),1)

        pointmap = ch_fst(self.pointmap_decode(ch_sec(comb_feats)),comb_feats.size(-2))
        return out | {
                "pointmap": pointmap, 
        }

class CanonRobotTrajPredTransfBaselineNaive(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # baseline model
        self.img_enc = ResnetFPN(in_ch=3, use_first_pool=True)#, nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)

        latent_dim=256
        n_timesteps=1
        self.global_emb_latent = nn.Embedding(n_timesteps, latent_dim)

        n_layer_transf=6
        self.cross_attns = nn.ModuleList([mlp_helpers.CrossAttn_(latent_dim,latent_dim) for _ in range(n_layer_transf)])

        self.traj_decode = mlp_helpers.make_net([latent_dim,latent_dim,latent_dim,latent_dim,9])

    def forward( self, model_input, track_idxs=None, out={}):

        # Image encoder to produce tokens (around 15x20 resolution)
        img_tokens = ch_sec(self.img_enc(model_input["start_img"][:,0]))
        # Motion tokens that will be decoded into actions which attend to the image tokens
        motion_tokens = self.global_emb_latent.weight[None].expand(len(img_tokens),-1,-1)

        # Motion tokens attend to the image tokens
        for attn in self.cross_attns: motion_tokens = attn(img_tokens,motion_tokens)

        # Decode motion tokens into 3d gripper trajectory
        pred_traj = self.traj_decode(motion_tokens).squeeze(1)

        trans,rot=pred_traj[...,:3], pred_traj[...,3:]
        #model_input["transf_tris"],transf=geometry.rot_trans_to_tris(model_input["gripper_rot"],model_input["gripper_trans"])
        transf_tris=geometry.rot_trans_to_tris(rot,trans)[0]

        return out | {
                "pred_tris": transf_tris, 
        }

        rot_euler = torch.stack(kornia.geometry.conversions.euler_from_quaternion(*kornia.geometry.conversions.rotation_matrix_to_quaternion(geometry.rotation_6d_to_matrix(model_output["pred_rot"])).unbind(-1)),-1)
        model_output["pred_tris"]=geometry.rot_trans_to_tris(rot_euler.squeeze(1),model_output["pred_trans"].squeeze(1))[0]
    def forward_( self, model_input, track_idxs=None, out={}):

        # Image encoder to produce tokens (around 15x20 resolution)
        img_tokens = ch_sec(self.img_enc(model_input["start_img"][:,0]))
        # Motion tokens that will be decoded into actions which attend to the image tokens
        motion_tokens = self.global_emb_latent.weight[None].expand(len(img_tokens),-1,-1)

        # Motion tokens attend to the image tokens
        for attn in self.cross_attns: motion_tokens = attn(img_tokens,motion_tokens)

        # Decode motion tokens into 3d gripper trajectory
        pred_traj = self.traj_decode(motion_tokens)

        return out | {
                "pred_trans": pred_traj[...,:3],
                "pred_rot": pred_traj[...,3:],
        }
class CanonRobotTrajPredResnetBaseline(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # baseline model
        import torchvision.models as models
        resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])

        self.traj_pred = make_net([512,256,128,3*20])

    def forward( self, model_input, track_idxs=None, out={}):

        latent = self.resnet(model_input["canon_pc_filtered_img" if not self.args.noncanon else "start_img_img"][:,0]).squeeze(-1).squeeze(-1)
        pred = self.traj_pred(latent).unflatten(-1,(20,3))

        # Encode GT gripper in DCT space for sup
        return out | {
            "pred_traj": pred,
        }

class ResnetFPN(nn.Module): # from pixelnerf code, modified for two-view learning
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
        out_dim=64,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError(
                    "normalization layer [%s] is not found" % norm_type
                )
            return norm_layer

        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(pretrained=pretrained, norm_layer=norm_layer)

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 64, 1),
        )

        #self.out_dim=64#32 # todo make arg
        #self.combs = nn.ModuleList([ nn.Sequential(nn.Conv2d(256, 128, 1),nn.ReLU(),nn.Conv2d(128,128,1)) for d1,d2 in [(256,128)]])#[4,64,64,128,256]])
        self.combs_1 = nn.ModuleList([ nn.Conv2d(d1, d2, 1) for d1,d2 in [(256,128),(128,64),(64,64),(64,64),(64,out_dim)]]).cuda()#[4,64,64,128,256]])
        self.combs_2 = nn.ModuleList([ nn.Conv2d(d, d, 1) for d in [128,64,64,64,out_dim]]).cuda()#[4,64,64,128,256]])
        self.last_conv_up=nn.Conv2d(in_ch, out_dim, 1).cuda()

        self.aux_inject_nets = nn.ModuleList([mlp_helpers.make_net([128+f,f,f]) for f in [256,128,64,64,3][::-1]])

        self.cross_attns = nn.ModuleList([mlp_helpers.CrossAttn_(256,256) for _ in range(6)])

    def two_view(self, x, additional_view, custom_size=None): # two view
        latent_x=self.forward_oneview(x)
        latent_additional=self.forward_oneview(additional_view)

        token_x=ch_sec(latent_x[-1])
        token_additional=ch_sec(latent_additional[-1])
        for attn in self.cross_attns: token_x = attn(token_additional,token_x)
        latent_x[-1]=ch_fst(token_x,latent_x[-1].size(-2))
        return self.forward_oneview(x,latents=latent_x) # upsampling
    def forward_oneview(self, x, latents=None,upsample=False,aux_latent=None): # one view, aux latent is e.g. dino 16x16 features we're going to upsample here

        if len(x.shape) > 4: return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate( x, scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area", align_corners=True if self.feature_scale > 1.0 else None, recompute_scale_factor=True,)
        if latents is None:
            latents = [x]

            x = self.model.conv1(x)
            x = self.model.maxpool(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)

            latents.append(x)

            if self.num_layers > 1:
                if self.use_first_pool:
                    x = self.model.maxpool(x)
                x = self.model.layer1(x)
                latents.append(x)
            if self.num_layers > 2:
                x = self.model.layer2(x)
                latents.append(x)
            if self.num_layers > 3:
                x = self.model.layer3(x)
                latents.append(x)
            if self.num_layers > 4:
                x = self.model.layer4(x)
                latents.append(x)
            if not upsample: return latents
        #return latents[-1]

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]

        # add aux latent to latents
        for i in range(len(latents)): latents[i] = ch_fst(self.aux_inject_nets[i](ch_sec(torch.cat((F.interpolate(aux_latent,latents[i].shape[-2:],mode="bilinear"),latents[i]),1))),latents[i].size(-2))

        up_latent = self.combs_2[0]( (self.combs_1[0](F.interpolate(latents[-1],latents[-2].shape[-2:],mode="bilinear"))+latents[-2]).relu() )
        up_latent = self.combs_2[1]( (self.combs_1[1](F.interpolate(up_latent,latents[-3].shape[-2:],mode="bilinear"))  +latents[-3]).relu() )
        up_latent = self.combs_2[3]( (self.combs_1[3](F.interpolate(up_latent,latents[-4].shape[-2:],mode="bilinear"))  +latents[-4]).relu() )
        up_latent = self.combs_2[4]( (self.combs_1[4](F.interpolate(up_latent,latents[-5].shape[-2:],mode="bilinear"))  +self.last_conv_up(latents[-5])).relu() )
        return up_latent

        #up_latent = self.combs_2[0]( (self.combs_1[0](F.interpolate(latents[-1],latents[-2].shape[-2:],mode="bilinear"))+latents[-2]).relu() )
        #up_latent = self.combs_2[1]( (self.combs_1[1](F.interpolate(up_latent,latents[-3].shape[-2:],mode="bilinear"))  +latents[-3]).relu() )
        #up_latent = self.combs_2[3]( (self.combs_1[3](F.interpolate(up_latent,latents[-4].shape[-2:],mode="bilinear"))  +latents[-4]).relu() )
        #up_latent = self.combs_2[4]( (self.combs_1[4](F.interpolate(up_latent,latents[-5].shape[-2:],mode="bilinear"))  +self.last_conv_up(latents[-5])).relu() )
        #return up_latent
    def forward(self, x, custom_size=None,just_global=False):

        if len(x.shape) > 4:
            return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )

        latents = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents.append(x)

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)
        if just_global: return x.flatten(-2,-1).max(dim=-1)[0]

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
