import wandb
import argparse
import os

import vis_scripts

# Import of libraries
import random
import imageio
import numpy as np
import args_setup

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import einops
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST, FashionMNIST


run = args_setup.make_run()

"""# Optional visualizations"""

# Optionally, show the denoising (backward) process
optim = Adam(run.model.parameters(), run.args.lr)

#i think we should go back to the unet, refactor it codewise however you want, add your transformer in the lowres bottleneck, removing layernorms, etc and 
#just making sure can get similar results and stable trainings 

global_step=0
n_epochs=1000000
until_vis,until_gen,until_calc_fid,until_save=50,200,1000,3000
for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
    epoch_loss = 0.0
    for step, batch in enumerate(tqdm(run.loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):

        batch= args_setup.to_gpu(batch)

        log_dict={}
        global_step+=1
        # Loading data

        # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
        eta = torch.randn_like(batch["rgb"]).cuda()
        t = torch.randint(0, run.model.n_steps, (len(batch["rgb"]),)).cuda()
        noisy_imgs,eta_ = run.model(batch["rgb"], t, eta) # fwd noising process
        model_est = run.model.network(noisy_imgs, t.reshape(len(batch["rgb"]), -1)) # estimate noise we added

        # Optimizing the MSE between the noise added and the predicted noise
        loss = (model_est-eta_).square().mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
        epoch_loss += loss.item() * len(batch["rgb"]) / len(run.loader.dataset)

        log_dict |= {"denoise_loss": loss.item(),"epoch": epoch}

        if global_step%until_vis==1: # Visualizations 
            vis={"imgs_raw":batch["rgb"], "noisy_imgs":noisy_imgs, "model_est":noisy_imgs-model_est}
            if global_step%until_gen==1: vis["sampled_gen"] = run.model.generate_new_images(imsl=run.args.imsl,c=run.args.ch,n_samples=run.args.n_gen)
            vis_scripts.wandb_summary(vis)
        if global_step%until_calc_fid==1 and 0: # FID comp
            print("Calculating FID")
            from torcheval.metrics import FrechetInceptionDistance
            fid = FrechetInceptionDistance()
            real_imgs = torch.cat([next(iter(run.loader))["rgb"] for _ in range(500//run.args.batch_size)]);print("got real images")
            gen_imgs = torch.cat([run.model.generate_new_images(n_samples=run.args.n_gen,just_last=True,imsl=run.args.imsl,c=run.args.ch)]).flatten(0,1);print("generated images")
            fid.update(real_imgs.expand(-1,3,-1,-1)*.5+.5, is_real=True)
            fid.update( gen_imgs.expand(-1,3,-1,-1).clip(-1,1)*.5+.5, is_real=False)
            fid_score = fid.compute();print(f"FID: {fid_score:.3f}")
            log_dict["FID"]=fid_score

        wandb.log(log_dict, step=global_step)

        if global_step%until_save==1: # Checkpoint model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

    print(f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}")
    #torch.save(ddpm.state_dict(), store_path)
