import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler

def run_model(run,dataset,until_img=25,until_save=500,optim=None,single_data=None):

    loss_all=0

    prefix="val"
    dataset.val()
    dataloader=run.get_dataloader(dataset)

    n_opt=1000

    # disable grad for everything except latents and reset latents
    for name, param in run.model.named_parameters(): param.requires_grad = 'latent' in name 
    for x in run.model.latents: torch.nn.init.normal_(x.weight, mean=0, std=0.01)

    for val_i,model_input in enumerate(dataloader):
        model_input_= args_setup.to_gpu(model_input)

        for step in range(n_opt):
            model_input=deepcopy(model_input_)
            total_loss = 0.
            out=run.model(model_input)
            for loss_name, loss in args_setup.loss_fn(out, model_input, model_input,run.model,step).items():
                wandb.log({"%s_%s"%(loss_name,prefix): loss.item()}, step=step)
                if loss_name == "rgb": total_loss += loss 
                else: loss_all += loss
            wandb.log({"loss_%s"%prefix: total_loss.item()}, step=step)
            print(val_i,total_loss)

            optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())
            total_loss.backward(); 
            if step: optim.step();optim.zero_grad(); 

            # Image summaries and checkpoint
            #if 1 and (step%until_img==0 or prefix=="val"): 
            if step%10==0:
                with torch.no_grad():
                    out=run.model.render_full_img(model_input)
                    vis_scripts.wandb_summary( 0, out, model_input, model_input, None,step=step,prefix=str(val_i)+"_")
    print("depth error: ",loss_all/(1+val_i))

if __name__ == '__main__':
    run = args_setup.make_run()
    run_model(run,run.dataset,until_save=run.args.until_save, until_img=run.args.until_img)
