import os,time
import torch,wandb
from tqdm import trange
from einops import rearrange
import vis,geometry
from copy import deepcopy
import numpy as np
import piqa,kornia
from torchvision.utils import make_grid
from einops import rearrange, repeat
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import getpass

import data,models
import viser,nerfview
from gsplat import rasterization
torch.inverse(torch.ones((0, 0), device="cuda:0"))

def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

def train_flowmap(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):

    def loss_fn(model_out, gt, model_input,model,step):

        rays = lambda x,y: torch.stack([x[i,:,y[i].long()] for i in range(len(x))])
        losses = { }

        if "res_depth" in model_out:losses["metrics/res_depth_reg"] = model_out["res_depth"].square().mean()*1e-1
        if "track_reprojs" in model_out: 
            #rig_tracks=F.grid_sample(model_input["rig_flow_masks"].flatten(0,1),model_input["pred_tracks"][:,1:].flatten(0,1).unsqueeze(-2)*2-1,padding_mode="border",)
            #rig_tracks=torch.cat((torch.ones_like(rig_tracks[:1]),rig_tracks)).squeeze()[None]
            #vis_mask=torch.minimum(repeat(model_input["pred_visibility"]*rig_tracks,"b t k -> b k t s 1",s=model_input["pred_visibility"].size(1)),
            #                       repeat(model_input["pred_visibility"]*rig_tracks,"b t k -> b k s t 1",s=model_input["pred_visibility"].size(1)))
            vis_mask=torch.minimum(repeat(model_input["pred_visibility"],"b t k -> b k t s 1",s=model_input["pred_visibility"].size(1)),
                                   repeat(model_input["pred_visibility"],"b t k -> b k s t 1",s=model_input["pred_visibility"].size(1)))
            #losses["metrics/tracks_err"]=(
            #        repeat(model_input["pred_tracks"][:,:,model_out["track_idxs"]],"b t p c -> b p s t c", s=model_input["pred_visibility"].size(1))*vis_mask[:,model_out["track_idxs"]], 
            #        model_out["track_reprojs"]*vis_mask[:,model_out["track_idxs"]],delta=2e-4
            #        )*2e6
            losses["metrics/tracks_err"]= ( repeat(model_input["pred_tracks"][:,:,model_out["track_idxs"]],"b t p c -> b p s t c", s=model_input["pred_visibility"].size(1))*vis_mask[:,model_out["track_idxs"]] - model_out["track_reprojs"]*vis_mask[:,model_out["track_idxs"]] ).square().mean()*5e3

        if "flow_from_pose" in model_out: losses["metrics/flow_from_pose"] = F.huber_loss( model_out["flow_from_pose"].clip(-.2,.2), ch_sec(model_out["flow_inp_"]).clip(-.2,.2),delta=1.5e-05 )*3e7
        #if "flow_from_pose" in model_out: losses["metrics/flow_from_pose"] = ( (model_out["flow_from_pose"].clip(-.2,.2) -ch_sec(model_out["flow_inp_"])).square().mean()*1e6 )#.clip(max=10)
        #if "flow_from_pose" in model_out: 
        #for i in range(1,3): losses["metrics/flow_from_pose_%d"%i] = ( ((model_out["flow_from_pose"].clip(-.2,.2) -ch_sec(model_out["flow_inp_"]))*ch_sec(model_input["rig_flow_masks"])[...,[i]]).square().sum()/ch_sec(model_input["rig_flow_masks"])[...,[i]].sum()*1e5 )#.clip(max=10)
        #if "flow_from_pose" in model_out: losses["metrics/flow_from_pose_sobel_grad"] = (kornia.filters.sobel(model_out["flow_from_pose"].clip(-.2,.2).flatten(0,1).permute(0,2,1).unflatten(-1,model_input["rgb"].shape[-2:])) - kornia.filters.sobel(model_out["flow_inp_"].flatten(0,1)) ).square().mean()*1e6
                                                                                    
        #if "flow_from_pose" in model_out: losses["metrics/flow_from_pose"] = (model_out["flow_from_pose"].clip(-.2,.2)*ch_sec(model_input["rig_flow_masks"])
        #                                            -ch_sec(model_out["flow_inp_"])*ch_sec(model_input["rig_flow_masks"])).square().mean()*6e3 
        #if "flow_from_pose" in model_out: losses["metrics/flow_from_pose"] = ( (model_out["flow_from_pose"].clip(-.2,.2)*ch_sec(model_input["rig_flow_masks"])
        #                                            -ch_sec(model_out["flow_inp_"])*ch_sec(model_input["rig_flow_masks"])).square().mean()*6e5 ).clip(max=10)

        wandb.log({"est/fx": model_input["intrinsics"][0,0,0,0]},step=step)
        wandb.log({"est/fy": model_input["intrinsics"][0,0,1,1]},step=step)

        return losses

    losses_agg=[]
    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    model_input = ground_truth = single_data = to_gpu(run.dataset.collate_fn([train_dataset[0][0]]))

    # Train loop
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        # Run model and calculate losses
        total_loss = 0.
        out=run.model(model_input)
        losses = loss_fn(out, ground_truth, model_input,run.model,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss": total_loss.item()}, step=step)

        total_loss.backward();optim.step();optim.zero_grad(); 

        #run.splat_vars = geometry.format_splat_vars(out) # for splat vis
        # Image summaries and checkpoint
        if step==0:run.time=time.time()
        losses_agg.append({k:v.detach().item() for k,v in losses.items()}|{"time elapsed (seconds)":time.time()-run.time})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%until_img==0 and 1: wandb_imgs=vis.wandb_summary( 0, out, model_input, ground_truth, None,step=step)
            if run.args.viser:vis.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)
        if step%until_save == 0 and step: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 
            #save(run.splat_vars)

        if run.args.export_poses and step%100==0 and run.args.overfit: # export poses+geom if overfitting on single scene
            out["intrinsics"] = model_input["intrinsics"]
            if "c2w" in model_input: out["gt_poses"],out["gt_intrinsics"]=model_input["c2w"],model_input["gt_intrinsics"]
            #torch.save({k:v[0].detach() for k,v in out.items() if type(v)==torch.Tensor and len(v.shape)>1} ,f"output/pose_exps/poses_{run.args.name}.pt")
            torch.save({k:v.detach() for k,v in out.items() if type(v)==torch.Tensor and len(v.shape)>1} ,f"output/pose_exps/poses_{run.args.name}.pt")
            print("exported poses")

