"""Code for pixelnerf and alternatives."""
import torch, torchvision
from torch import nn
from einops import rearrange, repeat
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
import matplotlib.pyplot as plt;
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
import conv_modules
from collections import defaultdict
from attn_modules import CrossAttn_,PositionalEncodingNoFreqFactor

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding
class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

class DDPM(nn.Module):
    def __init__(self, args, n_steps=1000, min_beta=10 ** -4, max_beta=0.02):
        super(DDPM, self).__init__()
        self.args = args
        self.n_steps = n_steps
        self.betas = torch.linspace(min_beta, max_beta, n_steps)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))])

        self.model = DenoisingSceneLearnerUnet()

    def forward(self, model_input, eta=None):

        b=model_input["rgb"].size(0)

        # Make input image more noisy (we can directly skip to the desired step)
        t = torch.randint(0, self.n_steps, (b,)).cuda()
        a_bar = self.alpha_bars.cuda()[t]

        eta = torch.randn_like(model_input["rgb"])
        #if eta is None: eta = torch.randn(n, c, h, w).cuda()

        noisy = a_bar.sqrt()[:,None,None,None] * model_input["rgb"] + (1 - a_bar).sqrt()[:,None,None,None] * eta

        eta_theta = self.model(noisy,t)

        return {"noised_img":noisy,"eta":eta,"eta_est":eta_theta}

    def generate_new_images(self, n_samples=4, c=1, h=28, w=28):
        """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
        frame_idxs = np.linspace(0, self.n_steps, 10).astype(np.uint)
        intermed_gens = []

        with torch.no_grad():

            # Starting from random noise
            x = torch.randn(n_samples, c, h, w).cuda()

            for idx, t in enumerate(list(range(self.n_steps))[::-1]):
                # Estimating noise to be removed
                time_tensor = (torch.ones(n_samples, 1) * t).cuda().long()
                eta_theta = self.model(x, time_tensor)

                alpha_t = self.alphas[t]
                alpha_t_bar = self.alpha_bars[t]

                # Partially denoising the image
                x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

                if t:
                    z = torch.randn(n_samples, c, h, w).cuda()

                    beta_t = self.betas[t]
                    sigma_t = beta_t.sqrt()

                    # Adding some more noise like in Langevin Dynamics fashion
                    x = x + sigma_t * z

                # Adding frames to the GIF
                if idx in frame_idxs: intermed_gens.append(x)
        return torch.stack(intermed_gens)

class DenoisingSceneLearnerUnet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(DenoisingSceneLearnerUnet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlock((10, 14, 14), 10, 20),
            MyBlock((20, 14, 14), 20, 20),
            MyBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlock((20, 7, 7), 20, 40),
            MyBlock((40, 7, 7), 40, 40),
            MyBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlock((40, 3, 3), 40, 20),
            MyBlock((20, 3, 3), 20, 20),
            MyBlock((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlock((80, 7, 7), 80, 40),
            MyBlock((40, 7, 7), 40, 20),
            MyBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlock((40, 14, 14), 40, 20),
            MyBlock((20, 14, 14), 20, 10),
            MyBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlock((20, 28, 28), 20, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )


class SceneLearner(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Our actual model
        fdim=64
        self.fmap_downproj = nn.Linear(512,fdim)
        self.depth_est = make_net([fdim,64,1])
        self.corr_weighter_perpoint = make_net([fdim*2,128,64,1])

        self.midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", pretrained=not self.args.scratch_net)
        self.midas_out=self.midas.scratch.output_conv
        self.midas_out[1]=self.midas.scratch.output_conv=nn.Identity()

        self.img_enc = conv_modules.PixelNeRFEncoder()

        fdim=args.fdim
        # Latent embeddings for different spatial resolutions
        self.spatial_dims = [1,4,8,16][:args.spatial_dims]
        self.latents = nn.ModuleList([nn.Embedding(4010, fdim*s**2) for s in self.spatial_dims]).cuda()
        for x in self.latents: nn.init.normal_(x.weight, mean=0, std=0.01)

        #self.decoder = CrossAttn_(512,4,128)
        n_freq=6
        self.posenc = PositionalEncodingNoFreqFactor(2,n_freq)
        self.pix_embed = nn.Linear(n_freq*4+2,fdim)
        self.decoder = nn.ModuleList([CrossAttn_(fdim,3,128) for _ in range(4)]).cuda()
        self.depth_est = make_net([fdim,64,1])
        self.rgb_est = make_net([fdim,64,3])

    # Render out image iteratively
    def render_full_img(self, model_input):
        imsize=model_input["rgb"].shape[-2:]
        out_all = defaultdict(lambda :torch.tensor([]).cuda())
        with torch.no_grad():
            for j,pix in enumerate(torch.arange(imsize[0]*imsize[1]).chunk(16)):
                out = self(model_input,render_pix=pix)
                out_all["rgb"]=torch.cat([out_all["rgb"],out["rgb"]],-2)
                out_all["depth"]=torch.cat([out_all["depth"],out["depth"]],-1)
                out_all["masks"]=torch.cat([out_all["masks"],out["masks"]],-1)

        return {
            "depth":out_all["depth"].unflatten(-1,imsize),
            "rgb":ch_fst(out_all["rgb"],imsize[0]),
            "masks":rearrange(out_all["masks"],"b l1 l2 1 (x y) -> b l1 l2 1 x y",x=imsize[0]),
            }

    def forward(self, model_input, out={}, render_pix=None):

        imsize=model_input["rgb"].shape[-2:]
        b,n_latent = len(model_input["rgb"]),len(self.spatial_dims)+1
        if render_pix is None: render_pix = torch.randperm(imsize[0]*imsize[1])[:2000]
        
        # Query coordinates for decoding image
        xpix = make_xpix(*imsize).flatten(0,1)[render_pix]
        crds = repeat(self.pix_embed(self.posenc(xpix)),"xy c -> b xy l 1 c",l=n_latent,b=b)

        # Index latents and then tile to half image resolution for decoding
        if "latents" not in model_input:
            latents = [ch_fst(latent(model_input["idx"].squeeze(-1)).unflatten(-1,(s**2,-1)),s) for s,latent in zip(self.spatial_dims,self.latents)] # learnable spatial latents
            latents.append( self.img_enc(model_input["rgb"]) ) # cnn latents
            model_input["latents"] = latents
        latents = torch.stack([ch_sec(grid_samp(latent,xpix[None].expand(b,-1,-1,-1))) for latent in model_input["latents"]],2)

        # Masks -- spatial resolution bottlenecks and then random masking
        lowres_masks = repeat(torch.stack([torch.tensor([float(i<=j) for i in range(n_latent)]) for j in range(n_latent)]).cuda(),"x y -> b x y 1 1 1",b=b)
        hires_masks  = rearrange( grid_samp(lowres_masks.flatten(0,2),xpix[None,None].expand(b*n_latent**2,-1,-1,-1)) , "(b l1  l2) 1 1 xy -> b xy l1 l2 1",b=b,l1=n_latent)

        # next up is adding cnn to latent grid
        # todo render out only N pixels (very easy) to make more tractable and build outer loop rendering entire image and add CNN latent level

        # Create masked latent feature grids
        latents = repeat(latents,"b xy l c -> b xy l2 l c",l2=latents.size(2))
        latents = latents * hires_masks

        # Render image feature map at half res
        fmap = crds
        for cross_attn in self.decoder: fmap = cross_attn(latents,fmap)
        fmap = rearrange(fmap,"b xy l1 1 c -> b l1 xy c")

        # Decode into rgb/depth
        depth = F.softplus(self.depth_est(fmap)).squeeze(-1)
        rgb = F.sigmoid(self.rgb_est(fmap))

        return out | {
            "depth":depth,
            "rgb":rgb,
            "masks":hires_masks.permute(0,2,3,4,1),
            "render_pix":render_pix.cuda(),
        }

    def forward_autodecoder_global(self, model_input, out={}):

        imsize=model_input["rgb"].shape[-2:]
        latent_codes = self.latent_codes(model_input["idx"].squeeze(-1))
        
        crds = self.posenc(make_xpix(*imsize).flatten(0,1)[None].expand(len(model_input["rgb"]),-1,-1).cuda())

        # Render image feature map at half res
        fmap = self.pix_embed(crds.unflatten(1,imsize)[:,1::2,1::2].flatten(1,2))
        for cross_attn in self.decoder: fmap = cross_attn(latent_codes[:,None],fmap)

        # Upsample and decode
        fmap = F.interpolate(ch_fst(fmap,imsize[0]//2),imsize,mode="bilinear")
        depth = F.softplus(ch_fst(self.depth_est(torch.cat((crds,ch_sec(fmap)),-1)),imsize[0])).squeeze(1)
        rgb = F.sigmoid(ch_fst(self.rgb_est(torch.cat((crds,ch_sec(fmap)),-1)),imsize[0]))

        return out | {
            "depth":depth,
            "rgb":rgb,
        }

    def forward_unet(self, model_input, out={}):

        imsize=model_input["rgb"].shape[-2:]

        fmap = self.img_enc(model_input["rgb"])
        fmap = ch_fst(self.fmap_downproj(ch_sec(fmap)),fmap.size(-2))
        fmap = F.interpolate(fmap,imsize,mode="bilinear")
        depth = F.softplus(ch_fst(self.depth_est(ch_sec(fmap)),imsize[0])).squeeze(1)

        return out | {
            "depth":depth,
        }
