from PIL import Image
from glob import glob
import depth_pro
import matplotlib.pyplot as plt 
import torch

model, transform = depth_pro.create_model_and_transforms()
model.eval()
model=model.cuda()

# Load model and preprocessing transform
for dirs in ["/home/cameronsmith/robotics_demo/*","/data/nerf_llff_data/horns/images/*","/home/cameronsmith/demo_images/*","/data/DAVIS/1080p/*/*"]:
    for image_path in glob(dirs):
        print(image_path)
        # Load and preprocess an image.
        image, _, f_px = depth_pro.load_rgb(image_path)
        image = transform(image)

        # Run inference.
        prediction = model.infer(image.cuda(), f_px=f_px)
        savepath=image_path.replace("/","-")
        plt.imsave("/data/cameron_depth_data_storage/%s.png"%savepath,1/prediction["depth"].cpu().numpy())
        torch.save(prediction["depth"].cpu(),"/data/cameron_depth_data_storage/%s_focal_%04d.pt"%(savepath,int(prediction["focallength_px"].item())))
