import os,random,time,glob,sys

import numpy as np
from functools import partial
from tqdm import tqdm,trange
from einops import rearrange,repeat
from torch.nn import functional as F

import wandb
import torch

import models
import vis_scripts

from data.sun_rgbd import SUNRGBD
import getpass

def to_gpu(ob): return {k: to_gpu(v) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

import argparse
parser = argparse.ArgumentParser(description='simple training job')
# logging parameters
parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
parser.add_argument('-o','--online', default=False, action='store_true')
# data/training parameters
parser.add_argument('-d','--dataset', type=str,default="hydrant")
parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
parser.add_argument('--n_workers',type=int,default=10,help="number of workers per dataloader")
parser.add_argument('--until_save',type=int,default=1000,help="number of steps until model save")
parser.add_argument('--lr',type=float,default=1e-4,help="learning rate")
parser.add_argument('--n_train_steps',type=int,default=int(1e8),help="learning rate")
parser.add_argument('--overfit', default=False, action='store_true',help="Whether to overfit on a single scene")
parser.add_argument('--no_shuffle', default=False, action='store_true',help="Whether to shuffle dataset")
parser.add_argument('--until_img', type=int,default=50,help="Number of steps until image summary. ")
parser.add_argument('--overfit_size', type=int,default=10000000,help="Number of scenes to overfit on. ")
parser.add_argument('--load_save', default=False, action='store_true',help="Whether to load the previously saved data if overfitting (to avoid running flow again)")
# model parameters
parser.add_argument('--scratch_net', default=False, action='store_true',help="Whether to turn off the midas weight prior and use a network from scratch")
parser.add_argument('--fdim',type=int,default=512,help="latent code dimension")
parser.add_argument('--spatial_dims',type=int,default=100,help="max number of spatial grids to use (usually ~4 and set to 0 for just cnn)")
#parser.add_argument('--n_samples', type=int,default=64,help="Number of samples along ray")
# eval/vis 
parser.add_argument('--eval', default=False, action='store_true',help="whether to train or run evaluation")
parser.add_argument('--n_eval', type=int,default=int(1e8),help="Number of eval samples to run")
parser.add_argument('--save_imgs', default=True, action='store_true',help="whether to save out the all-trajectory images")

ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)

def loss_fn(model_out, gt, model_input,model,step):

    losses = { }
    inv = lambda x : 1/((1e-3+x)+1e-8)
    losses["rgb"]=(ch_sec(model_input["rgb"])[:,None,model_out["render_pix"],:]-model_out["rgb"]).square().mean()*5e1
    losses["depth"]=((inv(model_input["depth"].flatten(1,2)[:,None,model_out["render_pix"]])-inv(model_out["depth"]))*model_input["depth_mask"].flatten(1,2)[:,None,model_out["render_pix"]]).abs().mean()
    losses["gauss_latent"]=sum(latent.weight.norm() for latent in model.latents)*3e-4
    
    return losses

def make_run(args=None,val=False):
    args = parser.parse_args(args)
    self = argparse.Namespace()
    user = getpass.getuser()
    print(f"user={user}")
    if args.overfit: args.n_workers=0
    #if args.use_renderer and args.overfit: args.lr=5e-5

    # Wandb init
    run = wandb.init(entity="cosmith",project="inference_as_optim",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")
    self.save_dir = os.path.join("/home/cameronsmith/logs" , run.name)
    print(self.save_dir)
    os.makedirs(self.save_dir,exist_ok=True)

    # Make dataset
    self.dataset = SUNRGBD(overfit_size=args.overfit_size)
    self.get_dataloader = lambda dataset: iter(torch.utils.data.DataLoader(dataset, batch_size=args.batch_size*torch.cuda.device_count(),
                                                    num_workers=min(args.n_workers,args.batch_size),shuffle=not args.no_shuffle,pin_memory=False))
    # Make model and load checkpoint
    self.model = models.SceneLearner(args).cuda()
    if args.init_ckpt is not None:
        ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob.glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
        self.model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
    self.args=args
    self.wandb=run
    return self

