import torch
import torch.nn as nn
import numpy as np
import attn_modules
from torch.nn import functional as F

from einops import rearrange, repeat
import conv_modules

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)


def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network
        self.betas = torch.linspace(min_beta, max_beta, n_steps).cuda()  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).cuda()

    def forward_noise(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy

    def generate_new_images(self, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28):
        """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
        frame_idxs = np.linspace(0, self.n_steps, frames_per_gif).astype(np.uint)
        frames = []
        xs=[]

        with torch.no_grad():

            # Starting from random noise
            x = torch.randn(n_samples, c, h, w).cuda()

            for idx, t in enumerate(list(range(self.n_steps))[::-1]):
                # Estimating noise to be removed
                time_tensor = (torch.ones(n_samples, 1) * t).cuda().long()
                eta_theta = self.network(x, time_tensor)

                alpha_t = self.alphas[t]
                alpha_t_bar = self.alpha_bars[t]

                # Partially denoising the image
                x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

                if t:
                    z = torch.randn(n_samples, c, h, w).cuda()
                    beta_t = self.betas[t]
                    sigma_t = beta_t.sqrt()
                    # Adding some more noise like in Langevin Dynamics fashion
                    x = x + sigma_t * z

                # Adding frames to the GIF
                if idx in frame_idxs or t == 0: xs.append(x)

            vis={"sampled_gen":torch.stack(xs)}
            return vis

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

class MySceneRep(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MySceneRep, self).__init__()
        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        res=(28,28)

        # Downconvs
        dims = [1,20,40,80]
        self.conv_blocks_down = nn.ModuleList([
            nn.Sequential(*[
                    MyBlock((dims[i+int(j>0)],  res[0]//2**i, res[1]//2**i), dims[i+int(j>0)], dims[i+1])
                     for j in range(3)]
                )
            for i in range(3)])
        self.down_convs = nn.ModuleList([nn.Identity()]+[nn.Conv2d(dim, dim, 4, 2, 1) for dim in dims[1:4]])
    
        # Layer/Dimension-specific time projections
        self.time_embeddings = nn.ModuleList([ nn.Sequential(nn.Linear(time_emb_dim, d), nn.SiLU(), nn.Linear(d, d))
                                                                            for d in dims+[d*2 for d in dims[::-1]]])
        self.time_embeddings[-1]=nn.Identity()

        # Bottleneck (replace with transformer)
        self.te_mid = nn.Sequential(nn.Linear(time_emb_dim, 80), nn.SiLU(), nn.Linear(80, 80))#self._make_te(time_emb_dim, 80)
        self.b_mid = nn.Sequential(
            MyBlock((80, 3, 3), 80, 80),
            MyBlock((80, 3, 3), 80, 80),
            MyBlock((80, 3, 3), 80, 80)
        )

        #self.self_attns = nn.ModuleList([attn_modules.CrossAttn_(fdim,4,fdim//2) for _ in range(4)]).cuda()
        #self.scene_tokens = nn.Embedding(8, fdim) #learnable register embeddings
        #self.spatial_emb = nn.Embedding(256*256, fdim)#learnable spatial embedding

        # Upconvs for skip-connected decoding
        self.conv_blocks_up = nn.ModuleList([
            nn.Sequential(*[
                    MyBlock((2*dims[i+int(j==0)],  res[0]//2**i, res[1]//2**i), dims[i+int(j==0)]*2, dims[i]*(2 if j<2 else 1))
                     for j in range(3)]
                )
            for i in [2,1,0]])

    def forward(self, x, t):
        t = self.time_embed(t) # sin embedding of time
        n = len(x)

        # Convs down - encoder
        xs_down=[x]
        for conv_block, down_conv, time_emb in zip(self.conv_blocks_down,self.down_convs,self.time_embeddings):
            xs_down.append( conv_block(down_conv(xs_down[-1]) + time_emb(t).reshape(n,-1,1,1)) )

        # low latent space thinking - replace with transformer
        out_mid = self.b_mid(self.down_convs[3](xs_down[-1]) + self.time_embeddings[3](t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        # Convs up - decoder
        curr_feat = out_mid
        for conv_block,x_prev,time_emb in zip(self.conv_blocks_up,xs_down[::-1],self.time_embeddings[4:]):
            cat_feat = torch.cat((x_prev, F.interpolate(curr_feat,x_prev.shape[-2:],mode="bilinear")), dim=1)
            curr_feat = conv_block(cat_feat+time_emb(t).reshape(n,-1,1,1))

        return curr_feat

class MySceneRep_(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MySceneRep, self).__init__()

        fdim = 128

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)


        # First half
        self.time_lins = nn.ModuleList([nn.Linear(100,fdim if i else 8) for i in range(6)])
        self.conv_enc = conv_modules.PixelNeRFEncoder(in_ch=1+8,fdim_out=fdim)
        self.self_attns = nn.ModuleList([attn_modules.CrossAttn_(fdim,4,fdim//2) for _ in range(4)]).cuda()
        self.scene_tokens = nn.Embedding(8, fdim)#learnable register embeddings
        self.down_convs = nn.ModuleList([nn.Conv2d(1 if i==0 else fdim,fdim,3,2,padding=1) for i in range(2)])
        self.spatial_emb = nn.Embedding(256*256, fdim)#learnable spatial embedding
        self.comb_feats = make_net([256,128,64,32])
        self.img_dec = make_net([32,32,32,1])

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t_emb = self.time_embed(t).squeeze(1)

        time_emb_hires=self.time_lins[0](t_emb)[...,None,None].expand(-1,-1,x.size(-2),x.size(-1))
        conv_feats = self.conv_enc(torch.cat((x,time_emb_hires),1))
        
        # downconvs - todo use pretrained resnet here like in pixelnerf features instead of just raw rgb
        
        # transformer self attns
        feats2 = F.interpolate(conv_feats,scale_factor=.5)
        pos_emb_low = F.interpolate(ch_fst(self.spatial_emb.weight)[None],feats2.shape[-2:])
        feats3= torch.cat((ch_sec(feats2+pos_emb_low+self.time_lins[1](t_emb)[...,None,None]),
                           self.scene_tokens.weight[None].expand(len(x),-1,-1)),1)
        for i,attn in enumerate(self.self_attns): feats3 = attn(feats3,feats3)+self.time_lins[i+1](t_emb)[:,None]
        feats3= ch_fst(feats3[:,:-self.scene_tokens.weight.size(0)],feats2.size(-2))

        # decoder upconvs
        low_and_hi = ch_sec(torch.cat([F.interpolate(feats3,conv_feats.shape[-2:]),conv_feats],1))
        feats4 = ch_fst(self.comb_feats(low_and_hi),low_and_hi.size(-2))
        dec = ch_fst(self.img_dec(ch_sec(F.interpolate(feats4,x.shape[-2:]))),x.size(-2))

        return dec

class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

class MyUNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlock((10, 14, 14), 10, 20),
            MyBlock((20, 14, 14), 20, 20),
            MyBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlock((20, 7, 7), 20, 40),
            MyBlock((40, 7, 7), 40, 40),
            MyBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlock((40, 3, 3), 40, 20),
            MyBlock((20, 3, 3), 20, 20),
            MyBlock((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlock((80, 7, 7), 80, 40),
            MyBlock((40, 7, 7), 40, 20),
            MyBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlock((40, 14, 14), 40, 20),
            MyBlock((20, 14, 14), 20, 10),
            MyBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlock((20, 28, 28), 20, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )
