import torch

from glob import glob
from matplotlib import cm
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
from tqdm import tqdm
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

torch.set_grad_enabled(False)

import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
import models,data,geometry
def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

errs_all,num_est=[],0
#for dogname in tqdm(glob("/data/cameron/monocular_ests/pets_dogs/*")[:100]):
    #dogname=dogname.split("/")[-1]
for dogname in glob("/home/cameronsmith/repos/mega-sam/outputs/*_droid.npz"):
    dogname=dogname.split("/")[-1].split("_droid")[0]
    print("num est ",num_est)
    pos=torch.from_numpy(np.load("/home/cameronsmith/repos/mega-sam/outputs/%s_droid.npz"%dogname)["cam_c2w"])[:,:3,-1]
    gt_pos=torch.load("/data/cameron/monocular_ests/pets_dogs/%s/poses.pt"%dogname)[:len(pos),:3,-1]
    gt_pos_aligned=geometry.numpy_procrustes(pos,gt_pos)[1]
    errs_all.append((gt_pos_aligned-pos).square().mean())
    print(errs_all)
    num_est+=1
print(sum(errs_all)/num_est)
