import torch
from glob import glob
import os,sys
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)
outdir=Path(args.output_dir)

os.makedirs(outdir,exist_ok = True)

max_i=4000
image_paths = sorted(imgdir.iterdir())[::args.n_skip+1][:max_i]

images = torch.stack([ torchvision.transforms.ToTensor()(Image.open(path)) for path in tqdm(image_paths, desc="Loading images") ])
if images.size(-1)>1000 or images.size(-2)>1000: images= F.interpolate(images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
images= F.interpolate(images,(images.size(-2)+images.size(-2)%16,images.size(-1)+images.size(-1)%16),mode="bilinear")
images=images.cuda()*255
torch.save(images.cpu(),outdir/"imgs.pt")

# Pointtracks (cotracker3)
from cotracker.utils.visualizer import Visualizer
video = images[None]
grid_size=64
if not os.path.exists(outdir/"pred_tracks_more.pt") or 1: # TODO use sky mask as depth>thresh to mask out sky points 

    pred_tracks_all,pred_visibility_all=[],[]

    cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").cuda()
    # run cotracker on N frames
    #for i,start_frame in enumerate(tqdm(torch.linspace(cotracker.step+2,len(images)-(cotracker.step+2), 3).int().tolist(),"Cotracker point tracking")):
    for i,start_frame in enumerate(tqdm(torch.linspace(0,len(images)-cotracker.step, len(images)//8).int().tolist(),"Cotracker point tracking")):
        # manually separate into backward and forward videos since online tracker doesn't support backward tracking and start frame
        video_left,video_right = video[:,:start_frame],video[:,start_frame:]
        if video_left.size(1)<=cotracker.step: video_right=video
        if video_right.size(1)<=cotracker.step: video_left=video
        video_left=torch.flip(video_left,dims=[1])
        # left
        if video_left.size(1)>cotracker.step:
            cotracker(video_chunk=video_left, is_first_step=True, grid_size=grid_size)
            for ind in range(0, video_left.shape[1] - cotracker.step, cotracker.step):
                pred_tracks_left, pred_visibility_left = cotracker( video_chunk=video_left[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1
            pred_tracks_left,pred_visibility_left = torch.flip(pred_tracks_left,dims=[1]),torch.flip(pred_visibility_left,dims=[1])
        # right
        if video_right.size(1)>cotracker.step:
            cotracker(video_chunk=video_right, is_first_step=True, grid_size=grid_size)
            for ind in range(0, video_right.shape[1] - cotracker.step, cotracker.step):
                pred_tracks_right, pred_visibility_right = cotracker( video_chunk=video_right[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1

        if video_left.size(1)<=cotracker.step: pred_tracks, pred_visibility = pred_tracks_right, pred_visibility_right
        elif video_right.size(1)<=cotracker.step: pred_tracks,pred_visibility  = pred_tracks_left, pred_visibility_left
        else: pred_tracks,pred_visibility  = torch.cat((pred_tracks_left,pred_tracks_right),1),torch.cat((pred_visibility_left,pred_visibility_right),1)
        # concate them
        #print("saving vid")
        #if i%2==0:vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(video, pred_tracks, pred_visibility,filename='tracks_more_%02d'%i,query_frame=start_frame)
        #print("done vid")

        #print(pred_tracks.shape)
        pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
        pred_tracks_all.append(pred_tracks_norm)
        pred_visibility_all.append(pred_visibility)
    pred_tracks_all,pred_visibility_all = torch.stack(pred_tracks_all),torch.stack(pred_visibility_all)
    torch.save((pred_tracks_all.cpu(),pred_visibility_all.cpu()),outdir/"pred_tracks_more.pt")
else:
    pred_tracks_all,pred_visibility_all = [x[:max_i].cuda() for x in torch.load(outdir/"pred_tracks_more.pt")]
    #pred_tracks_all=pred_tracks_norm*(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()

# Depth maps and focal (ml-pro)
if not os.path.exists(outdir/"depth_ests.pt") or 0:
    import depth_pro
    depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri="./output/depth_pro.pt"
    model, transform = depth_pro.create_model_and_transforms()
    model.eval()
    model=model.cuda()
    depth_ests,focal_ests=[],[]
    for image_path in tqdm(image_paths,desc="Doing depth est"):
        image, _, f_px = depth_pro.load_rgb(image_path)
        image = transform(image)

        # Run inference.
        prediction = model.infer(image.cuda(), f_px=f_px)
        depth_ests.append(F.interpolate(prediction["depth"][None,None],images.shape[-2:],mode="bilinear")[0,0].cpu())
        focal_ests.append(torch.tensor([prediction["focallength_px"]]))

    depth_ests=torch.stack(depth_ests)
    focal_ests=torch.stack(focal_ests)
    focal = focal_ests.median().item()/image.size(-1)

    torch.save((depth_ests,focal),outdir/"depth_ests.pt")
    depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[:,None],pad_value=depth_ests.median())[0].numpy())
    plt.imsave(outdir/"depth_vis.png",depth_vis)
    print("Saved depth vis to ",outdir/"depth_vis.png")
else:
    depth_ests=torch.load(outdir/"depth_ests.pt")[0].cuda()[:max_i]

from matplotlib import cm
from einops import rearrange
pred_tracks_all_unnorm=pred_tracks_all*(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
pred_tracks_all_unnorm = rearrange(pred_tracks_all_unnorm,"s 1 t p c -> 1 t (p s) c")
pred_visibility_all = rearrange(pred_visibility_all,"s 1 t p -> 1 t (p s)")
tracks_vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).draw_tracks_on_video(video, pred_tracks_all_unnorm, pred_visibility_all)[0,-len(images):].cpu()
depth_vis=depth_ests[:,None].cpu().numpy()
depth_vis = torch.from_numpy(cm.get_cmap('magma')(depth_vis.min().item()/depth_vis)).squeeze(1).permute(0,3,1,2)*255
all_vis = F.interpolate(torch.cat((images.cpu(),depth_vis[:,:3],tracks_vis),-1),scale_factor=.5)
os.makedirs(outdir/"all_vis2",exist_ok=True)
for i,img in enumerate(all_vis): plt.imsave(outdir/("all_vis2/%03d.png"%i),img.permute(1,2,0).numpy()/255)
