import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T
from torchcubicspline import(natural_cubic_spline_coeffs, NaturalCubicSpline)

import geometry,mlp_helpers

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)
hom = lambda x: torch.cat((x, torch.ones_like(x[..., [0]])), -1)
unhom = lambda x: x[..., :-1] / (1e-5 + x[..., -1:])
grid_samp_ = lambda x, y, pad,mode: F.grid_sample( x, y * 2 - 1, mode=mode, padding_mode=pad)  # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x, y, pad="border",mode="bilinear": ( grid_samp_(x, y, pad,mode) if len(x.shape) == 4 else grid_samp_(x.flatten(0, 1), y.flatten(0, 1),pad,mode).unflatten(0, x.shape[:2]))  
project = lambda crds, K: unhom(torch.einsum("b...cij,b...ckj->b...cki", K, crds))
warp = lambda crds, poses, K: project( torch.einsum("b...cij,b...ckj->b...cki", poses, hom(crds))[..., :3], K)
shuffle = lambda x: x[torch.randperm(len(x)).to(x)]


class CanonRobotTrajPredPC(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # pointcloud transformer

        latent_dim=64
        n_timesteps=20
        self.global_emb_latent = nn.Embedding(n_timesteps, latent_dim)

        self.pos_enc=mlp_helpers.PositionalEncodingNoFreqFactor(3,3)
        n_layer_transf=4
        self.pc_upproj = nn.Linear(24,latent_dim)
        self.cross_attns = nn.ModuleList([mlp_helpers.CrossAttn_(latent_dim,latent_dim) for _ in range(n_layer_transf)]).cuda()
        self.pc_net_mlps = torch.nn.ModuleList([mlp_helpers.make_net([latent_dim,latent_dim,latent_dim]) for _ in range(n_layer_transf)])
        self.global_latent_nets = torch.nn.ModuleList([mlp_helpers.make_net([latent_dim,latent_dim,latent_dim]) for _ in range(n_layer_transf)])
        self.traj_pred = mlp_helpers.make_net([latent_dim,latent_dim,128,3*20]) 

        self.pc_net = nn.Sequential(*[mlp_helpers.make_net([latent_dim,latent_dim,latent_dim]) for _ in range(4)])
        self.global_net = nn.Sequential(*[mlp_helpers.make_net([latent_dim,latent_dim,latent_dim]) for _ in range(4)])


    def forward( self, model_input, track_idxs=None, out={}):


        # pointclouds initial feature (just positional encoded position and rgb but later use foundation features with film)
        #pc_pos= model_input["fps_canon_pc_pos"]
        #pc_rgb= model_input["fps_canon_pc_rgb"]

        # global latent feature cross attending to point cloud
        #global_latent = self.global_emb_latent(torch.tensor([0]).cuda()).expand(len(pc_pos),-1)[:,None]

        #global_latent = self.pc_upproj(torch.cat((self.pos_enc(pc_pos),pc_rgb),-1)).max(dim=1)[0]
        #global_latent = self.global_net(global_latent)
        #pred = self.traj_pred(global_latent).unflatten(-1,(20,3))

        #feat(pool(mlp(posencandrgb))

        #pc_feat = self.pc_upproj(torch.cat((pc_pos,pc_rgb),-1))

        for pc_mlp,global_latent_mlp,attn in zip(self.pc_net_mlps,self.global_latent_nets,self.cross_attns): 
            global_latent = attn(pc_feat,global_latent)
            pc_feat,global_latent = pc_mlp(pc_feat)+pc_feat,global_latent_mlp(global_latent)+global_latent
        
        # map latent to trajectory (very dumb, todo use implicit linear timestep query)
        #pred = self.traj_pred(global_latent.squeeze(1)).unflatten(-1,(20,3))

        # Encode GT gripper in DCT space for sup
        return out | {
            "pred_traj": pred,
        }

class CanonRobotTrajPredTransfBaseline(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # baseline model
        self.img_enc = ResnetFPN(in_ch=3, use_first_pool=True)#, nn.Conv2d(512, fdim * self.time_stride, 3, padding=1),)

        latent_dim=256
        n_timesteps=20
        self.global_emb_latent = nn.Embedding(n_timesteps, latent_dim)

        n_layer_transf=6
        self.cross_attns = nn.ModuleList([mlp_helpers.CrossAttn_(latent_dim,latent_dim) for _ in range(n_layer_transf)])

        self.traj_decode = mlp_helpers.make_net([latent_dim,latent_dim,latent_dim,latent_dim,3])

    def forward( self, model_input, track_idxs=None, out={}):

        # Image encoder to produce tokens (around 15x20 resolution)
        img_tokens = ch_sec(self.img_enc(model_input["canon_pc_filtered_img" if not self.args.noncanon else "start_img_img"][:,0]))
        # Motion tokens that will be decoded into actions which attend to the image tokens
        motion_tokens = self.global_emb_latent.weight[None].expand(len(img_tokens),-1,-1)

        # Motion tokens attend to the image tokens
        for attn in self.cross_attns: motion_tokens = attn(img_tokens,motion_tokens)

        # Decode motion tokens into 3d gripper trajectory
        pred_traj = self.traj_decode(motion_tokens)


        return out | {
            "pred_traj": pred_traj,
        }
class CanonRobotTrajPredResnetBaseline(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        # baseline model
        import torchvision.models as models
        resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])

        self.traj_pred = make_net([512,256,128,3*20])

    def forward( self, model_input, track_idxs=None, out={}):

        latent = self.resnet(model_input["canon_pc_filtered_img" if not self.args.noncanon else "start_img_img"][:,0]).squeeze(-1).squeeze(-1)
        pred = self.traj_pred(latent).unflatten(-1,(20,3))

        # Encode GT gripper in DCT space for sup
        return out | {
            "pred_traj": pred,
        }

class ResnetFPN(nn.Module): # from pixelnerf code
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError(
                    "normalization layer [%s] is not found" % norm_type
                )
            return norm_layer

        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(pretrained=pretrained, norm_layer=norm_layer)

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 512, 1),
        )

        self.out_dim=64#32 # todo make arg
        #self.combs = nn.ModuleList([ nn.Sequential(nn.Conv2d(256, 128, 1),nn.ReLU(),nn.Conv2d(128,128,1)) for d1,d2 in [(256,128)]])#[4,64,64,128,256]])
        self.combs_1 = nn.ModuleList([ nn.Conv2d(d1, d2, 1) for d1,d2 in [(256,128),(128,64),(64,64),(64,64),(64,self.out_dim)]]).cuda()#[4,64,64,128,256]])
        self.combs_2 = nn.ModuleList([ nn.Conv2d(d, d, 1) for d in [128,64,64,64,self.out_dim]]).cuda()#[4,64,64,128,256]])
        self.last_conv_up=nn.Conv2d(in_ch, self.out_dim, 1).cuda()

    def forward(self, x, custom_size=None):

        if len(x.shape) > 4: return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate( x, scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area", align_corners=True if self.feature_scale > 1.0 else None, recompute_scale_factor=True,)
        latents = []#x]

        x = self.model.conv1(x)
        x = self.model.maxpool(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents.append(x)

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        # Overriding here to just use last layer's features instead of all since doing transformer instead of pixel-aligned task
        return latents[-1]

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]

        up_latent = self.combs_2[0]( (self.combs_1[0](F.interpolate(latents[-1],latents[-2].shape[-2:],mode="bilinear"))+latents[-2]).relu() )
        up_latent = self.combs_2[1]( (self.combs_1[1](F.interpolate(up_latent,latents[-3].shape[-2:],mode="bilinear"))+latents[-3]).relu() )
        up_latent = self.combs_2[3]( (self.combs_1[3](F.interpolate(up_latent,latents[-4].shape[-2:],mode="bilinear"))+latents[-4]).relu() )
        #up_latent = self.combs_2[4]( (self.combs_1[4](F.interpolate(up_latent,latents[-5].shape[-2:],mode="bilinear"))+self.last_conv_up(latents[-5])).relu() )
        return up_latent
        #for i in range(len(latents)):
        #    latents[i] = F.interpolate(
        #        latents[i],
        #        latent_sz if custom_size is None else custom_size,
        #        mode=self.upsample_interp,
        #        align_corners=align_corners,
        #    )
        #self.latent = torch.cat(latents, dim=1)
        #self.latent_scaling[0] = self.latent.shape[-1]
        #self.latent_scaling[1] = self.latent.shape[-2]
        #self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        #return self.out(self.latent)

    def forward_(self, x, custom_size=None):

        if len(x.shape) > 4:
            return self(x.flatten(0, 1), custom_size).unflatten(0, x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        #latents.append(x)
        latents = [x]

        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
