"""Code for pixelnerf and alternatives."""
import torch, torchvision
from torch import nn
from einops import rearrange, repeat
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
import matplotlib.pyplot as plt;
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
import conv_modules
from collections import defaultdict
from attn_modules import CrossAttn_,PositionalEncodingNoFreqFactor

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
interp    = lambda x,y: F.interpolate(x,y,mode="bilinear",align_corners=True) 
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear",align_corners=True) # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_xpix(h,w):
  uv = np.mgrid[0 : w, 0 : h].astype(float).transpose(1, 2, 0)
  uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).float()
  uv = uv / torch.tensor([w-1, h-1])  # uv in [0,1]
  return uv.cuda()

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

class SceneLearner(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Our actual model
        fdim=64
        self.fmap_downproj = nn.Linear(512,fdim)
        self.depth_est = make_net([fdim,64,1])
        self.corr_weighter_perpoint = make_net([fdim*2,128,64,1])

        self.midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", pretrained=not self.args.scratch_net)
        self.midas_out=self.midas.scratch.output_conv
        self.midas_out[1]=self.midas.scratch.output_conv=nn.Identity()

        self.img_enc = conv_modules.PixelNeRFEncoder()

        fdim=args.fdim
        # Latent embeddings for different spatial resolutions
        self.spatial_dims = [1,4,8,16][:args.spatial_dims]
        self.latents = nn.ModuleList([nn.Embedding(4010, fdim*s**2) for s in self.spatial_dims]).cuda()
        for x in self.latents: nn.init.normal_(x.weight, mean=0, std=0.01)

        #self.decoder = CrossAttn_(512,4,128)
        n_freq=6
        self.posenc = PositionalEncodingNoFreqFactor(2,n_freq)
        self.pix_embed = nn.Linear(n_freq*4+2,fdim)
        self.decoder = nn.ModuleList([CrossAttn_(fdim,3,128) for _ in range(4)]).cuda()
        self.depth_est = make_net([fdim,64,1])
        self.rgb_est = make_net([fdim,64,3])

        #self.forward = self.forward_unet
        #self.forward = self.forward_autodecoder

        # todo: iterative render, cnn, iterative dataset,

    # Render out image iteratively
    def render_full_img(self, model_input):
        imsize=model_input["rgb"].shape[-2:]
        out_all = defaultdict(lambda :torch.tensor([]).cuda())
        with torch.no_grad():
            for j,pix in enumerate(torch.arange(imsize[0]*imsize[1]).chunk(16)):
                out = self(model_input,render_pix=pix)
                out_all["rgb"]=torch.cat([out_all["rgb"],out["rgb"]],-2)
                out_all["depth"]=torch.cat([out_all["depth"],out["depth"]],-1)
                out_all["masks"]=torch.cat([out_all["masks"],out["masks"]],-1)

        return {
            "depth":out_all["depth"].unflatten(-1,imsize),
            "rgb":ch_fst(out_all["rgb"],imsize[0]),
            "masks":rearrange(out_all["masks"],"b l1 l2 1 (x y) -> b l1 l2 1 x y",x=imsize[0]),
            }

    def forward(self, model_input, out={}, render_pix=None):

        imsize=model_input["rgb"].shape[-2:]
        b,n_latent = len(model_input["rgb"]),len(self.spatial_dims)+1
        if render_pix is None: render_pix = torch.randperm(imsize[0]*imsize[1])[:2000]
        
        # Query coordinates for decoding image
        xpix = make_xpix(*imsize).flatten(0,1)[render_pix]
        crds = repeat(self.pix_embed(self.posenc(xpix)),"xy c -> b xy l 1 c",l=n_latent,b=b)

        # Index latents and then tile to half image resolution for decoding
        if "latents" not in model_input:
            latents = [ch_fst(latent(model_input["idx"].squeeze(-1)).unflatten(-1,(s**2,-1)),s) for s,latent in zip(self.spatial_dims,self.latents)] # learnable spatial latents
            latents.append( self.img_enc(model_input["rgb"]) ) # cnn latents
            model_input["latents"] = latents
        latents = torch.stack([ch_sec(grid_samp(latent,xpix[None].expand(b,-1,-1,-1))) for latent in model_input["latents"]],2)

        # Masks -- spatial resolution bottlenecks and then random masking
        lowres_masks = repeat(torch.stack([torch.tensor([float(i<=j) for i in range(n_latent)]) for j in range(n_latent)]).cuda(),"x y -> b x y 1 1 1",b=b)
        hires_masks  = rearrange( grid_samp(lowres_masks.flatten(0,2),xpix[None,None].expand(b*n_latent**2,-1,-1,-1)) , "(b l1  l2) 1 1 xy -> b xy l1 l2 1",b=b,l1=n_latent)

        # next up is adding cnn to latent grid
        # todo render out only N pixels (very easy) to make more tractable and build outer loop rendering entire image and add CNN latent level

        # Create masked latent feature grids
        latents = repeat(latents,"b xy l c -> b xy l2 l c",l2=latents.size(2))
        latents = latents * hires_masks

        # Render image feature map at half res
        fmap = crds
        for cross_attn in self.decoder: fmap = cross_attn(latents,fmap)
        fmap = rearrange(fmap,"b xy l1 1 c -> b l1 xy c")

        # Decode into rgb/depth
        depth = F.softplus(self.depth_est(fmap)).squeeze(-1)
        rgb = F.sigmoid(self.rgb_est(fmap))

        return out | {
            "depth":depth,
            "rgb":rgb,
            "masks":hires_masks.permute(0,2,3,4,1),
            "render_pix":render_pix.cuda(),
        }

    def forward_autodecoder_global(self, model_input, out={}):

        imsize=model_input["rgb"].shape[-2:]
        latent_codes = self.latent_codes(model_input["idx"].squeeze(-1))
        
        crds = self.posenc(make_xpix(*imsize).flatten(0,1)[None].expand(len(model_input["rgb"]),-1,-1).cuda())

        # Render image feature map at half res
        fmap = self.pix_embed(crds.unflatten(1,imsize)[:,1::2,1::2].flatten(1,2))
        for cross_attn in self.decoder: fmap = cross_attn(latent_codes[:,None],fmap)

        # Upsample and decode
        fmap = F.interpolate(ch_fst(fmap,imsize[0]//2),imsize,mode="bilinear")
        depth = F.softplus(ch_fst(self.depth_est(torch.cat((crds,ch_sec(fmap)),-1)),imsize[0])).squeeze(1)
        rgb = F.sigmoid(ch_fst(self.rgb_est(torch.cat((crds,ch_sec(fmap)),-1)),imsize[0]))

        return out | {
            "depth":depth,
            "rgb":rgb,
        }

    def forward_unet(self, model_input, out={}):

        imsize=model_input["rgb"].shape[-2:]

        fmap = self.img_enc(model_input["rgb"])
        fmap = ch_fst(self.fmap_downproj(ch_sec(fmap)),fmap.size(-2))
        fmap = F.interpolate(fmap,imsize,mode="bilinear")
        depth = F.softplus(ch_fst(self.depth_est(ch_sec(fmap)),imsize[0])).squeeze(1)

        return out | {
            "depth":depth,
        }
