import os,random,time,glob,sys

import numpy as np
from functools import partial
from tqdm import tqdm,trange
from einops import rearrange,repeat
from torch.nn import functional as F

import wandb
import torch

import models
import vis_scripts

from data.KITTI import KittiDataset
from data.co3d import Co3DNoCams
from data.walking_tours import WalkingTours
from data.pokemon_rooms import PokemonRooms
from data.realestate10k_dataio import RealEstate10k
from data.re10k_hires import DatasetRealEstate10k

from data.tanks2 import Tanks
from data.image_folder import ImageFolder
from data.davis import DAVIS
from data.LLFF import LLFF
from data.mip import Mip
import getpass

def to_gpu(ob): return {k: to_gpu(v) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

import argparse
parser = argparse.ArgumentParser(description='simple training job')
# logging parameters
parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
parser.add_argument('-o','--online', default=False, action='store_true')
# data/training parameters
parser.add_argument('-d','--dataset', type=str,default="hydrant")
parser.add_argument('--imgpath', type=str,default="")
parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
parser.add_argument('--n_workers',type=int,default=100,help="number of workers per dataloader")
parser.add_argument('--until_save',type=int,default=500,help="number of steps until model save")
parser.add_argument('--lr',type=float,default=1e-4,help="learning rate")
parser.add_argument('--n_train_steps',type=int,default=int(1e8),help="learning rate")
parser.add_argument('--overfit', default=False, action='store_true',help="Whether to overfit on a single scene")
parser.add_argument('--until_img', type=int,default=50,help="Number of steps until image summary. ")
parser.add_argument('--load_save', default=False, action='store_true',help="Whether to load the previously saved data if overfitting (to avoid running flow again)")
parser.add_argument('--seq_query', type=str,default=None,help="co3d sequency query for overfitting")
parser.add_argument('--category', type=str,default=None,help="co3d category overfitting")
# model parameters
parser.add_argument('--n_render_rays', type=int,default=1024,help="Num rays to volume render")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
parser.add_argument('--depth_var', default=False, action='store_true',help="Whether to use depth-as-variable optimization instead of network finetuning")
parser.add_argument('--gm_flow', default=False, action='store_true',help="Whether to use gm_flow instead of raft")
parser.add_argument('--midas_invert', default=False, action='store_true',help="Whether to interpret model output as disparity (for directly using midas)")
parser.add_argument('--use_gt_intrinsics', default=False, action='store_true',help="Whether to use GT intrinsics instead of predicting them. Useful for pretraining scene rep.")
parser.add_argument('--point_track', default=False, action='store_true',help="Whether to use point tracking")
parser.add_argument('--pixelSplat', default=False, action='store_true',help="Whether to use pixelSplat rendering for photometric loss")
parser.add_argument('--scratch_net', default=False, action='store_true',help="Whether to turn off the midas weight prior and use a network from scratch")
parser.add_argument('--n_samples', type=int,default=64,help="Number of samples along ray")
# eval/vis 
parser.add_argument('--export_poses', default=False, action='store_true',help="Export poses when overfitting")
parser.add_argument('--eval', default=False, action='store_true',help="whether to train or run evaluation")
parser.add_argument('--n_eval', type=int,default=int(1e8),help="Number of eval samples to run")
parser.add_argument('--save_ind', default=False, action='store_true',help="whether to save out each individual image (in rendering images) or just save the all-trajectory image")
parser.add_argument('--save_imgs', default=True, action='store_true',help="whether to save out the all-trajectory images")
parser.add_argument('--plot_poses', default=True, action='store_true',help="whether to save out the all-trajectory poses")

ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)

def loss_fn(model_out, gt, model_input,model,step):

    rays = lambda x,y: torch.stack([x[i,:,y[i].long()] for i in range(len(x))])
    losses = { }

    # Point tracking loss
    #if "all_pair_pix" in model_out and (step>20 or not model.args.overfit): 
    #    all_crds=repeat(model_input["pred_tracks"],"b t p c -> b p s t c",s=model_input["rgb"].size(1))
    #    vis_mask=torch.minimum(repeat(model_input["pred_visibility"],"b t k -> b k t s 1",s=model_input["rgb"].size(1)), 
    #                           repeat(model_input["pred_visibility"],"b t k -> b k s t 1",s=model_input["rgb"].size(1)))
    #    losses["metrics/tracks_err"]=F.huber_loss(all_crds*vis_mask,model_out["all_pair_pix"]*vis_mask,delta=2e-4)*2e6

    # Pose flow loss
    #if "zoe_d_loss" in model_out: losses["metrics/zoe_depth_loss"] = model_out["zoe_d_loss"] * (1 if step<500 else .5)
    if "res_depth" in model_out:losses["metrics/res_depth_reg"] = model_out["res_depth"].square().mean()*1e-1
    if "point_track_loss" in model_out: losses["metrics/tracks_err"] = model_out["point_track_loss"]*1e1
    if "flow_from_pose" in model_out: losses["metrics/flow_from_pose"] = F.huber_loss( model_out["flow_from_pose"].clip(-.2,.2),
                                                                                    ch_sec(model_out["flow_inp_"]).clip(-.2,.2),delta=1.5e-05 )*1e6
    # Rn just penalizing everything after first component, todo implement N-body or increasing component loss
    #if "eigens" in model_out: losses["pca_comps"] = model_out["eigens"][...,1:].square().mean()*1e2
    #if "affinity_sim" in model_out: losses["metrics/aff_reg"] = (1-model_out["affinity_sim"]).square().mean()*3e-2
    #if "affinity_emb" in model_out: losses["metrics/affinity_emb_nuc"] = torch.svd(ch_sec(model_out["affinity_emb"]).flatten(1,2))[1].sum()*1e-6
    #if "lie_perpix" in model_out: 
    #    losses["metrics/lieperpix_nucnorm"] = torch.svd(ch_sec(model_out["lie_perpix"]).flatten(1,2))[1].sum()*2e-6
    #    losses["metrics/lie_nucnorm"] = torch.svd(model_out["poses_lie"].flatten(1,2))[1].sum()*1e-6

    wandb.log({"est/fx": model_input["intrinsics"][0,0,0,0]},step=step)
    wandb.log({"est/fy": model_input["intrinsics"][0,0,1,1]},step=step)
    if "gt_intrinsics" in model_input:
        wandb.log({"ref/fx": model_input["gt_intrinsics"][0,0,0,0]},step=step)
        wandb.log({"ref/fy": model_input["gt_intrinsics"][0,0,1,1]},step=step)
    print(losses)
    return losses

