import torch

from glob import glob
from matplotlib import cm
import os,sys,shutil
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)

torch.set_grad_enabled(False)

import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
import models
model = models.FlowMap().cuda()
ckpt_file = "/tmp/dog_norigablation_nohydrantpretrain/checkpoint.pt"
model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
model_name=ckpt_file.split("/")[2]

for img_path in glob("/data/pets_co3d/dog/*/images/frame000001.jpg")[:100]:
    vidname = img_path.split("/")[4]
    outdir=Path(os.path.join("/data/tmp/cameron_mlprodepthests/", vidname))

    imsize=(76, 134)
    images=F.interpolate( torch.load( outdir/(vidname+"_imgs.pt") )/255 * 2 -1 , imsize)

    fmap_out = F.interpolate( model.img_enc(images.cuda() * 0.5 + 0.5), imsize, mode="bilinear",)
    depth = torch.nn.functional.softplus( model.depth_conv(fmap_out) + 1)[0,0].cpu()+1 

    torch.save(depth,outdir/(model_name+"_depth_ests.pt"))
    plt.imsave(outdir/(model_name+"_depth_vis.png"),1/(1e-1+depth))
    plt.imsave(outdir/(model_name+"lowresimgvis.png"),images[0].permute(1,2,0).clip(-1,1).numpy()*.5+.5)
    print(outdir/(model_name+"_depth_vis.png"))
