import os
import wandb
from matplotlib import cm
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import numpy as np
import torch
import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy()); 
from einops import rearrange, repeat
import piqa
import imageio

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(vis_dict, prefix="",suffix=""):

    wandb_out = {}

    if "sampled_gen" in vis_dict: wandb_out["est/sampled_gen"] =make_grid(vis_dict["sampled_gen"].cpu().detach()*.5+.5,normalize=False)
    if "model_est" in vis_dict: wandb_out["est/model_est"] = make_grid(vis_dict["model_est"].cpu().detach()*.5+.5,normalize=False)
    if "imgs_raw" in vis_dict: wandb_out["ref/imgs_raw"] = make_grid(vis_dict["imgs_raw"].cpu().detach(),normalize=True,scale_each=True)
    if "noisy_imgs" in vis_dict: wandb_out["ref/noisy_imgs"] = make_grid(vis_dict["noisy_imgs"].cpu().detach(),normalize=True,scale_each=True)

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