def train_splat(run):

    def splat_loss_fn(model_out, gt, model_input,model,step,view_i,scene):

        losses={}

        losses["metrics/rgb"]=(model_out["rgb"]-(model_out["gt_rgb"]*.5+.5).permute(1,2,0)).square().mean()
        losses["metrics/depth"]= (model_out["depth"].flatten()-scene["depth"][0,view_i,:,0]).square().mean()*1e-2
        if "render_flow" in model_out and view_i: losses["metrics/flow"]= ( (model_out["render_flow"].permute(0,3,1,2)-model_input["flow_inp_"][0,[view_i-1]]).square().mean()*2e-3 ).clip(max=losses["metrics/rgb"].detach())

        #print(losses)
        return losses

    scene = torch.load(run.args.splat_src)

    model_input = ground_truth = single_data = to_gpu(run.dataset.collate_fn([run.dataset[0][0]]))
    model_input |= scene

    # optionally make it higher res than used for pose optimization (which may be very low res to save compute) 
    if 0:
        imsize_low=scene["flow_inp_"].shape[-2:]
        imsize_large=model_input["bwd_flow_large"].shape[-2:]
        scene["world_crds"]=models.ch_sec(F.interpolate(models.ch_fst(scene["world_crds"],imsize_low[0]).flatten(0,1),imsize_large,mode="nearest")[None])
        scene["rgb_crds"]=  models.ch_sec(model_input["rgb_large"]/255*2-1)
        scene["depth"]=model_input["depth"]=ground_truth["depth"]=model_input["depth_inp_large"]
        scene["rgb"]=model_input["rgb"]=ground_truth["rgb"]=model_input["rgb_large"]/255 *2-1
        scene["flow_inp_"]=model_input["flow_inp_"]=ground_truth["flow_inp_"]=model_input["bwd_flow_large"]

    #run.splat_vars = geometry.format_splat_vars(scene)

    params = [
        # name, value, lr
        ("means3d",   run.splat_vars["means"],    1.6e-4),
        ("scales",    run.splat_vars["scales"],    2e-3),
        ("quats",     run.splat_vars["quats"],     1e-3),
        ("opacities", run.splat_vars["opacities"], 2e-2),
        #("lie_poses", run.splat_vars["lie_poses"], 1e-8),
        #("lie_perpix",run.splat_vars["lie_perpix"], 1e-5),
    ]
    optimizers = [ torch.optim.Adam( [{"params": param, "lr": lr*1e-1, "name": name}],) for name, param, lr in params ]

    losses_agg=[]

    gt_rgbs=model_input["rgb"][0]
    imsize=scene["flow_inp_"].shape[-2:]
    Ks=scene["intrinsics"][0,:1,:3,:3].detach().clone()
    Ks[:,0]*=imsize[1]
    Ks[:,1]*=imsize[0]

    # Train loop
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        # todo pick n views to render if n render, sort them
        view_i=np.random.randint(len(gt_rgbs))

        render, alphas, meta = geometry.do_render(torch.eye(4).cuda(),view_i,imsize,Ks,run.splat_vars)
        out ={ "rgb":render[0,...,:3], "render_flow":render[...,3:5],"depth":render[0,...,-1], "alphas":render[0,...,3], "gt_rgb":gt_rgbs[view_i], }
        #out["poses"]=scene["poses"]
        #out["lie_poses"]=run.splat_vars["lie_poses"]

        # Calculate losses
        total_loss = 0.
        losses = splat_loss_fn(out, ground_truth, model_input,None,step,view_i,scene)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss": total_loss.item()}, step=step)

        total_loss.backward(); 
        for optimizer in optimizers:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        # Image summaries and checkpoint
        if step==0:run.time=time.time()
        losses_agg.append({k:v.detach().item() for k,v in losses.items()}|{"time elapsed (seconds)":time.time()-run.time})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%run.args.until_img==0: wandb_imgs=vis.wandb_summary_splat( 0, out, model_input, ground_truth, None,step=step,view_i=view_i)
            if run.args.viser:vis.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)

        if 0 and step%until_save == 0 and step: # save model
            print(f"Saving to {run.save_dir}"); 
            torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

