import torch

from glob import glob
from matplotlib import cm
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

torch.set_grad_enabled(False)

import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
import models
#ckpt_file = "/tmp/dog_norigablation_nohydrantpretrain/checkpoint.pt"
#model_name=ckpt_file.split("/")[2]
model_names=["dog_norigablation","dog_norigablation_nohydrantpretrain","dog_juststatic_ablation","dogs_from_hydrant","dog_testing_scratchfpn"]

for model_name in model_names:
    err_all=0
    n_est=0
    for img_path in glob("/data/pets_co3d/dog/*/images/frame000001.jpg")[:100]:
        vidname = img_path.split("/")[4]
        outdir=Path(os.path.join("/data/tmp/cameron_mlprodepthests/", vidname))

        our_depth = torch.load( outdir/(model_name+"_depth_ests.pt") )
        gt_depth=F.interpolate(torch.load( outdir/"depth_ests.pt" )[None],our_depth.shape[-2:])[0,0]

        depth_map1,depth_map2=(1e-1+gt_depth).view(-1),(1e-1+our_depth).view(-1)
        A = torch.stack([depth_map2, torch.ones_like(depth_map2)], dim=1)
        solution = torch.linalg.lstsq(A,depth_map1[:,None]).solution
        scale, shift = solution[:,0]
        aligned_depth_map2 = scale * depth_map2 + shift
        err = (depth_map1-aligned_depth_map2).square().mean()
        err_all+=err
        n_est+=1

        plt.imsave("/home/cameronsmith/tmp.png",torch.cat((our_depth,gt_depth)).numpy())
        plt.imsave(outdir/(model_name+"_depth_comparison.png"),1/(1e-1+torch.cat((our_depth,gt_depth)).numpy()))
    print(model_name,err_all/n_est)
