import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T
from torchcubicspline import(natural_cubic_spline_coeffs, NaturalCubicSpline)

import geometry,mlp_helpers

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)
hom = lambda x: torch.cat((x, torch.ones_like(x[..., [0]])), -1)
unhom = lambda x: x[..., :-1] / (1e-5 + x[..., -1:])
grid_samp_ = lambda x, y, pad,mode: F.grid_sample( x, y * 2 - 1, mode=mode, padding_mode=pad)  # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x, y, pad="border",mode="bilinear": ( grid_samp_(x, y, pad,mode) if len(x.shape) == 4 else grid_samp_(x.flatten(0, 1), y.flatten(0, 1),pad,mode).unflatten(0, x.shape[:2]))  
project = lambda crds, K: unhom(torch.einsum("b...cij,b...ckj->b...cki", K, crds))
warp = lambda crds, poses, K: project( torch.einsum("b...cij,b...ckj->b...cki", poses, hom(crds))[..., :3], K)
shuffle = lambda x: x[torch.randperm(len(x)).to(x)]

class CanonRobotTrajPredTransfEuclidean(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # baseline model
        self.img_enc = ResnetFPN(in_ch=3, use_first_pool=True)

        latent_dim=256
        import torchvision.models as models
        resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.traj_pred = mlp_helpers.make_net([512,latent_dim,latent_dim,latent_dim,7 if args.noncanon6d else 9])#7 if args.noncanon else 9])

    def forward( self, model_input, track_idxs=None, out={}):
        latent = self.resnet(model_input["img"]).squeeze(-1).squeeze(-1)
        traj_pred=self.traj_pred(latent)
        return ({ "pred_trans": traj_pred[...,:3], "pred_rot": traj_pred[...,3:], } if self.args.noncanon6d or self.args.noncanon9d
           else { "pred_tris": traj_pred.unflatten(-1,(3,3)), })
    

class ResnetFPN(nn.Module): # from pixelnerf code
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError(
                    "normalization layer [%s] is not found" % norm_type
                )
            return norm_layer

        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(pretrained=pretrained, norm_layer=norm_layer)

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 512, 1),
        )

        self.out_dim=64#32 # todo make arg
        #self.combs = nn.ModuleList([ nn.Sequential(nn.Conv2d(256, 128, 1),nn.ReLU(),nn.Conv2d(128,128,1)) for d1,d2 in [(256,128)]])#[4,64,64,128,256]])
        self.combs_1 = nn.ModuleList([ nn.Conv2d(d1, d2, 1) for d1,d2 in [(256,128),(128,64),(64,64),(64,64),(64,self.out_dim)]]).cuda()#[4,64,64,128,256]])
        self.combs_2 = nn.ModuleList([ nn.Conv2d(d, d, 1) for d in [128,64,64,64,self.out_dim]]).cuda()#[4,64,64,128,256]])
        self.last_conv_up=nn.Conv2d(in_ch, self.out_dim, 1).cuda()

    def forward(self, x, custom_size=None):

        if len(x.shape) > 4: return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate( x, scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area", align_corners=True if self.feature_scale > 1.0 else None, recompute_scale_factor=True,)
        latents = []#x]

        x = self.model.conv1(x)
        x = self.model.maxpool(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents.append(x)

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        # Overriding here to just use last layer's features instead of all since doing transformer instead of pixel-aligned task
        return latents[-1]

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]

        up_latent = self.combs_2[0]( (self.combs_1[0](F.interpolate(latents[-1],latents[-2].shape[-2:],mode="bilinear"))+latents[-2]).relu() )
        up_latent = self.combs_2[1]( (self.combs_1[1](F.interpolate(up_latent,latents[-3].shape[-2:],mode="bilinear"))+latents[-3]).relu() )
        up_latent = self.combs_2[3]( (self.combs_1[3](F.interpolate(up_latent,latents[-4].shape[-2:],mode="bilinear"))+latents[-4]).relu() )
        #up_latent = self.combs_2[4]( (self.combs_1[4](F.interpolate(up_latent,latents[-5].shape[-2:],mode="bilinear"))+self.last_conv_up(latents[-5])).relu() )
        return up_latent
        #for i in range(len(latents)):
        #    latents[i] = F.interpolate(
        #        latents[i],
        #        latent_sz if custom_size is None else custom_size,
        #        mode=self.upsample_interp,
        #        align_corners=align_corners,
        #    )
        #self.latent = torch.cat(latents, dim=1)
        #self.latent_scaling[0] = self.latent.shape[-1]
        #self.latent_scaling[1] = self.latent.shape[-2]
        #self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        #return self.out(self.latent)

    def forward_(self, x, custom_size=None):

        if len(x.shape) > 4:
            return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        #latents.append(x)
        latents = [x]

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
