import os,io,shutil
import geometry
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import models
import piqa
import imageio
from PIL import Image
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import geometry
import data

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[]):
    model_output,model_input,ground_truth = [{k:(v[:1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output,model_input,ground_truth)]

    wandb_out = {}

    # Log images
    for k,v in {k:v for k,v in model_input.items() if "img" in k}.items(): wandb_out["ref/"+k]= make_grid(v.flatten(0,1))

    # Plot tri trajs
    unit_tri = torch.tensor([ [0, 0, 0,1],  [1, 0, 0,1],  [0, 1, 0,1],  ]).float().numpy()
    if "pred_tris" not in model_output:
        if model_output["pred_rot"].size(-1)==6: rot = kornia.geometry.conversions.rotation_matrix_to_quaternion(geometry.rotation_6d_to_matrix(model_output["pred_rot"]))
        elif model_output["pred_rot"].size(-1)==4: rot = model_output["pred_rot"]
        model_output["pred_tris"]=geometry.rot_trans_to_tris(rot.squeeze(1),model_output["pred_trans"].squeeze(1))[0]

    # Add rotation and translation metrics here
    if "rot" not in model_output:
        model_output["transf"]=transf=geometry.tris_to_transf(model_output["pred_tris"])[0][None]
        model_output["rot"]=torch.stack(kornia.geometry.conversions.rotation_matrix_to_quaternion(transf[:,:3,:3]).unbind(-1),-1)
        model_output["trans"]=transf[...,:3,-1]
    if "pred_tris" not in model_output:
        model_output["pred_tris"]=geometry.rot_trans_to_tris(model_output["rot"],model_output["trans"])
        model_output["transf"]=transf=geometry.tris_to_transf(model_output["pred_tris"])[0][None]
    wandb.log({"metrics/trans_metric_"+prefix: (model_output["rot"]-model_input["rot"]).square().mean()}, step=step)
    wandb.log({"metrics/rot_metric"+prefix: (model_output["trans"]-model_input["trans"]).square().mean()}, step=step)
    wandb.log({"metrics/tri_vis_metric_"+prefix: (model_output["pred_tris"]-model_input["transf_tris"]).square().mean()}, step=step)

    for pred_tri,tri_name in [(model_output["pred_tris"],"pred_tris")]:
        pred_tri = pred_tri.detach().cpu().numpy()[0]
        gt_tri = model_input["transf_tris"].detach().cpu().numpy()[0]

        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        # Plot original triangle
        ax.plot_trisurf(unit_tri[:, 0], unit_tri[:, 1], unit_tri[:, 2], color='red', alpha=0.5)
        ax.plot_trisurf(pred_tri[:, 0], pred_tri[:, 1], pred_tri[:, 2], color='blue', alpha=0.5)
        ax.plot_trisurf(gt_tri[:, 0], gt_tri[:, 1], gt_tri[:, 2], color='green', alpha=0.5)
        # Set labels and title
        ax.set_xlabel('X') ; ax.set_ylabel('Y') ; ax.set_zlabel('Z') ; ax.set_title('Unit (Red) and Pred (Blue) and GT Transformed Tri (Green)')
        # Set equal aspect ratio
        ax.set_box_aspect([1, 1, 1])
        try:
            plt.savefig("output/img/tmp.png");plt.close();
            img_arr=plt.imread("output/img/tmp.png")
            wandb_out["ref/tri_"+tri_name] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
            print("logging images",len(wandb_out))
        except: print("failed saving/logging 3d plot image")
        plt.close()

    # rerender mesh with pred cam
    wandb_out["est/cam_rerender_est"] = data.render_cam(transf[0].cpu().numpy()).permute(2,0,1)

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    for k,v in wandb_out.items():print(k,v.shape)
    wandb.log({prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    print("done logging images")
    return wandb_out