def make_run(args=None,val=False):
    args = parser.parse_args(args)
    self = argparse.Namespace()
    user = getpass.getuser()
    print(f"user={user}")
    if args.n_skip==0 and args.dataset in ["realestate","re10khires"]: args.n_skip=9
    if args.n_skip==0 and args.dataset in ["10cat","hydrant"] and not args.overfit: args.n_skip=2
    if args.overfit: args.n_workers=0
    if args.init_ckpt=="best":args.init_ckpt=f"/home/camsmith/logs/full_render_test_alldata_hires/"
    #if args.use_renderer and args.overfit: args.lr=5e-5

    # Wandb init
    #run = wandb.init(entity="scene-representation-group",project="dyn-pixelNeRF",mode="online" if args.online else "disabled",name=args.name,dir=f"/nobackup/nvme1/{user}/wandb")
    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")
    self.save_dir = "/tmp"#os.path.join(os.environ.get('LOGDIR', "") , run.name)
    print(self.save_dir)
    os.makedirs(self.save_dir,exist_ok=True)
    wandb.save(os.path.join(self.save_dir, "checkpoint*"))
    wandb.save(os.path.join(self.save_dir, "video*"))

    # Make dataset
    self.dataset = [( 
                                 Co3DNoCams(num_trgt=args.vid_len+1,num_cat=1 if args.dataset=="hydrant" else 10 if args.dataset=="10cat" else 30,
                                      n_skip=args.n_skip,val=val,seq_query=args.seq_query,category=args.category) if args.dataset in ["hydrant","10cat","allcat"]
                                 else RealEstate10k(imsl=128, num_ctxt_views=2, num_query_views=args.vid_len+1, val=val, n_skip = args.n_skip) if args.dataset == "realestate" 
                                 else ImageFolder(path=args.imgpath,num_trgt=args.vid_len+1,n_skip = args.n_skip) if args.dataset == "dog" 
                                 else DAVIS(path=args.imgpath,num_trgt=args.vid_len+1,n_skip = args.n_skip) if args.dataset == "davis" 
                                 else Tanks(low_res=(128,184) if 0 else (92,128), num_trgt=args.vid_len+1, n_skip = args.n_skip, scene=args.dataset.split("_")[-1]) if "tanks" in args.dataset 
                                 else LLFF(low_res=(92,128),num_trgt=args.vid_len+1, n_skip = args.n_skip, scene=args.dataset.split("_")[1]) if "llff" in args.dataset 
                                 else Mip(low_res=(128,184) if 0 else (92,128), num_trgt=args.vid_len+1, n_skip = args.n_skip, scene=args.dataset.split("_")[-1]) if "mip" in args.dataset 
                                 else KittiDataset(num_context=1,num_trgt=args.vid_len+1,low_res=(76,250),val=val,n_skip=args.n_skip) #if args.dataset=="kitti" else None
                         )] if "all" not in args.dataset else [
                                        RealEstate10k(imsl=128, num_query_views=args.vid_len+1, val=val, n_skip=9),
                                        KittiDataset(num_context=1,num_trgt=args.vid_len+1,val=val,n_skip=0),
                                        Co3DNoCams(num_trgt=args.vid_len+1,num_cat=10 if "small" not in args.dataset else 1,n_skip=2,val=val) ]
    self.get_dataloader = lambda dataset: iter(torch.utils.data.DataLoader(dataset, batch_size=args.batch_size*torch.cuda.device_count(),
                                                    num_workers=min(args.n_workers,args.batch_size),shuffle=True,pin_memory=True))
    # Make model and load checkpoint
    self.model = models.FlowMap(args).cuda()
    if args.init_ckpt is not None:
        ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob.glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
        self.model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
    self.args=args
    self.wandb=run
    return self

