import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict

import wandb

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from einops import rearrange


#unsqueeze3x = lambda x: x[..., None, None, None]
unsqueeze3x = lambda x: x[..., None, None, None,None]
small_const=1e-5

class GaussianDiffusion:
    """Gaussian diffusion process with 1) Cosine schedule for beta values (https://arxiv.org/abs/2102.09672)
    2) L_simple training objective from https://arxiv.org/abs/2006.11239.
    """

    def __init__(self, timesteps=1000, device="cuda:0"):
        self.timesteps = timesteps
        self.device = device
        self.alpha_bar_scheduler = (
            lambda t: math.cos((t / self.timesteps + 0.008) / 1.008 * math.pi / 2) ** 2
        )
        self.scalars = self.get_all_scalars(
            self.alpha_bar_scheduler, self.timesteps, self.device
        )

        self.clamp_x0 = lambda x: x.clamp(-1, 1)
        self.get_x0_from_xt_eps = lambda xt, eps, t, scalars: (
            self.clamp_x0( 1 / (small_const+unsqueeze3x(scalars.alpha_bar[t].sqrt())) * (xt - unsqueeze3x((1 - scalars.alpha_bar[t]).sqrt()) * eps))
        )
        self.get_pred_mean_from_x0_xt = (
            lambda xt, x0, t, scalars: unsqueeze3x( (scalars.alpha_bar[t].sqrt() * scalars.beta[t]) / (small_const+(1 - scalars.alpha_bar[t]) * scalars.alpha[t].sqrt())) * x0
            + unsqueeze3x( (scalars.alpha[t] - scalars.alpha_bar[t]) / (small_const+(1 - scalars.alpha_bar[t]) * scalars.alpha[t].sqrt())) * xt
        )

    def get_all_scalars(self, alpha_bar_scheduler, timesteps, device, betas=None):
        """
        Using alpha_bar_scheduler, get values of all scalars, such as beta, beta_hat, alpha, alpha_hat, etc.
        """
        all_scalars = {}
        if betas is None:
            all_scalars["beta"] = torch.from_numpy(
                np.array(
                    [
                        min(
                            1 - alpha_bar_scheduler(t + 1) / alpha_bar_scheduler(t),
                            0.999,
                        )
                        for t in range(timesteps)
                    ]
                )
            ).to(
                device
            )  # hardcoding beta_max to 0.999
        else:
            all_scalars["beta"] = betas
        all_scalars["beta_log"] = torch.log(all_scalars["beta"])
        all_scalars["alpha"] = 1 - all_scalars["beta"]
        all_scalars["alpha_bar"] = torch.cumprod(all_scalars["alpha"], dim=0)
        all_scalars["beta_tilde"] = (
            all_scalars["beta"][1:]
            * (1 - all_scalars["alpha_bar"][:-1])
            / (1 - all_scalars["alpha_bar"][1:])
        )
        all_scalars["beta_tilde"] = torch.cat(
            [all_scalars["beta_tilde"][0:1], all_scalars["beta_tilde"]]
        )
        all_scalars["beta_tilde_log"] = torch.log(all_scalars["beta_tilde"])
        return EasyDict(dict([(k, v.float()) for (k, v) in all_scalars.items()]))

    def sample_from_forward_process(self, x0, t):
        """Single step of the forward process, where we add noise in the image.
        Note that we will use this paritcular realization of noise vector (eps) in training.
        """
        eps = torch.randn_like(x0)

        # For video diffusion set first frame to clean image for conditioning
        #eps[:,0] = 0

        xt = (
            unsqueeze3x(self.scalars.alpha_bar[t].sqrt()) * x0
            + unsqueeze3x((1 - self.scalars.alpha_bar[t]).sqrt()) * eps
        )

        return xt.float(), eps

    def diff_sample_from_reverse_process(
        self, model, xT, timesteps=None, model_kwargs={}, ddim=False, scene_graph=False, cameras=None,conditioning=None,use_direct=False,model_input=None,teacher_forcing=True,
    ):

        ddim=True

        """Sampling images by iterating over all timesteps.

        model: diffusion model
        xT: Starting noise vector.
        timesteps: Number of sampling steps (can be smaller the default,
            i.e., timesteps in the diffusion process).
        model_kwargs: Additional kwargs for model (using it to feed class label for conditioning)
        ddim: Use ddim sampling (https://arxiv.org/abs/2010.02502). With very small number of
            sampling steps, use ddim sampling for better image quality.

        Return: An image tensor with identical shape as XT.
        """

        final = xT
        #final_cameras=cameras

        # sub-sampling timesteps for faster sampling
        timesteps = timesteps or self.timesteps
        new_timesteps = np.linspace( 0, self.timesteps - 1, num=timesteps, endpoint=True, dtype=int)
        alpha_bar = self.scalars["alpha_bar"][new_timesteps]
        new_betas = 1 - ( alpha_bar / torch.nn.functional.pad(alpha_bar, [1, 0], value=1.0)[:-1])
        scalars = self.get_all_scalars( self.alpha_bar_scheduler, timesteps, self.device, new_betas)

        #gen_name="rgb" if not scene_graph else "scene_graph"

        eps_loss_all = []
        teacher_forcing=teacher_forcing and "scene_graph" in model_input
        print("using teacher forcing :",teacher_forcing)

        #print("starting diff")
        intermeds={k[7:]:[] for k,v in final.items()}
        for i, t in zip(np.arange(timesteps)[::-1], new_timesteps[::-1]):

            if teacher_forcing:
                xt_scene_graph, eps_scene_graph = self.sample_from_forward_process(rearrange(model_input["scene_graph"],"b o c -> b 1 c o 1"), t)
                final["noised_scene_graph"] = rearrange(xt_scene_graph,"b 1 c o 1 -> b o c ")
                final["eps_scene_graph"] = rearrange(eps_scene_graph,"b 1 c o 1 -> b o c ")
                xt_cameras, eps_cameras = self.sample_from_forward_process(rearrange(model_input["cameras"],"b o c -> b 1 c o 1"), t)
                final["noised_cameras"] = rearrange(xt_cameras,"b 1 c o 1 -> b o c ")
                final["eps_cameras"] = rearrange(eps_cameras,"b 1 c o 1 -> b o c ")

            #print(i,t)
            current_t = torch.tensor([t] * len(list(final.values())[0]), device=self.device)
            current_sub_t = torch.tensor([i] * len(list(final.values())[0]), device=self.device)

            model_out, scene_graph_latent = model.denoise_scene_graph(final | conditioning, current_t)
            pred_epsilon = {k.replace("noised","eps"):model_out[k.replace("noised","eps")] for k in final.keys() if k.replace("noised","eps") in model_out}

            eps_loss_all.append({k:(final[k]-model_out[k]).square().mean() if teacher_forcing else torch.zeros_like(v) for k,v in pred_epsilon.items()}) # fill in when doing teacher forcing

            final,pred_epsilon=[{k:rearrange(v,"b o c -> b 1 c o 1") if "rgb" not in k else v for k,v in d.items()} for d in [final,pred_epsilon]]
            pred_x0 = {k:self.get_x0_from_xt_eps( final[k.replace("eps","noised")], v, current_sub_t, scalars) for k,v in pred_epsilon.items()}
            pred_mean = {k:self.get_pred_mean_from_x0_xt( final[k.replace("eps","noised")], v, current_sub_t, scalars) for k,v in pred_x0.items()}
            if i == 0: final = pred_mean
            else:
                final = {k.replace("eps","noised"):  ( unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1]).sqrt() * pred_x0[k]
                       + ( 1 - unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1])).sqrt() * pred_epsilon[k])
                    for k in pred_mean.keys()}
            final,pred_epsilon=[{k:rearrange(v,"b 1 c o 1 -> b o c ") if "rgb" not in k else v for k,v in d.items()} for d in [final,pred_epsilon]]
            for k,v in pred_mean.items():[intermeds[k[4:]].append(v)]

        return ({k[4:]:v for k,v in final.items()},{k:torch.stack(v,2) for k,v in intermeds.items() if len(v)}, 
                {"epsloss_"+k:torch.stack([ x[k][None] for x in eps_loss_all ]) for k in eps_loss_all[0].keys()}, scene_graph_latent)

    def sample_from_reverse_process(
        self, model, xT, timesteps=None, model_kwargs={}, ddim=False, scene_graph=False, cameras=None,conditioning=None,use_direct=False,
    ):

        ddim=False # only use for N<10 samples

        """Sampling images by iterating over all timesteps.

        model: diffusion model
        xT: Starting noise vector.
        timesteps: Number of sampling steps (can be smaller the default,
            i.e., timesteps in the diffusion process).
        model_kwargs: Additional kwargs for model (using it to feed class label for conditioning)
        ddim: Use ddim sampling (https://arxiv.org/abs/2010.02502). With very small number of
            sampling steps, use ddim sampling for better image quality.

        Return: An image tensor with identical shape as XT.
        """
        model.eval()

        final = xT
        #final_cameras=cameras

        # sub-sampling timesteps for faster sampling
        timesteps = timesteps or self.timesteps
        new_timesteps = np.linspace( 0, self.timesteps - 1, num=timesteps, endpoint=True, dtype=int)
        alpha_bar = self.scalars["alpha_bar"][new_timesteps]
        new_betas = 1 - ( alpha_bar / torch.nn.functional.pad(alpha_bar, [1, 0], value=1.0)[:-1])
        scalars = self.get_all_scalars( self.alpha_bar_scheduler, timesteps, self.device, new_betas)

        #gen_name="rgb" if not scene_graph else "scene_graph"

        print("starting diff")
        sg,cam=None,None
        intermeds={k[7:]:[] for k in list(final.keys())+["noised_cameras","noised_scene_graph"]}
        for i, t in zip(np.arange(timesteps)[::-1], new_timesteps[::-1]):
            with torch.no_grad():
                #print(i,t)
                current_t = torch.tensor([t] * len(list(final.values())[0]), device=self.device)
                current_sub_t = torch.tensor([i] * len(list(final.values())[0]), device=self.device)

                #final[:,:1]=xT[:,:1].detach()

                vanilla=True and not use_direct

                # When predicting residual (default code) use this
                direct_rgb=(t>700 and not vanilla) or use_direct

                #if sg is not None:
                #    conditioning["scene_graph"] = sg
                #    conditioning["cameras"] = cam

                model_out = model(final | conditioning, current_t, use_skip=True, sample_rand_global=t>900 and not vanilla, 
                            clip_global_latent=t>500 and not vanilla and 1,**model_kwargs,teacher_forcing=False)

                #if sg is None:sg,cam=model_out["scene_graph"],model_out["cameras"]

                #model_out = model.denoise_scene_graph(final | conditioning, current_t)[0]
                pred_epsilon = {k.replace("noised","eps"):model_out[k.replace("noised","eps")] for k in final.keys() if k.replace("noised","eps") in model_out}

                # NOTE hacking overriding direct as use conditioning or not;refactor
                direct_rgb=False

                #else:
                #    model_out = model({"noised_scene_graph":final,"noised_cameras":final_cameras}, current_t, use_skip=not direct_rgb, 
                #                        sample_rand_global=t>900 and not vanilla, clip_global_latent=t>500 and not vanilla and 1,**model_kwargs)
                #    pred_epsilon,pred_epsilon_cameras = model_out["eps_scene_graph"],model_out["eps_cameras"]
                # using xt+x0 to derive mu_t, instead of using xt+eps (former is more stable)

                #pred_epsilon,final=[rearrange(x,"b o c -> b 1 c o 1") for x in [pred_epsilon,final]]

                if not direct_rgb:
                    final,pred_epsilon=[{k:rearrange(v,"b o c -> b 1 c o 1") if "rgb" not in k else v for k,v in d.items()} for d in [final,pred_epsilon]]
                    pred_x0 = {k:self.get_x0_from_xt_eps( final[k.replace("eps","noised")], v, current_sub_t, scalars) for k,v in pred_epsilon.items()}
                    pred_mean = {k:self.get_pred_mean_from_x0_xt( final[k.replace("eps","noised")], v, current_sub_t, scalars) for k,v in pred_x0.items()}
                else:
                    final={k:rearrange(model_out[k[7:]],"b o c -> b 1 c o 1") if "rgb" not in k else model_out[k[7:]] for k in final.keys()}
                    # When trying to regress image directly instead of noise, use this
                    #final={k:rearrange(v,"b o c -> b 1 c o 1") if "rgb" not in k else v for k,v in model_out.items()}
                    pred_mean = {k.replace("noised","eps"):v for k,v in final.items()}
                    #pred_mean = model_out[gen_name]

                if i == 0:
                    final = pred_mean
                else:
                    if ddim and not direct_rgb: 
                        final = {k.replace("eps","noised"):  ( unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1]).sqrt() * pred_x0[k]
                               + ( 1 - unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1])).sqrt() * pred_epsilon[k])
                            for k in pred_mean.keys()}
                        #final = ( unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1]).sqrt() * pred_x0
                        #       + ( 1 - unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1])).sqrt() * pred_epsilon)
                    else:
                        final = {k.replace("eps","noised"): v + unsqueeze3x( scalars.beta_tilde[current_sub_t].sqrt()) 
                            * torch.randn_like(v) for k,v in pred_mean.items()}
                            #* torch.randn_like(final[k.replace("eps","noised")]) for k,v in pred_mean.items()}

                if direct_rgb: final={k:rearrange(v.detach(),"b 1 c o 1 -> b o c ") if "rgb" not in k else v.detach() for k,v in final.items()}
                else: final,pred_epsilon=[{k:rearrange(v.detach(),"b 1 c o 1 -> b o c ") if "rgb" not in k else v.detach() for k,v in d.items()} for d in [final,pred_epsilon]]
                #pred_epsilon,final=[rearrange(x,"b 1 c o 1 -> b o c ") for x in [pred_epsilon,final]]

                #final = final.detach()
                if i in np.arange(timesteps)[np.round(np.linspace(0, len(np.arange(timesteps)) - 1, min(10,timesteps))).astype(int)]: 
                    print(direct_rgb)
                    for k,v in pred_mean.items():[intermeds[k[4:]].append(v)]
                    #intermeds["scene_graph"].append(model_out["scene_graph"].permute(0,2,1)[:,None])
                    #intermeds["cameras"].append(model_out["cameras"][:,None])

        #final["eps_scene_graph"],final["eps_cameras"]=model_out["scene_graph"],model_out["cameras"]
        return {k[4:]:v for k,v in final.items()},{k:torch.stack(v,2) for k,v in intermeds.items() if len(v)}