# Data/args setup and run
import argparse
parser = argparse.ArgumentParser(description='simple training job')
# logging parameters
parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
parser.add_argument('-o','--online', default=False, action='store_true')
parser.add_argument('--viser', default=False, action='store_true')
parser.add_argument('--save_opt_vis', default=False, action='store_true')
# data/training parameters
parser.add_argument('-d','--dataset', type=str,default="hydrant")
parser.add_argument('--imgpath', type=str,default="")
parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
parser.add_argument('--n_workers',type=int,default=0,help="number of workers per dataloader")
parser.add_argument('--until_save',type=int,default=500,help="number of steps until model save")
parser.add_argument('--lr',type=float,default=1e-4,help="learning rate")
parser.add_argument('--n_train_steps',type=int,default=int(1e8),help="learning rate")
parser.add_argument('--overfit', default=True, action='store_true',help="Whether to overfit on a single scene")
parser.add_argument('--until_img', type=int,default=50,help="Number of steps until image summary. ")
parser.add_argument('--sf', type=float,default=1,help="Image resolution scale factor (fractional is cheaper)")
parser.add_argument('--load_save', default=False, action='store_true',help="Whether to load the previously saved data if overfitting (to avoid running flow again)")
parser.add_argument('--splat_src', type=str,default=None,required=False,help="splat src pt file")
# model parameters
parser.add_argument('--time_stride', type=int,default=1,help="Number of frames to flatten in the encoder like a temporal convolution (to save memory). ")
parser.add_argument('--n_skip', type=int,default=1,help="Number of frames to skip between adjacent frames in dataloader. ")
parser.add_argument('--use_gt_intrinsics', default=True, action='store_true',help="Whether to use GT intrinsics instead of predicting them. Useful for pretraining scene rep.")
parser.add_argument('--point_track', default=True, action='store_true',help="Whether to use point tracking")
parser.add_argument('--export_poses', default=False, action='store_true',help="Export poses when overfitting")

def make_run(args=None,val=False):
    args = parser.parse_args(args)
    self = argparse.Namespace()
    user = getpass.getuser()
    print(f"user={user}")

    # Wandb init
    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")
    self.save_dir = "/tmp/"+args.name#os.path.join(os.environ.get('LOGDIR', "") , run.name)
    os.makedirs(self.save_dir,exist_ok=True)
    wandb.save(os.path.join(self.save_dir, "checkpoint*"))
    wandb.save(os.path.join(self.save_dir, "video*"))

    self.args=args
    self.wandb=run
    if args.viser: self.viser_server=viser.ViserServer()

    # Make dataset
    self.dataset = data.ImageFolder(path=args.imgpath,num_trgt=args.vid_len+1,n_skip = args.n_skip,sf=args.sf)
    # Make model and load checkpoint
    self.model = models.FlowMap(args).cuda()
    if args.init_ckpt is not None:
        ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob.glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
        self.model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
    return self

run = make_run()
run.splat_vars=None

def viewer_render_fn(camera_state, img_wh):
    try: timestep=int(run.viser_server.gui_timestep.value)
    except:timestep=0
    if run.splat_vars is None:
        print("skipping splat vars, not set")
        return torch.zeros(*list(img_wh)[::-1],3).cpu().numpy()
    with torch.no_grad(): render, alphas, meta = geometry.do_render(torch.from_numpy(camera_state.c2w).float().cuda().inverse(),
                    timestep,img_wh[::-1],torch.from_numpy(camera_state.get_K(img_wh)).float().cuda()[None],splat_vars=run.splat_vars)
    return render[0,...,:3].cpu().numpy()

if run.args.viser: 
    run.viser_server.run=run
    nerfview.Viewer( server=run.viser_server, render_fn=viewer_render_fn, mode="rendering",)

torch.autograd.set_detect_anomaly(True)
if run.args.splat_src is None: train_flowmap(run,run.dataset,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
else: train_splat(run)
