import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts,geometry
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler

def train(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):
    losses_agg=[]

    train_dataloaders=[run.get_dataloader(dset) for dset in train_dataset] # get data

    if run.args.overfit: # overfitting on single scene - load from disk if requested, save to disk with flow otherwise
        single_data =model_input,ground_truth = [args_setup.to_gpu(x) for x in next(train_dataloaders[0])]
        if run.args.load_save and run.args.overfit and 1:
            model_input=torch.load("output/tmp.pt")
            single_data=model_input,ground_truth
            print("Loading saved data")
        with torch.no_grad():
            run.model.get_flow(model_input)
            #run.model.get_seg(model_input)
            #if "zoe_depth" not in model_input: model_input["zoe_depth"] = ch_sec(run.model.model_zoe_n.infer(model_input["rgb"].flatten(0,1)*.5+.5).unflatten(0,model_input["rgb"].shape[:2]))

        torch.save(model_input,"output/tmp.pt")
        if run.args.depth_var: # depth as variable opimization
            run.model.depth=torch.nn.Parameter(torch.ones_like(model_input["rgb"][:,:,:1]),requires_grad=True)
            run.model.corr_weights=torch.nn.Parameter(torch.ones_like(model_input["rgb"][:,1:,:1])*1e-2,requires_grad=True)
            run.model.depth_scale=1
            run.args.lr=1e-3

    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    # Train loop
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        dset_i=step%len(train_dataloaders)
        train_dataloader=train_dataloaders[dset_i]

        # Get data
        if single_data is None:
            try: model_input, ground_truth = next(train_dataloader)
            except StopIteration:
                train_dataloaders[dset_i]=run.get_dataloader(train_dataset[dset_i])
                continue
            model_input, ground_truth = args_setup.to_gpu(model_input), args_setup.to_gpu(ground_truth)
        else: model_input, ground_truth = deepcopy(single_data)

        # Run model and calculate losses
        total_loss = 0.
        out=run.model(model_input)
        losses = args_setup.loss_fn(out, ground_truth, model_input,run.model,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss/len(train_dataloaders)

        wandb.log({"trainer/global_step": step}, step=step)
        wandb.log({"loss": total_loss.item()}, step=step)
        if single_data is None: wandb.log({"epoch": (step*run.args.batch_size)/len(train_dataset)}, step=step)

        total_loss.backward(); 
        if step%len(train_dataloaders)==0 and step: optim.step();optim.zero_grad(); 

        # Image summaries and checkpoint
        losses_agg.append({k:v.detach().item() for k,v in losses.items()})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%until_img in list(range(len(train_dataset))) and 1: wandb_imgs=vis_scripts.wandb_summary( 0, out, model_input, ground_truth, None,step=step)
            if run.args.viser:vis_scripts.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)

        if (step==0 or (not run.args.overfit and step%until_img==0)) and "pred_visibility" in model_input and run.args.point_track and 0: 
            for i in range(run.model.n_track_frames): wandb.log({'_vid/track_%01d'%i:wandb.Video("./output/tracks_%02d_pred_track.mp4"%i, format='mp4', fps=8)})

        if step%until_save == 0 and step: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

        if run.args.export_poses and step%100==0 and run.args.overfit: # export poses+geom if overfitting on single scene
            #ros, rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"], out["poses"])
            #out["world_crds"]=ros+rds*out["depth"].cuda()
            out["rgb"] = model_input["rgb"].flatten(0,1).flatten(-2,-1).permute(0,2,1)[None]
            out["intrinsics"] = model_input["intrinsics"]
            if "c2w" in model_input: out["gt_poses"],out["gt_intrinsics"]=model_input["c2w"],model_input["gt_intrinsics"]
            #out["rgb"]=model_input["rgb_large"];out["rgb_crds"]=ch_sec(model_input["rgb"]*.5+.5)
            torch.save({k:v[0].detach() for k,v in out.items() if type(v)==torch.Tensor and len(v.shape)>1} ,f"output/pose_exps/poses_{run.args.name}.pt")
            print("exported poses")

if __name__ == '__main__':
    run = args_setup.make_run()
    #torch.autograd.set_detect_anomaly(True)
    torch.autograd.detect_anomaly()
    train(run,run.dataset,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