def sample_N_images(
    N,
    model,
    diffusion,
    xT_=None,
    sampling_steps=250,
    batch_size=64,
    num_channels=3,
    image_size=32,
    n_frames=1,
    num_classes=None,
    args=None,
    scene_graph=False,
    cameras=None,
    conditioning=None,
    use_direct=False,
):
    """use this function to sample any number of images from a given
        diffusion model and diffusion process.

    Args:
        N : Number of images
        model : Diffusion model
        diffusion : Diffusion process
        xT : Starting instantiation of noise vector.
        sampling_steps : Number of sampling steps.
        batch_size : Batch-size for sampling.
        num_channels : Number of channels in the image.
        image_size : Image size (assuming square images).
        num_classes : Number of classes in the dataset (needed for class-conditioned models)
        args : All args from the argparser.

    Returns: Numpy array with N images and corresponding labels.
    """
    samples, intermed_samples, labels, num_samples = [], [], [], 0
    image_size=xT_["rgb"].size(-1)
    num_processes=1
    with tqdm(total=math.ceil(N / (args.batch_size * num_processes))) as pbar:
        while num_samples < N:
            #xT = ( torch.randn(batch_size, n_frames, num_channels, image_size, image_size) .float() .to(args.device))

            xT = {"noised_"+k: torch.randn(batch_size, *v.shape[1:]) .float() .to(args.device) for k,v in xT_.items()}
            #if cameras is not None: cameras = ( torch.randn(batch_size, *cameras.shape[1:]) .float() .to(args.device))
            #xT[:,:1]=xT_[:len(xT),:1]
            #if args.class_cond: y = torch.randint(num_classes, (len(xT),), dtype=torch.int64).to( args.device)
            y = None
            gen_images,gen_intermeds = diffusion.sample_from_reverse_process( model, xT, sampling_steps, {"y": y}, args.ddim,scene_graph=scene_graph,cameras=cameras,conditioning=conditioning, use_direct=use_direct)
            #samples_list = [torch.zeros_like(gen_images) for _ in range(num_processes)]
            if args.class_cond:
                #labels_list = [torch.zeros_like(y) for _ in range(num_processes)]
                #dist.all_gather(labels_list, y, group)
                labels_list = [y]
                labels.append(torch.cat(labels_list).detach().cpu().numpy())

            #dist.all_gather(samples_list, gen_images, group)
            #samples.append(torch.cat(samples_list).detach().cpu().numpy())
            num_samples += len(list(xT.values())[0]) * num_processes
            pbar.update(1)
    return gen_images,gen_intermeds#torch.cat(samples),torch.cat(intermed_samples)
    samples = np.concatenate(samples).transpose(0, 2, 3, 1)#[:N]
    samples = (127.5 * (samples + 1)).astype(np.uint8)
    intermed_samples = (127.5 * (gen_intermeds + 1)).cpu().numpy().astype(np.uint8)
    return (samples, intermed_samples, np.concatenate(labels) if args.class_cond else None)
