import torch
import torch.nn as nn
import numpy as np
import attn_modules
from torch.nn import functional as F

from einops import rearrange, repeat
import conv_modules

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)


def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, imsl=28):
        super(MyDDPM, self).__init__()
        image_chw=(1, imsl, imsl)
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network
        self.betas = torch.linspace(min_beta, max_beta, n_steps).cuda()  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).cuda()

    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        #noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        eta_=(1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        noisy = x0 + eta_
        return noisy,eta_

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)

    def generate_new_images(self, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, imsl=28,just_last=False):
        """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
        h=w=imsl
        frame_idxs = np.linspace(0, self.n_steps, frames_per_gif).astype(np.uint)
        frames = []
        xs=[]

        with torch.no_grad():
            # Starting from random noise
            x = torch.randn(n_samples, c, h, w).cuda()
            idx, t = list(enumerate(list(range(self.n_steps))[::-1]))[0]
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples, 1) * t).cuda().long()
            noise_est= self.backward(x, time_tensor)
            #x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)
            return x-noise_est

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

class MySceneRep(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100,imsl=28,in_ch=1):
        super(MySceneRep, self).__init__()
    
        res=(imsl,imsl)

        # Sinusoidal time embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # Downconvs
        #dims = [in_ch,10,20,40]
        dims = [in_ch,64,128,256]
        self.conv_blocks_down = nn.ModuleList([
            nn.Sequential(*[
                    MyBlock((dims[i+int(j>0)],  res[0]//2**i, res[1]//2**i), dims[i+int(j>0)], dims[i+1])
                     for j in range(3)]
                )
            for i in range(3)])
        self.down_convs = nn.ModuleList([nn.Identity()]+[nn.Conv2d(dim, dim, 4, 2, 1) for dim in dims[1:4]])

        # Layer/Dimension specific time injections
        self.time_embeddings = nn.ModuleList([nn.Sequential(nn.Linear(time_emb_dim, d), nn.SiLU(), nn.Linear(d, d)) for d in dims+[d*2 for d in dims[::-1]]])
    
        # Transformer self attention with registers processing after convs 
        self.self_attns = nn.ModuleList([attn_modules.CrossAttn_(dims[-1],4,dims[-1]//2) for _ in range(4)]).cuda()
        self.scene_tokens = nn.Embedding(8, dims[-1]) #learnable register embeddings
        self.spatial_emb = nn.Embedding((res[0]//2**3)*(res[1]//2**3), dims[-1]) # learnable spatial embedding

        # Upconvs
        dims_=[dims[1]]+dims[1:]
        self.conv_blocks_up = nn.ModuleList([
            nn.Sequential(*[
                    MyBlock((2*dims_[i+int(j==0)],  res[0]//2**i, res[1]//2**i), dims_[i+int(j==0)]*2, dims_[i]*(2 if j<2 else 1))
                     for j in range(3)]
                )
            for i in [2,1,0]])
        self.conv_out = nn.Conv2d(dims[1], in_ch, 3, 1, 1)

    def forward(self, x, t):
        t = self.time_embed(t)
        n = len(x)

        # Downconvs encoding - downconv, add time info, then spatial-preserving conv blocks
        xs_down=[x]
        for conv_block, down_conv, time_emb in zip(self.conv_blocks_down,self.down_convs,self.time_embeddings):
            xs_down.append( conv_block(down_conv(xs_down[-1]) + time_emb(t).reshape(n,-1,1,1)) )

        # Low latent space thinking - self attns with registers
        mid_feats = ch_sec(self.down_convs[3](xs_down[-1])) + self.time_embeddings[3](t) + self.spatial_emb.weight[None]
        mid_feats = torch.cat((mid_feats,self.scene_tokens.weight[None].expand(n,-1,-1)),1) # add registers
        for self_attn in self.self_attns: mid_feats = self_attn(mid_feats,mid_feats)
        mid_feats = ch_fst(mid_feats[:,:-self.scene_tokens.weight.size(0)], xs_down[-1].size(-1)//2)

        # Upconvs skip-connection decoding - bilinearly upsample on concatenated skip then spatial-preserving convs
        curr_feat = mid_feats
        for conv_block,x_prev,time_emb in zip(self.conv_blocks_up,xs_down[::-1],self.time_embeddings[4:]):
            cat_feat = torch.cat((x_prev, F.interpolate(curr_feat,x_prev.shape[-2:],mode="bilinear")), dim=1)
            curr_feat = conv_block(cat_feat+time_emb(t).reshape(n,-1,1,1))

        return self.conv_out(curr_feat).tanh()

class MySceneRep_(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MySceneRep, self).__init__()

        fdim = 128

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)


        # First half
        self.time_lins = nn.ModuleList([nn.Linear(100,fdim if i else 8) for i in range(6)])
        self.conv_enc = conv_modules.PixelNeRFEncoder(in_ch=1+8,fdim_out=fdim)
        self.self_attns = nn.ModuleList([attn_modules.CrossAttn_(fdim,4,fdim//2) for _ in range(4)]).cuda()
        self.scene_tokens = nn.Embedding(8, fdim)#learnable register embeddings
        self.down_convs = nn.ModuleList([nn.Conv2d(1 if i==0 else fdim,fdim,3,2,padding=1) for i in range(2)])
        self.spatial_emb = nn.Embedding(256*256, fdim)#learnable spatial embedding
        self.comb_feats = make_net([256,128,64,32])
        self.img_dec = make_net([32,32,32,1])

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t_emb = self.time_embed(t).squeeze(1)

        time_emb_hires=self.time_lins[0](t_emb)[...,None,None].expand(-1,-1,x.size(-2),x.size(-1))
        conv_feats = self.conv_enc(torch.cat((x,time_emb_hires),1))
        
        # downconvs - todo use pretrained resnet here like in pixelnerf features instead of just raw rgb
        
        # transformer self attns
        feats2 = F.interpolate(conv_feats,scale_factor=.5)
        pos_emb_low = F.interpolate(ch_fst(self.spatial_emb.weight)[None],feats2.shape[-2:])
        feats3= torch.cat((ch_sec(feats2+pos_emb_low+self.time_lins[1](t_emb)[...,None,None]),
                           self.scene_tokens.weight[None].expand(len(x),-1,-1)),1)
        for i,attn in enumerate(self.self_attns): feats3 = attn(feats3,feats3)+self.time_lins[i+1](t_emb)[:,None]
        feats3= ch_fst(feats3[:,:-self.scene_tokens.weight.size(0)],feats2.size(-2))

        # decoder upconvs
        low_and_hi = ch_sec(torch.cat([F.interpolate(feats3,conv_feats.shape[-2:]),conv_feats],1))
        feats4 = ch_fst(self.comb_feats(low_and_hi),low_and_hi.size(-2))
        dec = ch_fst(self.img_dec(ch_sec(F.interpolate(feats4,x.shape[-2:]))),x.size(-2))

        return dec

class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True, last_act=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.last_act=last_act
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x, last_act=True):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        if self.last_act: out = self.activation(out)
        return out

class MyUNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlock((10, 14, 14), 10, 20),
            MyBlock((20, 14, 14), 20, 20),
            MyBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlock((20, 7, 7), 20, 40),
            MyBlock((40, 7, 7), 40, 40),
            MyBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlock((40, 3, 3), 40, 20),
            MyBlock((20, 3, 3), 20, 20),
            MyBlock((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlock((80, 7, 7), 80, 40),
            MyBlock((40, 7, 7), 40, 20),
            MyBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlock((40, 14, 14), 40, 20),
            MyBlock((20, 14, 14), 20, 10),
            MyBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlock((20, 28, 28), 20, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )
