import torch
from torchvision.utils import make_grid,draw_keypoints

from glob import glob
from matplotlib import cm
from einops import rearrange, repeat
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

torch.set_grad_enabled(False)

import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
import models
model = models.FlowMap().cuda()
ckpt_file = "/data/cameron/checkpoints/dogs_from_hydrant/checkpoint.pt" #"/tmp/dog_norigablation_nohydrantpretrain/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
model_name=ckpt_file.split("/")[2]

#for img_path in glob("/data/pets_co3d/dog/*/images/frame000001.jpg")[:100]:
    #vidname = img_path.split("/")[4]
outdir=Path(os.path.join("/data/tmp/cameron_mlprodepthests/tmp"))

imsize=[76, 134][::1]
#path="/data/pets_co3d/dog/1053_46223_41938/images/frame000001.jpg")
path="/home/cameronsmith/babka.jpg"
images=F.interpolate( torch.from_numpy(plt.imread(path)).permute(2,0,1)[None].float()/255 * 2 -1 , imsize)

fmap_out = F.interpolate( model.img_enc(images.cuda() * 0.5 + 0.5), imsize, mode="bilinear",)
depth = torch.nn.functional.softplus( model.depth_conv(fmap_out) + 1)[0,0].cpu()+1 

affinity_emb_unnorm = model.affinities_conv(fmap_out)
affinity_emb = F.normalize(affinity_emb_unnorm, dim=1)

features=affinity_emb
B, C, H, W = features.shape
features = features.view(B, C, -1)
# Center the data
features_mean = features.mean(dim=2, keepdim=True)
features = features - features_mean
covariance = torch.bmm(features, features.transpose(1, 2)) / (H * W - 1)
# Perform SVD
U, S, V = torch.svd(covariance)
# Project the data onto the top principal components
num_components=6
transformed_features = torch.bmm(U[:, :, :num_components].transpose(1, 2), features)

aff_img = make_grid(rearrange(transformed_features[:,:3].detach(),"1 c (x y) -> 1 c x y",y=imsize[1],x=imsize[0]),normalize=True)
aff_img2= make_grid(rearrange(transformed_features[:,3:].detach(),"1 c (x y) -> 1 c x y",y=imsize[1],x=imsize[0]),normalize=True)

plt.imsave(outdir/(model_name+"_depth_vis.png"),1/(1e-1+depth))
plt.imsave(outdir/(model_name+"_aff_vis.png"),aff_img.permute(1,2,0).cpu().numpy())
plt.imsave(outdir/(model_name+"_aff_vis2.png"),aff_img2.permute(1,2,0).cpu().numpy())
plt.imsave(outdir/(model_name+"lowresimgvis.png"),images[0].permute(1,2,0).clip(-1,1).numpy()*.5+.5)
print(outdir/(model_name+"_depth_vis.png"))
