import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler

def run_model(run,dataset,until_img=25,until_save=500,optim=None,single_data=None):

    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    for step in trange(run.args.n_train_steps, desc="Fitting"): 

        # train/val switch
        until_eval,n_eval=100,5
        if step%until_eval==(until_eval-n_eval):
            print(step,"switching to val")
            prefix="val"
            dataset.val()
            dataloader=run.get_dataloader(dataset)
            torch.set_grad_enabled(False)
        elif step%until_eval==0:
            print(step,"switching to train")
            prefix="train"
            dataset.train()
            dataloader=run.get_dataloader(dataset)
            torch.set_grad_enabled(True)
        # Get data
        try: model_input= next(dataloader)
        except StopIteration: dataloader=run.get_dataloader(dataset);continue
        model_input= args_setup.to_gpu(model_input)

        # Run model and calculate losses
        total_loss = 0.
        out=run.model(model_input)
        for loss_name, loss in args_setup.loss_fn(out, model_input, model_input,run.model,step).items():
            wandb.log({"%s_%s"%(loss_name,prefix): loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss_%s"%prefix: total_loss.item()}, step=step)

        if torch.is_grad_enabled():
            total_loss.backward(); 
            if step: optim.step();optim.zero_grad(); 

        # Image summaries and checkpoint
        until_gen=100
        if 1 and (step%until_img==0 or prefix=="val"): 
            with torch.no_grad(): 
                if step%until_gen==0:
                    print("sampling model")  
                    out["sampled_gen"] = run.model.generate_new_images()

                #for k,v in model_input.items(): model_input[k] = v[:4] # reduce batch size for tractable vis
                #out=run.model.render_full_img(model_input)
                vis_scripts.wandb_summary( 0, out, model_input, model_input, None,step=step,prefix=prefix+"_")

        if step%until_save == 0 and step: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

if __name__ == '__main__':
    run = args_setup.make_run()
    run_model(run,run.dataset,until_save=run.args.until_save, until_img=run.args.until_img)
