import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler

def run_model(run,dataset,until_img=25,until_save=500,optim=None,single_data=None):

    if run.args.overfit: # overfitting on single scene - load from disk if requested, save to disk with flow otherwise
        single_data =model_input,model_input = [args_setup.to_gpu(x) for x in next(dataloaders[0])]
        if run.args.load_save and run.args.overfit:
            model_input=torch.load("output/tmp.pt")
            single_data=model_input,model_input
            print("Loading saved data")
        torch.save(model_input,"output/tmp.pt")

    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    for step in trange(run.args.n_train_steps, desc="Fitting"): 

        # train/val switch
        until_eval,n_eval=1000000000,5
        if step%until_eval==(until_eval-n_eval):
            print(step,"switching to val")
            prefix="val"
            dataset.val()
            dataloader=run.get_dataloader(dataset)
            torch.set_grad_enabled(False)
        elif step%until_eval==0:
            print(step,"switching to train")
            prefix="train"
            dataset.train()
            dataloader=run.get_dataloader(dataset)
            torch.set_grad_enabled(True)
        # Get data
        if single_data is None:
            try: model_input= next(dataloader)
            except StopIteration:
                dataloader=run.get_dataloader(dataset)
                continue
            model_input= args_setup.to_gpu(model_input)
        else: model_input= deepcopy(single_data)

        # Run model and calculate losses
        total_loss = 0.
        out=run.model(model_input)
        for loss_name, loss in args_setup.loss_fn(out, model_input, model_input,run.model,step).items():
            wandb.log({"%s_%s"%(loss_name,prefix): loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss_%s"%prefix: total_loss.item()}, step=step)

        if torch.is_grad_enabled():
            total_loss.backward(); 
            if step: optim.step();optim.zero_grad(); 

        # Image summaries and checkpoint
        if 1 and (step%until_img==0 or prefix=="val"): 
            with torch.no_grad(): 
                if "latents" in model_input: del model_input["latents"]
                for k,v in model_input.items(): model_input[k] = v[:4] # reduce batch size for tractable vis
                out=run.model.render_full_img(model_input)
                vis_scripts.wandb_summary( 0, out, model_input, model_input, None,step=step,prefix=prefix+"_")

        if step%until_save == 0 and step: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

if __name__ == '__main__':
    run = args_setup.make_run()
    run_model(run,run.dataset,until_save=run.args.until_save, until_img=run.args.until_img)
