import os
import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict

import wandb

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from torcheval.metrics import FrechetInceptionDistance

from data import get_metadata, get_dataset, fix_legacy_dict
import unets, models, diffusion

from torchvision.utils import make_grid

def to_gpu(ob): return {k: to_gpu(v) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

class loss_logger:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.loss = []
        self.start_time = time()
        self.ema_loss = None
        self.ema_w = 0.9

    def log(self, v, display=False):
        self.loss.append(v)
        if self.ema_loss is None:
            self.ema_loss = v
        else:
            self.ema_loss = self.ema_w * self.ema_loss + (1 - self.ema_w) * v
        if display:
            print( f"Steps: {len(self.loss)}/{self.max_steps} \t loss (ema): {self.ema_loss:.3f} " + f"\t Time elapsed: {(time() - self.start_time)/3600:.3f} hr")

def main():
    parser = argparse.ArgumentParser("Minimal implementation of diffusion models")
    # diffusion model
    parser.add_argument("--arch", type=str, help="Neural network architecture")
    parser.add_argument( "--class-cond", action="store_true", default=False, help="train class-conditioned diffusion model",)
    parser.add_argument( "--diffusion-steps", type=int, default=1000, help="Number of timesteps in diffusion process",)
    parser.add_argument( "--sampling-steps", type=int, default=250, help="Number of timesteps in diffusion process",)
    parser.add_argument( "--ddim", action="store_true", default=False, help="Sampling using DDIM update step",)
    parser.add_argument( "--no-diffusion", action="store_true", default=False, help="whether to use diffusion target",)
    parser.add_argument( "--no-labels", action="store_true", default=False, help="whether to use the gt labels",)
    parser.add_argument( "--no-gt-labels", action="store_true", default=False, help="whether to use gt labels target if dataset provides it",)
    # dataset
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--data-dir", type=str, default="./dataset/")
    # optimizer
    parser.add_argument( "--batch-size", type=int, default=128, help="batch-size per gpu")
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--ema_w", type=float, default=0.9995)
    # sampling/finetuning
    parser.add_argument("--pretrained-ckpt", type=str, help="Pretrained model ckpt")
    parser.add_argument("--delete-keys", nargs="+", help="Pretrained model ckpt")
    parser.add_argument( "--sampling-only", action="store_true", default=False, help="No training, just sample images (will save them in --save-dir)",)
    parser.add_argument( "--num-sampled-images", type=int, default=50000, help="Number of images required to sample from the model",)
    # misc
    parser.add_argument("--save-dir", type=str, default="./trained_models/")
    parser.add_argument("--name", type=str, default="test")
    parser.add_argument("--n_workers", default=8, type=int)
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--seed", default=112233, type=int)
    parser.add_argument("--online", default=False,required=False)
    parser.add_argument("--no_sampling", default=False,required=False)

    # setup
    args = parser.parse_args()
    metadata = get_metadata(args.dataset)
    torch.backends.cudnn.benchmark = True
    args.device = "cuda:{}".format(args.local_rank)
    torch.cuda.set_device(args.device)
    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    if args.local_rank == 0:
        print(args)

    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")

    # Creat model and diffusion process
    model = models.__dict__[args.arch](
        image_size=metadata.image_size,
        in_channels=metadata.num_channels,
        out_channels=metadata.num_channels,
        num_classes=metadata.num_classes if args.class_cond else None,
    ).to(args.device)
    diffusion_ = diffusion.GaussianDiffusion(args.diffusion_steps, args.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    # load pre-trained model
    if args.pretrained_ckpt:
        print(f"Loading pretrained model from {args.pretrained_ckpt}")
        d = fix_legacy_dict(torch.load(args.pretrained_ckpt, map_location=args.device))
        dm = model.state_dict()
        if args.delete_keys:
            for k in args.delete_keys:
                print( f"Deleting key {k} becuase its shape in ckpt ({d[k].shape}) doesn't match " + f"with shape in model ({dm[k].shape})") 
                del d[k]
        model.load_state_dict(d, strict=False)
        print( f"Mismatched keys in ckpt and model: ", set(d.keys()) ^ set(dm.keys()),)
        print(f"Loaded pretrained model from {args.pretrained_ckpt}")

    # Load dataset
    train_set = get_dataset(args.dataset, args.data_dir, metadata)
    train_loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, sampler=None, num_workers=args.n_workers, pin_memory=True,)
    if args.local_rank == 0: print( f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}")
    logger = loss_logger(len(train_loader) * args.epochs)

    # ema model
    args.ema_dict = copy.deepcopy(model.state_dict())

    # lets start training the model
    global_step=-1
    for epoch in range(args.epochs):

        model.train()
        for step, model_input in tqdm(enumerate(train_loader)): 
            global_step+=1
            #assert (model_input["rgb"].max().item() <= 1) and (0 <= model_input["rgb"].min().item())
            #print(step)

            model_input = to_gpu(model_input)

            log_dict={}

            # must use [-1, 1] pixel range for images
            #images, labels = ( 2 * images.to(args.device) - 1, labels.to(args.device) if args.class_cond else None,)
            images = model_input["rgb"]
            t = torch.randint(diffusion_.timesteps, (len(images),), dtype=torch.int64).to(args.device)
            xt, eps = diffusion_.sample_from_forward_process(images, t)
            model_input["noised_rgb"] = xt
            model_input["eps"] = eps

            #args.no_diffusion= (step%2==1 and not args.no_labels)
            #print("alternating diffusion",args.no_diffusion)
            if args.no_diffusion: model_input["noised_rgb"],model_input["eps"],t=model_input["rgb"],torch.zeros_like(model_input["eps"]),torch.zeros_like(t)

            # Modification for NVS task - remove when doing actual video diffusion 
            model_input["eps"][:,1:-1]=0

            # Run model
            model_out = model(model_input, t)
            
            # Losses
            loss=0
            for k in model_input:
                if k in model_out:

                    #if "eps" in k:continue
                    #if k!="rgb":continue

                    if args.no_labels and "eps" not in k: continue
                    loss_= (model_input[k]-model_out[k]).square().mean()
                    #if not args.no_diffusion and "eps" not in k: loss_/=10

                    ## Loss weights
                    if k=="eps": loss/=4
                    if "seg"==k:loss*=2
                    if "invdepth"==k:loss*=10

                    log_dict["metrics/%s_loss"%k]=loss_.item()
                    loss += loss_
            log_dict["latent_penalty"]=model.scene_codes.weight.square().mean()*1e2
            loss += log_dict["latent_penalty"]
            #print(log_dict)

            # vis all output in vis script and move script to new file

            # for k,v in 

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (global_step%100==0) or (global_step%30==0 and global_step<100) and 1: # Visualize images
                print("Logging images")
                img_grids = []

                for pref,dic in [("input/",model_input),("output/",model_out)]:
                    dic={k:v.flatten(0,1) for k,v in dic.items()}
                    for k,v in dic.items():
                        if k in ["cameras","idx"]:continue
                        v=v[:16]# just log first N batch elements for size constraints
                        if k not in ["rgb","eps","noised_rgb"] and args.no_labels:continue
                        tmp=[("",v)] if v.size(1)!=2 else [("0",v[:,[0]]),("1",v[:,[1]])] # if img is two channel instead of 3, split it
                        for suf,v_ in tmp:
                            img_grids.append( (pref+k+suf,make_grid(v_.cpu().detach(),nrow=model_out["eps"].size(1),normalize=True,scale_each=False) ) )
                            #if "rgb"==k: # make videos if rgb
                            #    wandb.log({"vid/"+pref+k+suf:
                            #        wandb.Video(((v_*.5+.5).unflatten(0,model_input["noised_rgb"].shape[:2]).permute(1,2,0,3,4).flatten(2,3).cpu().detach()*256
                            #                ).to(torch.uint8).numpy(), fps=6,format="mp4")}, step=global_step)

                from pdb import set_trace as pdb_;pdb_() 
                wandb.log({"img/"+k: wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in img_grids},step=global_step)
            #if global_step%550==0 and not args.no_diffusion and 1: # Sample images
            if global_step%100==0 and not args.no_diffusion and 1: # Sample images
                print("doing sampling")
                n_frames = model_input["noised_rgb"].size(1)
                sampled_images, intermed_samples_ = diffusion.sample_N_images( 8, model, diffusion_, model_input["rgb"], args.sampling_steps, 8, metadata.num_channels, 
                                                                                 metadata.image_size, n_frames, metadata.num_classes, args,)
                #sampled_images = torch.from_numpy(sampled_images).permute(0,3,1,2).cpu().detach()/255 # todo just remove processing in sample_n
                sampled_grid = make_grid(sampled_images.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                wandb.log({"sampled_images":wandb.Image(sampled_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)
                wandb.log({"vid/output/sampled_vid":
                    wandb.Video(((sampled_images*.5+.5).permute(1,2,0,3,4).flatten(2,3).cpu().detach()*256).to(torch.uint8).numpy(), fps=6,format="mp4")},
                    step=global_step)

                #intermed_images = torch.from_numpy(intermed_samples).flatten(0,2).cpu().detach()/255 # todo just remove processing in sample_n
                intermed_grid = make_grid(intermed_samples_.flatten(0,2)*.5+.5,nrow=intermed_samples_.size(2),normalize=False)
                wandb.log({"intermed_images":wandb.Image(intermed_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)
                print("done sampling")
                if global_step%(550*1)==0 and 0: # FID comp
                    print("Calculating FID")
                    fid = FrechetInceptionDistance()
                    real_imgs = torch.cat([next(iter(train_loader))["rgb"] for _ in range(500//(args.batch_size*n_frames))]);print("got real images")
                    if real_imgs.size(1)==1: 
                        real_imgs=real_imgs.expand(-1,3,-1,-1)
                        sampled_images=sampled_images.expand(-1,3,-1,-1)
                    if real_imgs.min()<0: real_imgs=real_imgs*.5+.5
                    if sampled_images.min()<0: sampled_images=sampled_images*.5+.5
                    fid.update(real_imgs.flatten(0,1).clip(0,1), is_real=True)
                    fid.update(sampled_images.flatten(0,1).clip(0,1), is_real=False)
                    fid_score = fid.compute();print(f"FID: {fid_score:.3f}")
                    log_dict|={"metrics/FID":fid_score.item()}

            log_dict = log_dict|{"metrics/denoising_loss":loss.item()}
            wandb.log(log_dict, step=global_step)

            # update ema_dict
            new_dict = model.state_dict()
            for (k, v) in args.ema_dict.items(): args.ema_dict[k] = ( args.ema_w * args.ema_dict[k] + (1 - args.ema_w) * new_dict[k])
            logger.log(loss.item(), display=not step % 100)

            if global_step%1000==0 and global_step:
                print("saving model")
                torch.save( model.state_dict(), os.path.join( args.save_dir, f"{args.name}_{args.dataset}-epoch_{args.epochs}.pt"))


if __name__ == "__main__":
    main()
