import torch

from glob import glob
from matplotlib import cm
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

import depth_pro
depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri="./output/depth_pro.pt"
model, transform = depth_pro.create_model_and_transforms()
model.eval()
model=model.cuda()

for img_path in glob("/data/pets_co3d/dog/*/images/frame000001.jpg")[:100]:

    vidname = img_path.split("/")[4]
    outdir=Path(os.path.join("/data/tmp/cameron_mlprodepthests/", vidname))
    os.makedirs(outdir,exist_ok = True)

    images = [torchvision.transforms.ToTensor()(Image.open(path)) for path in tqdm([img_path], desc="Loading images")]
    miny,minx=min([x.size(1) for x in images]),min([x.size(2) for x in images])
    images = [F.interpolate(x[None],(miny,minx))[0] for x in images]
    images = torch.stack( images )
    if images.size(-1)>1000 or images.size(-2)>1000: images= F.interpolate(images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
    images= F.interpolate(images,(images.size(-2)//16 * 16,images.size(-1)//16 * 16),mode="bilinear")
    images=images.cuda()*255

    shutil.copy(img_path,outdir/(vidname+"_img.png"))
    print(outdir)
    torch.save(images.cpu(),outdir/(vidname+"_imgs.pt"))

    # Depth maps and focal (ml-pro)
    
    depth_ests,focal_ests=[],[]
    image, _, f_px = depth_pro.load_rgb(img_path)
    image = transform(image)

    # Run inference.
    prediction = model.infer(image.cuda(), f_px=f_px)
    depth_ests.append(F.interpolate(prediction["depth"][None,None],images.shape[-2:],mode="bilinear")[0,0].cpu())

    depth_ests=torch.stack(depth_ests)

    torch.save(depth_ests,outdir/"depth_ests.pt")
    depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[:,None],pad_value=depth_ests.median())[0].numpy())
    plt.imsave(outdir/(vidname+"_depth_vis.png"),depth_vis)
    print("Saved depth vis to ",outdir/"depth_vis.png")
