import os
import wandb
from matplotlib import cm
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import numpy as np
import torch
import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy()); 
from einops import rearrange, repeat
import piqa
import imageio

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0):

    resolution = list(model_input["rgb"].permute(0,2,3,1).shape)
    resolution[0]=ground_truth["rgb"].size(1)*ground_truth["rgb"].size(0)
    nrow=8
    imsl=model_input["rgb"].shape[-2:]
    inv = lambda x : 1/(x+1e-8)

    # Convert depths to colormapped 3-channel images:
    for d in [model_output,model_input]:
        for k,v in list(d.items()): # magma colormap for depth
            if "latents" in k:continue
            if len(v.shape): v=v.clip(min=.01)
            if "depth" in k: d[k+"_raw"] = v
            if "depth" in k and "raw" not in k: 
                org= v.shape
                if len(org)==4:v=v.flatten(0,1)
                print(k,v.shape)
                v = torch.stack([torch.from_numpy(cm.get_cmap('magma')(v[i].min().item()/v[i].cpu().numpy())).squeeze(-2)[...,:3].permute(2,0,1)
                    for i in range(len(v))])
                if len(org)==4:v=v.unflatten(0,org[:2])
                d[k+"vis"]=v

    wandb_out = {}

    if "masks" in model_output: wandb_out["est/masks"]=make_grid(model_output["masks"].flatten(0,2).cpu().detach(),nrow=model_output["masks"].size(2))
    if "rgb" in model_output: wandb_out["est/rgb"]=make_grid(model_output["rgb"].flatten(0,1).cpu().detach(),nrow=model_output["rgb"].size(1))
    wandb_out["ref/rgb_gt"]= make_grid(ground_truth["rgb"].cpu().detach(),nrow=nrow)
    if "depthvis" in model_output: wandb_out["est/depth"]=make_grid(model_output["depthvis"].flatten(0,1).cpu().detach(),nrow=model_output["depth"].size(1))
    if "depth" in model_output: wandb_out["est/depth_raw"]=make_grid(model_output["depth"].flatten(0,1)[:,None].cpu().detach(),nrow=model_output["depth"].size(1),normalize=True)
    if "depth" in model_input:  wandb_out["ref/depth"]=make_grid(model_input["depth"][:,None].cpu().detach(),nrow=nrow)
    if "depth_mask" in model_input:  wandb_out["ref/depth_mask"]=make_grid(model_input["depth_mask"].unsqueeze(1).cpu().detach(),nrow=nrow,normalize=False)

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})

def pose_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
    # Log points and boxes in W&B
    point_scene = wandb.Object3D({
        "type": "lidar/beta",
        "points":  model_output["poses"][:,:3,-1].cpu().numpy(),
    })
    wandb.log({"camera positions": point_scene})


    
