import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts,geometry
import numpy as np
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
import models, geometry
from torch.cuda.amp import autocast, GradScaler
import kornia
import nerfview
from gsplat import rasterization

def splat_loss_fn(model_out, gt, model_input,model,step):

    losses={}

    losses["metrics/rgb"]=(model_out["rgb"]-model_out["gt_rgb"]).square().mean()

    print(losses)
    return losses

def train(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):
    losses_agg=[]

    scene = torch.load(run.args.splat_src)
    stride=1

    #scene["c2w"]=scene["gt_poses"][None]
    
    low_res=scene["flow_inp_"].shape[-2:]
    uv = np.mgrid[0 : low_res[0], 0 : low_res[1]].astype(float).transpose(1, 2, 0)
    uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
    uv = uv / torch.tensor([low_res[1]-1, low_res[0]-1])  # uv in [0,1]
    x_pix=uv[None].flatten(1,2)[None]
    rds = geometry.get_world_rays(x_pix[None].cuda(),scene["intrinsics"][None],None)[1][0,0]
    eye_surf=rds*scene["depth"]
    world_crds = torch.einsum("tpij,tpj->tpi",scene["pose_perpix"].flatten(1,2),models.hom(eye_surf))[...,:3]

    means =   torch.nn.Parameter(world_crds.flatten(0,1)[::stride])
    colors=   torch.nn.Parameter(scene["rgb"].flatten(0,1)[::stride]*.5+.5 )
    quats =   torch.nn.Parameter(torch.ones(len(means),4).cuda() )
    opacities=torch.nn.Parameter(torch.ones(len(means)).cuda()*.1 )
    scales=   torch.nn.Parameter(torch.ones(len(means),3).cuda()*.01 )

    imsize=scene["flow_inp_"].shape[-2:]
    gt_rgbs=scene["rgb"].unflatten(1,imsize).clip(-1,1)*.5+.5
    Ks=scene["intrinsics"][:1,:3,:3]
    Ks[:,0]*=imsize[1]
    Ks[:,1]*=imsize[0]

    params = [
        # name, value, lr
        ("means3d",   means,    1.6e-4),
        ("scales",    scales,    1e-3),
        ("quats",     quats,     1e-3),
        ("opacities", opacities, 5e-2),
    ]
    optimizers = [
        torch.optim.Adam(
            [{"params": param, "lr": lr*1e-1, "name": name}],
        )
        for name, param, lr in params
    ]

    def do_render(pose,timestep,imsize,K):
        pose_perpix = scene["pose_perpix"][[timestep]].inverse().expand(len(scene["pose_perpix"]),-1,-1,-1,-1)#@scene["pose_perpix"]
        means_i = torch.einsum("pij,pj->pi",pose_perpix.flatten(0,2),models.hom(means))[...,:3]
        means_i = torch.einsum("ij,kj->ki",pose,models.hom(means_i))[...,:3]
        quats_i=kornia.geometry.conversions.rotation_matrix_to_quaternion(pose[:3,:3])[None]*quats
        return rasterization( means_i, quats_i, scales.clip(max=.02), opacities, colors, torch.eye(4).cuda()[None], K, imsize[1], imsize[0],render_mode="RGB+ED",)

    # TODO still: add plots to viser viewer (plot loss and images and poses), optimize poses, use default opt/prune
    def viewer_render_fn(camera_state, img_wh):
        with torch.no_grad():
            width, height = img_wh
            c2w = camera_state.c2w
            K = camera_state.get_K(img_wh)
            c2w = torch.from_numpy(c2w).float().cuda()
            K = torch.from_numpy(K).float().cuda()
            viewmat = c2w.inverse()
            render, alphas, meta = do_render(viewmat,0,img_wh[::-1],K[None]) # todo add timestep from slider
            return render[0,...,:3].cpu().numpy()

    nerfview.Viewer(
        server=run.viser_server,
        render_fn=viewer_render_fn,
        mode="rendering",
    )

    # Train loop
    backprop_every = 1
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        # Render
        view_i=np.random.randint(len(gt_rgbs))
        #view_i=1

        print(scales.max())
        print("doing rendering")

        #render, alphas, meta = do_render(scene["poses"][view_i].inverse(),imsize,Ks)
        render, alphas, meta = do_render(torch.eye(4).cuda(),view_i,imsize,Ks)
        print("done rendering")
        out ={
            "rgb":render[0,...,:3],
            "depth":render[0,...,3],
            "alphas":render[0,...,3],
            "gt_rgb":gt_rgbs[view_i],
            "poses":scene["gt_poses"][None], # todo use opt poses when finished w that
        }

        # Calculate losses
        total_loss = 0.
        ground_truth,model_input=scene,scene
        losses = splat_loss_fn(out, ground_truth, model_input,None,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss

        wandb.log({"trainer/global_step": step}, step=step)
        wandb.log({"loss": total_loss.item()}, step=step)
        #wandb.log({"max_scale": scales.max().item()}, step=step)

        total_loss.backward(); 
        if step%backprop_every==0 or 1:
            for optimizer in optimizers:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

        # Image summaries and checkpoint
        losses_agg.append({k:v.detach().item() for k,v in losses.items()})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%100==0: wandb_imgs=vis_scripts.wandb_summary_splat( 0, out, model_input, ground_truth, None,step=step,losses_agg=losses_agg)
            if run.args.viser:vis_scripts.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)

        if step%until_save == 0 and step and 0: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

if __name__ == '__main__':
    run = args_setup.make_run()
    #torch.autograd.set_detect_anomaly(True)
    torch.autograd.detect_anomaly()
    train(run,None,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
