import os,io,shutil
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
import torch.nn.functional as F
import numpy as np
import torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import piqa
import imageio
from PIL import Image
from collections import defaultdict
import data

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[],dont_log=False):
    #model_output,model_input,ground_truth = [{k:(v[:1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output,model_input,ground_truth)]

    wandb_out = {}

    # Moons regression dataset plot
    for label,src in [("GT",model_input),("Pred",model_output)]:
        fig = plt.figure();
        colors=["red" if x>.5 else "blue" for x in src["y"].squeeze(-1).squeeze(-1)]
        plt.scatter(*model_input["x"].squeeze(1).cpu().numpy().T,c=colors)
        plt.legend()
        try:
            plt.savefig("output/img/tmp.png");plt.close();
            img_arr=plt.imread("output/img/tmp.png")
            wandb_out["regression_%s"%label] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
        except: print("failed saving/logging plot image")
        plt.close()


    # Lin regression dataset plot
    #fig = plt.figure();
    #plt.scatter(model_input["x"].squeeze(1).cpu().numpy(),model_input["y"] .squeeze(1).cpu().numpy(),c="red",label="GT")
    #plt.scatter(model_input["x"].squeeze(1).cpu().numpy(),model_output["y"].squeeze(1).detach().cpu().numpy(),c="blue",label="Pred")
    #plt.legend()
    #try:
    #    plt.savefig("output/img/tmp.png");plt.close();
    #    img_arr=plt.imread("output/img/tmp.png")
    #    wandb_out["regression"] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
    #except: print("failed saving/logging plot image")
    #plt.close()

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        torch.save((model_input,model_output),"output/inpout.pt")
        zz

    for k,v in wandb_out.items():print(k,v.shape)
    wandb_imgdict = {prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
    if not dont_log:wandb.log(wandb_imgdict)
    print("done logging images")
    return wandb_imgdict
