import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F

from einops import rearrange, repeat
from torch import einsum

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
scale_up  = lambda x: F.interpolate(x,scale_factor=2,mode="bilinear")

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

class ResConvBlock(nn.Module):
    def __init__(self, ch):
        super(ResConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.skip_connection = nn.Conv2d(ch, ch, 3, 1, 1)

    def forward(self, x):
        out = F.silu(self.conv1(x))
        out = F.silu(self.conv2(out))
        return out + self.skip_connection(x)

class CrossAttn_(nn.Module):
    def __init__(self, ch, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.ch = ch

        self.to_q = nn.Linear(ch, inner_dim, bias=False)
        self.to_kv = nn.Linear(ch, inner_dim * 2, bias=False)
        self.proj = nn.Linear(inner_dim, ch)

        self.out = nn.Sequential(
            nn.Linear(ch, int(4*ch)),
            nn.GELU(),
            nn.Linear(int(4*ch), ch)
        )

        self.ln_1 = nn.LayerNorm([self.ch])
        self.ln_2 = nn.LayerNorm([self.ch])

    # x is the image patches and y is the cls tokens, for ex.
    def forward(self, x, y, attn_mask=None,return_heads=False, return_attn=False,softmax_axis=-1):
        if len(x.shape)>3: 
            return self(x.flatten(0,-3),y.flatten(0,-3),attn_mask,return_heads,return_attn,softmax_axis).unflatten(0,x.shape[:-2])

        x_ln = self.ln_1(x)
        y_ln = self.ln_1(y)

        h = self.heads

        q = self.to_q(y_ln)
        k, v = self.to_kv(x_ln).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        # attention, what we cannot get enough of
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim=softmax_axis)

        out = einsum('b i j, b j d -> b i d', attn, v)

        if return_heads: return rearrange(out,'(b h) n d -> b n h d', h=h)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

        out = self.proj(out) + y
        out = self.out(self.ln_2(out)) + out

        return (out,rearrange(attn,"(b h) n d -> b n h d",h=h)) if return_attn else out

class MySceneRep(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100,imsl=28,in_ch=1):
        super(MySceneRep, self).__init__()
    
        res=(imsl,imsl)
        dims = [in_ch,64,128,256]

        # Sinusoidal time embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # Downconvs
        self.down_convs      = nn.ModuleList([ nn.Conv2d(dims[d_i], dims[d_i+1], 4, 2, 1) 
                                                                    for d_i in range(len(dims)-1)])
        self.conv_blocks_enc = nn.ModuleList([ nn.Sequential(*[ResConvBlock(dims[d_i+1]) for _ in range(2)]) 
                                                                    for d_i in range(len(dims)-1)])
        self.time_encs       = nn.ModuleList([ nn.Sequential(nn.Linear(time_emb_dim, dims[d_i+1]), nn.SiLU(), nn.Linear(dims[d_i+1], dims[d_i+1])) 
                                                                    for d_i in range(len(dims)-1)])

        # Upconvs
        self.downproj_dec    = nn.ModuleList([ nn.Conv2d((dims+dims[-1:])[d_i],(dims+dims[-1:])[d_i-1],1) for d_i in range(1,len(dims)+1) ][::-1])
        self.skip_lins       = nn.ModuleList([ nn.Conv2d(d,d,1) for d in dims[::-1] ])
        self.conv_blocks_dec = nn.ModuleList([ nn.Sequential(*[ResConvBlock(d) for _ in range(2)]) for d in dims])

        # Transformer self attention with registers processing after convs 
        self.self_attns = nn.ModuleList([CrossAttn_(dims[-1],4,dims[-1]//2) for _ in range(4)]).cuda()
        self.scene_tokens = nn.Embedding(8, dims[-1]) #learnable register embeddings
        self.spatial_emb = nn.Embedding(8*8, dims[-1]) # learnable spatial embedding

    def forward(self, x, t, y=None):

        time_emb = self.time_embed(t) # sinusoidal time embedding 

        # Downconvs encoding - downconv, add time info, then spatial-preserving conv blocks
        xs_down=[x]
        for downconv,conv_enc,time_enc in zip(self.down_convs,self.conv_blocks_enc,self.time_encs):
            xs_down.append( conv_enc( downconv(xs_down[-1])+time_enc(time_emb)[...,None,None] ) )

        # Low latent space thinking - self attns with registers
        mid_feats = ch_sec(xs_down[-1]) + self.spatial_emb.weight[None]
        mid_feats = torch.cat((mid_feats,self.scene_tokens.weight[None].expand(len(mid_feats),-1,-1)),1) # add registers
        for self_attn in self.self_attns: mid_feats = self_attn(mid_feats,mid_feats)
        mid_feats = ch_fst(mid_feats[:,:-self.scene_tokens.weight.size(0)], xs_down[-1].size(-1))

        #lets think about doing everything with convs? 

        #render out 32x32 res with coordinate query attending to spatially corresponding feature at 16x16 res and 8x8 res and N registers (all downprojected to 32x32 feature dim)
        #then standard skip-connect conv decode from 32x32 (popping encoder stack) to full res with downproj
        #return input img - rendered img
        #start with not using any transfofmer -- just use 32x32 res features to start

        # Upconvs skip-connection decoding - bilinearly upsample on concatenated skip then spatial-preserving convs
        curr_feat = mid_feats
        for i,(skip_feat,skip_lin,conv_up,downproj) in enumerate(zip(xs_down[::-1],self.skip_lins,self.conv_blocks_dec,self.downproj_dec)):
            if i: curr_feat = downproj(scale_up(curr_feat))
            curr_feat = curr_feat + skip_lin(skip_feat)

        return curr_feat

