import torch

from glob import glob
from matplotlib import cm
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
from tqdm import tqdm
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

torch.set_grad_enabled(False)

import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
import models,data,geometry
def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

model = models.FlowMap().cuda()
ckpt_file = "/tmp/dog_testing_scratchfpn/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
model_name=ckpt_file.split("/")[2]

errs_all,num_est=[],0
for dogname in tqdm(glob("/data/cameron/monocular_ests/pets_dogs/*")[:100]):
    print("num est ",num_est)
    dogname=dogname.split("/")[-1]
    try: 
        dataset = data.ImageFolder(num_trgt=30,path="/data/cameron/monocular_ests/pets_dogs/%s/"%dogname)
        batch=to_gpu(dataset[0][0])
        batch={k:v[None] for k,v in batch.items()}
        gt_pos=torch.load("/data/cameron/monocular_ests/pets_dogs/%s/poses.pt"%dogname)[:batch["rgb"].size(1),:3,-1]
    except:print("skipping ",dogname)

    out=model.forward_allpts(batch)

    n_cluster=3
    pose_clusters, pose_labels = geometry.cluster_and_represent(out["poses_all"][0],n_clusters=n_cluster,return_labels=True) # cluster poses
    errs=[]
    for pos in pose_clusters[:,:,:3,-1].detach().cpu():
        gt_pos_aligned=geometry.numpy_procrustes(pos,gt_pos)[1]
        errs.append((gt_pos_aligned-pos).square().mean())
    errs_all.append(min(errs))
    num_est+=1
print(sum(errs_all)/num_est)
