import os
import torch,wandb
from tqdm import trange
import kornia
import time
from einops import rearrange
import args_setup,vis_scripts,geometry
import numpy as np
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
import models
from torch.cuda.amp import autocast, GradScaler
import kornia
import matplotlib.pyplot as plt 
import nerfview
from gsplat import rasterization
torch.inverse(torch.ones((0, 0), device="cuda:0"))

def splat_loss_fn(model_out, gt, model_input,model,step,view_i):

    losses={}

    losses["metrics/rgb"]=(model_out["rgb"]-model_out["gt_rgb"]).square().mean()
    losses["metrics/depth"]= (model_out["depth"].flatten()-model_input["depth"][view_i,:,0]).square().mean()*1e-3
    if "render_flow" in model_out: losses["metrics/flow"]= ( (model_out["render_flow"].permute(0,3,1,2)-model_input["flow_inp_"][[view_i-1]]).square().mean()*5e-3 ).clip(max=losses["metrics/rgb"].detach())

    #print(losses)
    return losses

def train(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):
    losses_agg=[]

    scene = torch.load(run.args.splat_src)
    low_res=scene["flow_inp_"].shape[-2:]

    def get_pose_perpix(lie_perpix_):
        pose_perpix = torch.eye(4)[None,None].expand(len(scene["pose_perpix"]),lie_perpix_.size(-2),-1,-1).to(lie_perpix_)
        pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_perpix_[...,:4])
        pose_perpix[...,:3,-1] = lie_perpix_[...,4:]
        return pose_perpix

    means =      torch.nn.Parameter(scene["world_crds"].flatten(0,1))
    colors=      torch.nn.Parameter(scene["rgb_crds"].flatten(0,1)*.5+.5 )
    quats =      torch.nn.Parameter(torch.ones(len(means),4).cuda() )
    opacities=   torch.nn.Parameter(torch.ones(len(means)).cuda()*.05 )
    scales=      torch.nn.Parameter(torch.ones(len(means),3).cuda()*.01 )
    lie_perpix = torch.nn.Parameter(scene["lie_crds"].flatten(1,2)) 

    imsize=scene["flow_inp_"].shape[-2:]
    gt_rgbs=scene["rgb"].unflatten(1,imsize).clip(-1,1)*.5+.5
    Ks=scene["intrinsics"][:1,:3,:3]
    Ks[:,0]*=imsize[1]
    Ks[:,1]*=imsize[0]

    params = [
        # name, value, lr
        ("means3d",   means,    1.6e-4),
        ("scales",    scales,    1e-3),
        ("quats",     quats,     1e-3),
        ("opacities", opacities, 5e-2),
        ("lie_perpix", lie_perpix, 1e-5),
    ]
    optimizers = [
        torch.optim.Adam(
            [{"params": param, "lr": lr*1e-1, "name": name}],
        )
        for name, param, lr in params
    ]

    # Mark weird neg depth from tracks, not sure exactly where coming from
    bad_pt = torch.zeros_like(means[...,0])
    for timestep in range(len(gt_rgbs)):
        pose_perpix = get_pose_perpix(lie_perpix)
        pose_perpix = pose_perpix[[timestep]].inverse().expand(len(scene["pose_perpix"]),-1,-1,-1,-1)#@scene["pose_perpix"]
        means_i = torch.einsum("pij,pj->pi",pose_perpix.flatten(0,2),models.hom(means))[...,:3]
        bad_pt = torch.maximum(bad_pt,means_i[...,2]<.1)

    def do_render(pose,timestep,imsize,K,colors_=None):
        if colors_ is None: colors_=colors
        pose_perpix = get_pose_perpix(lie_perpix)
        pose_perpix = pose_perpix[[timestep]].inverse().expand(len(scene["pose_perpix"]),-1,-1,-1,-1)#@scene["pose_perpix"]
        means_i = torch.einsum("pij,pj->pi",pose_perpix.flatten(0,2),models.hom(means))[...,:3]
        means_i=torch.einsum("ij,kj->ki",pose,models.hom(means_i))[...,:3]
        means_i=torch.where(bad_pt[:,None].expand(-1,3).bool(),means_i*0+torch.tensor([1,1,1]).cuda()*1000,means_i)
        quats_i=kornia.geometry.conversions.rotation_matrix_to_quaternion(pose[:3,:3])[None]*quats
        return rasterization( means_i, quats_i, scales.clip(max=.1), opacities, colors_, torch.eye(4).cuda()[None], K, imsize[1], imsize[0],render_mode="RGB+D",)

    def viewer_render_fn(camera_state, img_wh):
        with torch.no_grad():
            try: colors_={"rgb":colors,"lie_rot":lie_perpix.flatten(0,1)[...,:3].clip(0,1),"lie_trans":(lie_perpix.flatten(0,1)[...,-3:]*2+.5).clip(0,1)}[run.viser_server.gui_color_choice.value]
            except:
                print("error in color set")
                colors_=colors
            viewmat = torch.from_numpy(camera_state.c2w).float().cuda().inverse()
            K = torch.from_numpy(camera_state.get_K(img_wh)).float().cuda()
            try: timestep=int(run.viser_server.gui_timestep.value)#viewmat = viewmat @ scene["poses"][int(run.viser_server.gui_timestep.value)].inverse()# move to canonical timestep based on gui slider
            except:timestep=0#pass #skipping canon timestep move

            #locking_ # todo make option to use this or not -- basically make transformation relative to any point but ideally a static point, could think about median transformation or pca on transformations
            #locking_pose=scene["poses"]#get_pose_perpix(torch.median(lie_perpix[timestep],dim=0)[0][None,None].expand(*lie_perpix.shape[:2],-1))[timestep,10]
            #locking_pose=get_pose_perpix(lie_perpix)[timestep,10]
            #viewmat = viewmat @ locking_pose

            render, alphas, meta = do_render(viewmat,timestep,img_wh[::-1],K[None],colors_=colors_)
            return render[0,...,:3].cpu().numpy()
    if run.args.viser: nerfview.Viewer( server=run.viser_server, render_fn=viewer_render_fn, mode="rendering",)

    # Train loop
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        # Render
        if step<1000 or 0:
            view_i=np.random.randint(len(gt_rgbs))
            ground_truth,model_input=scene,scene

            #print(scales.max())
            #print("doing rendering")
            #render, alphas, meta = do_render(scene["poses"][view_i].inverse(),imsize,Ks)
            render, alphas, meta = do_render(torch.eye(4).cuda(),view_i,imsize,Ks)
            out ={ "rgb":render[0,...,:3], "depth":render[0,...,3], "alphas":render[0,...,3], "gt_rgb":gt_rgbs[view_i], }
            #import matplotlib.pyplot as plt 
            #plt.imsave("/home/cameronsmith/tmp.png",out["depth"].detach().cpu())
            #from pdb import set_trace as pdb_;pdb_() 
            #print("done rendering")

            # Generate an optical flow per-pixel and render it as rgb channels TODO render once using N-D splat instead of doing again below
            if view_i!=0:
                pose_perpix = get_pose_perpix(lie_perpix)
                pos_i   = torch.einsum("pij,pj->pi",pose_perpix[[view_i]].inverse().expand(len(scene["pose_perpix"]),-1,-1,-1,-1).flatten(0,2),models.hom(means))[...,:3]
                pos_adj = torch.einsum("pij,pj->pi",pose_perpix[[view_i-1]].inverse().expand(len(scene["pose_perpix"]),-1,-1,-1,-1).flatten(0,2),models.hom(means))[...,:3]
                pos_i_2d   = models.project(pos_i  [None,None],model_input["intrinsics"][:1,None])[0,0]
                pos_adj_2d = models.project(pos_adj[None,None],model_input["intrinsics"][:1,None])[0,0]
                flow_2d = pos_adj_2d-pos_i_2d
                render_flow = do_render(torch.eye(4).cuda(),view_i,imsize,Ks,colors_=flow_2d)[0][...,:2]
                out["render_flow"]=render_flow

            # Calculate losses
            total_loss = 0.
            losses = splat_loss_fn(out, ground_truth, model_input,None,step,view_i)
            for loss_name, loss in losses.items():
                wandb.log({loss_name: loss.item()}, step=step)
                total_loss += loss

            wandb.log({"trainer/global_step": step}, step=step)
            wandb.log({"loss": total_loss.item()}, step=step)
            #wandb.log({"max_scale": scales.max().item()}, step=step)

            total_loss.backward(); 
            for optimizer in optimizers:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            # Image summaries and checkpoint
            losses_agg.append({k:v.detach().item() for k,v in losses.items()})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%100==0: wandb_imgs=vis_scripts.wandb_summary_splat( 0, out, model_input, ground_truth, None,step=step,view_i=view_i)
            if run.args.viser:vis_scripts.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)

        if step%until_save == 0 and step and 0: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

if __name__ == '__main__':
    run = args_setup.make_run()
    #torch.autograd.set_detect_anomaly(True)
    torch.autograd.detect_anomaly()
    train(run,None,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
