import os
import torch,wandb
from tqdm import trange
from einops import rearrange
import args_setup,vis_scripts,geometry
import numpy as np
from copy import deepcopy
import piqa
from torchvision.utils import make_grid
import models
from torch.cuda.amp import autocast, GradScaler
import kornia
import nerfview
from gsplat import rasterization

def splat_loss_fn(model_out, gt, model_input,model,step):

    losses={}

    losses["metrics/rgb"]=(model_out["rgb"]-model_out["gt_rgb"]).square().mean()

    print(losses)
    return losses

def train(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):
    losses_agg=[]

    scene = torch.load(run.args.splat_src)
    stride=1

    means =   torch.nn.Parameter(scene["world_crds"].flatten(0,1)[::stride])
    colors=   torch.nn.Parameter(scene["rgb"].flatten(0,1)[::stride]*.5+.5 )
    quats =   torch.nn.Parameter(torch.ones(len(means),4).cuda() )
    opacities=torch.nn.Parameter(torch.ones(len(means)).cuda()*.1 )
    scales=   torch.nn.Parameter(torch.ones(len(means),3).cuda()*.01 )

    imsize=scene["flow_inp_"].shape[-2:]
    gt_rgbs=scene["rgb"].unflatten(1,imsize).clip(-1,1)*.5+.5
    Ks=scene["intrinsics"][:1,:3,:3]
    Ks[:,0]*=imsize[1]
    Ks[:,1]*=imsize[0]

    params = [
        # name, value, lr
        ("means3d",   means,    1.6e-4),
        ("scales",    scales,    1e-3),
        ("quats",     quats,     1e-3),
        ("opacities", opacities, 5e-2),
    ]
    optimizers = [
        torch.optim.Adam(
            [{"params": param, "lr": lr*1e-1, "name": name}],
        )
        for name, param, lr in params
    ]

    def do_render(pose,imsize,K):
        means_i=torch.einsum("ij,kj->ki",pose,models.hom(means))[...,:3]
        quats_i=kornia.geometry.conversions.rotation_matrix_to_quaternion(pose[:3,:3])[None]*quats
        return rasterization( means_i, quats_i, scales.clip(max=.02), opacities, colors, torch.eye(4).cuda()[None], K, imsize[1], imsize[0],render_mode="RGB+ED",)

    # TODO still: add plots to viser viewer (plot loss and images and poses), optimize poses, use default opt/prune
    def viewer_render_fn(camera_state, img_wh):
        with torch.no_grad():
            width, height = img_wh
            c2w = camera_state.c2w
            K = camera_state.get_K(img_wh)
            c2w = torch.from_numpy(c2w).float().cuda()
            K = torch.from_numpy(K).float().cuda()
            viewmat = c2w.inverse()
            render, alphas, meta = do_render(viewmat,img_wh[::-1],K[None])
            return render[0,...,:3].cpu().numpy()

    if run.args.viser:
        nerfview.Viewer( server=run.viser_server, render_fn=viewer_render_fn, mode="rendering",)

    # Train loop
    backprop_every = 1
    for step in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        # Render
        view_i=np.random.randint(len(gt_rgbs))
        #view_i=1

        print(scales.max())
        print("doing rendering")

        render, alphas, meta = do_render(scene["poses"][view_i].inverse(),imsize,Ks)
        print("done rendering")
        out ={
            "rgb":render[0,...,:3],
            "depth":render[0,...,3],
            "alphas":render[0,...,3],
            "gt_rgb":gt_rgbs[view_i],
        }

        # Calculate losses
        total_loss = 0.
        ground_truth,model_input=scene,scene
        losses = splat_loss_fn(out, ground_truth, model_input,None,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss

        wandb.log({"trainer/global_step": step}, step=step)
        wandb.log({"loss": total_loss.item()}, step=step)
        #wandb.log({"max_scale": scales.max().item()}, step=step)

        total_loss.backward(); 
        if step%backprop_every==0 or 1:
            for optimizer in optimizers:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

        # Image summaries and checkpoint
        losses_agg.append({k:v.detach().item() for k,v in losses.items()})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%100==0: wandb_imgs=vis_scripts.wandb_summary_splat( 0, out, model_input, ground_truth, None,step=step)
            #if run.args.viser:vis_scripts.viser_update(run.viser_server, losses_agg, out, model_input, ground_truth, None,step=step,wandb_imgs=wandb_imgs)

        if step%until_save == 0 and step and 0: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint_{step}.pt")) 

if __name__ == '__main__':
    run = args_setup.make_run()
    #torch.autograd.set_detect_anomaly(True)
    torch.autograd.detect_anomaly()
    train(run,None,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
