import os,io,shutil
import geometry
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import models
import piqa
import imageio
from PIL import Image
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import viser.transforms as tf

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[]):
    model_output,model_input,ground_truth = [{k:(v[:1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output,model_input,ground_truth)]

    wandb_out = {}

    # Log images
    for k,v in {k:v for k,v in model_input.items() if "img" in k}.items(): wandb_out["ref/"+k]= make_grid(v.flatten(0,1))

    # Plot gripper trajs
    fig = plt.figure(dpi=100);ax = fig.add_subplot(111, projection='3d')
    for i,(k,gripper_traj) in enumerate([("GT_Gripper_Traj",model_input["gripper_traj"]),("Pred_Traj",model_output["pred_traj"])]):
        gripper_traj=gripper_traj[0].cpu()
        ax.plot(gripper_traj[:, 0], gripper_traj[:, 1], gripper_traj[:, 2], c=["red","blue","green"][i],label=k)
        print(gripper_traj.shape)
        plt.title(k)
        ax.set_xlim([min(min(gripper_traj[...,0]),.5), max(max(gripper_traj[...,0]),1.3)]); 
        ax.set_ylim([min(min(gripper_traj[...,1]),-.4),max(max(gripper_traj[...,1]),.4)]); 
        ax.set_zlim([min(min(gripper_traj[...,2]),.5), max(max(gripper_traj[...,2]),1.5)]); 
    ax.set_xlabel("X");ax.set_ylabel("Y");ax.set_zlabel("Z");
    plt.legend();plt.tight_layout();
    try:
        plt.savefig("output/img/tmp.png");plt.close();
        img_arr=plt.imread("output/img/tmp.png")
        wandb_out["ref/gripper_traj"] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
        print("logging images",len(wandb_out))
    except: print("failed saving/logging 3d plot image")
    plt.close()

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    for k,v in wandb_out.items():print(k,v.shape)
    wandb.log({prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    print("done logging images")
    return wandb_out
