import torch
from glob import glob
import os,sys
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)
outdir=Path(args.output_dir)

os.makedirs(outdir,exist_ok = True)

max_i=4000
image_paths = sorted(imgdir.iterdir())[::args.n_skip+1][:max_i]

images = torch.stack([ torchvision.transforms.ToTensor()(Image.open(path)) for path in tqdm(image_paths, desc="Loading images") ])
if images.size(-1)>1000 or images.size(-2)>1000: images= F.interpolate(images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
images= F.interpolate(images,(images.size(-2)+images.size(-2)%16,images.size(-1)+images.size(-1)%16),mode="bilinear")
images=images.cuda()*255
torch.save(images.cpu(),outdir/"imgs.pt")

# Dino features
if not os.path.exists(outdir/"dino_feats.pt") or 0: 
    upsampler = torch.hub.load("mhamilton723/FeatUp", 'dinov2', use_norm=True).cuda()

    dino_feats=[]
    for img in tqdm(images,"dino feats est"):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            dino_feats.append( F.interpolate(upsampler(F.interpolate(img[None],(torch.tensor(img.shape[1:])//14*14).tolist())/256),
                (torch.tensor(img.shape[1:])//4).tolist())[0] )
    dino_feats=torch.stack(dino_feats)

    # pca
    n_components=9 # just use first 9 components, probably only need first 3 tbh
    data = dino_feats.permute(0,2,3,1).flatten(0,-2)
    data_mean = data.mean(dim=0, keepdim=True)
    centered_data = data - data_mean
    U, S, Vt = torch.linalg.svd(centered_data, full_matrices=False)
    reduced_data = centered_data @ Vt[:n_components].T
    reduced_data = reduced_data.unflatten(0,dino_feats.permute(0,2,3,1)[...,0].shape).permute(0,3,1,2)

    reduced_data = reduced_data / max(reduced_data.min(),reduced_data.max())
    #plt.imsave("/home/cameronsmith/tmp.png",reduced_data[0,...,:3].cpu().numpy()/13*.5+.5)
    #plt.imsave("/home/cameronsmith/tmp.png",reduced_data[0].cpu())

    torch.save(reduced_data,outdir/"dino_feats.pt")
    dino_vis=torchvision.utils.make_grid(reduced_data[:,:3],normalize=True).permute(1,2,0)
    plt.imsave(outdir/"dino_vis.png",dino_vis.cpu().numpy())
    print("Saved depth vis to ",outdir/"depth_vis.png")

# Optical flow (raft)
if not os.path.exists(outdir/"bwd_flow.pt") or 0: 

    sys.path.append("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/")
    from gmflow.gmflow import GMFlow
    gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
    checkpoint = torch.load("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
    weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
    gm_flow.load_state_dict(weights, strict=False)
    for param in gm_flow.parameters(): param.requires_grad = False

    #from torchvision.models.optical_flow import Raft_Large_Weights
    #raft_transforms = Raft_Large_Weights.DEFAULT.transforms()
    #from torchvision.models.optical_flow import raft_large
    #raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).cuda()

    bwd_flow=[]
    for x,y in tqdm(zip(images[1:],images[:-1]),"Optical flow est"):
        #raft_inp_x,raft_inp_y = raft_transforms(x[None].to(torch.uint8), y[None].to(torch.uint8))
        #bwd_flow_=tmp= raft(raft_inp_x,raft_inp_y,num_flow_updates=32)[-1]

        bwd_flow_=gm_flow(x[None],y[None],attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
        bwd_flow.append( bwd_flow_/(torch.tensor(images.shape[-2:][::-1])-1).to(images)[None,:,None,None] )# normalize flow coordinates to [-1,1]
    bwd_flow=torch.cat(bwd_flow)

    # save flow and vis
    torch.save(bwd_flow.cpu(),outdir/"bwd_flow.pt")
    flow_vis = flow_vis_torch.flow_to_color(torchvision.utils.make_grid(bwd_flow)).permute(1,2,0).cpu().numpy()/255
    plt.imsave(outdir/"flow_vis.png",flow_vis)
    print("Saved flow vis to ",outdir/"flow_vis.png")
else:
    bwd_flow = torch.load(outdir/"bwd_flow.pt").cuda()[:max_i-1]

# Flow rigid thresholding for nonrigid motion approximation
if not os.path.exists(outdir/"rig_flow_masks.pt") or 0: 
    flow_errs=[]

    def compute_sampson_error(x1, x2, F_):
        """
        :param x1 (*, N, 2)
        :param x2 (*, N, 2)
        :param F (*, 3, 3)
        """
        h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
        h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
        d1 = torch.matmul(h1, F_.transpose(-1, -2))  # (B, N, 3)
        d2 = torch.matmul(h2, F_)  # (B, N, 3)
        z = (h2 * d1).sum(dim=-1)  # (B, N)
        err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
        return err

    bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.25,mode="bilinear")
    #bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.1,mode="bilinear")

    H, W = bwd_flow_low[0].shape[-2:]
    device="cpu"
    yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing="ij",)
    xx = 2 * (xx + 0.5) / W - 1
    yy = 2 * (yy + 0.5) / H - 1
    x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)
    masks_all=[]
    for i in tqdm(range(len(bwd_flow)),leave=False):
        flow=bwd_flow_low[i].permute(1,2,0).cpu()

        masks=[]

        to_be_explained=torch.ones(H*W).bool()
        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)

        for j in range(4):
            F_ = cv2.findFundamentalMat(x1[to_be_explained].numpy(), x2[to_be_explained].numpy(), cv2.FM_LMEDS)[0]  # (3, 3)
            if F_ is None:
                print("F is none")
                mask=to_be_explained
            else:
                err = compute_sampson_error(x1, x2, torch.from_numpy(F_).float()) * ((H + W) / 2) ** 2
                #print(to_be_explained.sum())
                err[~to_be_explained]=0
                thresh=torch.quantile(err[to_be_explained], 0.8)

                sp="/home/cameronsmith/tmp/tmp%02d%02d.png"%(i,j)
                print(sp)
                #if i==13:from pdb import set_trace as pdb_;pdb_() 
                mask=((err<thresh)*to_be_explained).reshape(H,W)
                plt.imsave("/home/cameronsmith/tmp%02d%02d.png"%(i,j),err.reshape(H,W))
                masks.append(mask)
                #plt.imsave(sp,mask)

                to_be_explained=err>thresh

        #flow_errs.append(err)
        masks_all.append(torch.stack(masks))
    masks_all=torch.stack(masks_all)
    torch.save(masks_all.cpu(),outdir/"rig_flow_masks.pt")

#if not os.path.exists(outdir/"rig_flow_masks.pt") or 1: 
#    flow_errs=[]
#
#    def compute_sampson_error(x1, x2, F_):
#        """
#        :param x1 (*, N, 2)
#        :param x2 (*, N, 2)
#        :param F (*, 3, 3)
#        """
#        h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
#        h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
#        d1 = torch.matmul(h1, F_.transpose(-1, -2))  # (B, N, 3)
#        d2 = torch.matmul(h2, F_)  # (B, N, 3)
#        z = (h2 * d1).sum(dim=-1)  # (B, N)
#        err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
#        return err
#
#    bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.25,mode="bilinear")
#    #bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.1,mode="bilinear")
#    for i in tqdm(range(len(bwd_flow)),leave=False):
#        flow=bwd_flow_low[i].permute(1,2,0).cpu()
#
#        H, W = bwd_flow_low[0].shape[-2:]
#
#        device="cpu"
#        yy, xx = torch.meshgrid(
#            torch.arange(H, dtype=torch.float32, device=device),
#            torch.arange(W, dtype=torch.float32, device=device),
#            indexing="ij",
#        )
#        xx = 2 * (xx + 0.5) / W - 1
#        yy = 2 * (yy + 0.5) / H - 1
#        x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)
#
#        F_   = torch.zeros(3, 3, dtype=torch.float32)
#        err = torch.zeros(H, W, dtype=torch.float32)
#
#        #flow = flow.permute(1, 2, 0)  # (H, W, 2)
#        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)
#        F_, _ = cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_LMEDS)
#        F_ = torch.from_numpy(F_.astype(np.float32))  # (3, 3)
#        err = compute_sampson_error(x1, x2, F_).reshape(H, W)
#        fac = (H + W) / 2
#        err = err * fac ** 2
#
#        flow_errs.append(err)
#
#    flow_errs=torch.stack(flow_errs)
#    flow_thresh = torch.quantile(flow_errs, 0.85)
#    flow_rig_masks=torch.nn.functional.interpolate(flow_errs[:,None].float(),bwd_flow.shape[-2:],mode="bilinear")<flow_thresh
#    torch.save(flow_rig_masks.cpu(),outdir/"rig_flow_masks.pt")
#    flow_rig_vis=torchvision.utils.make_grid(flow_rig_masks)
#    plt.imsave(outdir/"rig_flow_vis.png",flow_rig_vis.permute(1,2,0).float().numpy())
#    print("Saved flow vis to ",outdir/"rig_flow_vis.png")

# Pointtracks (cotracker3)
from cotracker.utils.visualizer import Visualizer
video = images[None]
if not os.path.exists(outdir/"pred_tracks.pt") or 0: # TODO use sky mask as depth>thresh to mask out sky points 

    pred_tracks_all,pred_visibility_all=[],[]

    cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").cuda()
    # run cotracker on N frames
    for i,start_frame in enumerate(tqdm(torch.linspace(cotracker.step+2,len(images)-(cotracker.step+2), 3).int().tolist(),"Cotracker point tracking")):
        # manually separate into backward and forward videos since online tracker doesn't support backward tracking and start frame
        video_left,video_right = video[:,:start_frame],video[:,start_frame:]
        video_left=torch.flip(video_left,dims=[1])
        # left
        cotracker(video_chunk=video_left, is_first_step=True, grid_size=42)
        for ind in range(0, video_left.shape[1] - cotracker.step, cotracker.step):
            pred_tracks, pred_visibility = cotracker( video_chunk=video_left[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1
        pred_tracks_left,pred_visibility_left = torch.flip(pred_tracks,dims=[1]),torch.flip(pred_visibility,dims=[1])
        # right
        cotracker(video_chunk=video_right, is_first_step=True, grid_size=42)
        for ind in range(0, video_right.shape[1] - cotracker.step, cotracker.step):
            pred_tracks, pred_visibility = cotracker( video_chunk=video_right[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1
        pred_tracks = torch.cat((pred_tracks_left,pred_tracks),1)
        pred_visibility = torch.cat((pred_visibility_left,pred_visibility),1)
        # concate them
        vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(video, pred_tracks, pred_visibility,filename='tracks_%02d'%i,query_frame=start_frame)

        pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
        pred_tracks_all.append(pred_tracks_norm)
        pred_visibility_all.append(pred_visibility)
    # Offline cotracker (better but expensive for long videos)
    #cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").cuda()
    #for i,start_frame in enumerate(tqdm(torch.linspace(0,len(images)-1, 4).int().tolist(),"Cotracker point tracking")):
    #    pred_tracks, pred_visibility = cotracker(images[None], grid_size=42, backward_tracking=True, grid_query_frame=start_frame)
    #    vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(images[None], pred_tracks, pred_visibility,filename='tracks_%02d'%i,query_frame=start_frame)
    #    pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
    #    pred_tracks_all.append(pred_tracks_norm)
    #    pred_visibility_all.append(pred_visibility)
    pred_tracks_all,pred_visibility_all = torch.stack(pred_tracks_all),torch.stack(pred_visibility_all)
    torch.save((pred_tracks_all.cpu(),pred_visibility_all.cpu()),outdir/"pred_tracks.pt")
else:
    pred_tracks_norm,pred_visibility_all = [x[:max_i].cuda() for x in torch.load(outdir/"pred_tracks.pt")]
    pred_tracks_all=pred_tracks_norm*(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()

# Depth maps and focal (ml-pro)
if not os.path.exists(outdir/"depth_ests.pt") or 0:
    import depth_pro
    depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri="./output/depth_pro.pt"
    model, transform = depth_pro.create_model_and_transforms()
    model.eval()
    model=model.cuda()
    depth_ests,focal_ests=[],[]
    for image_path in tqdm(image_paths,desc="Doing depth est"):
        image, _, f_px = depth_pro.load_rgb(image_path)
        image = transform(image)

        # Run inference.
        prediction = model.infer(image.cuda(), f_px=f_px)
        depth_ests.append(F.interpolate(prediction["depth"][None,None],images.shape[-2:],mode="bilinear")[0,0].cpu())
        focal_ests.append(torch.tensor([prediction["focallength_px"]]))

    depth_ests=torch.stack(depth_ests)
    focal_ests=torch.stack(focal_ests)
    focal = focal_ests.median().item()/image.size(-1)

    torch.save((depth_ests,focal),outdir/"depth_ests.pt")
    depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[:,None],pad_value=depth_ests.median())[0].numpy())
    plt.imsave(outdir/"depth_vis.png",depth_vis)
    print("Saved depth vis to ",outdir/"depth_vis.png")
else:
    depth_ests=torch.load(outdir/"depth_ests.pt")[0].cuda()[:max_i]

from matplotlib import cm
tracks_vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).draw_tracks_on_video(video, pred_tracks_all[0], pred_visibility_all[0])[0,-len(images):].cpu()
flow_vis = flow_vis_torch.flow_to_color(torch.cat((torch.zeros_like(bwd_flow[:1]),bwd_flow))).cpu()
depth_vis=depth_ests[:,None].cpu().numpy()
depth_vis = torch.from_numpy(cm.get_cmap('magma')(depth_vis.min().item()/depth_vis)).squeeze(1).permute(0,3,1,2)*255
#rig_masks_vis=F.interpolate( (masks_all.permute(0,2,1,3).flatten(2,3))[:,None].expand(-1,3,-1,-1).float(), (images.size(-2),images.size(-1)*masks_all.size(1)),mode="bilinear" )*255
#rig_masks_vis = torch.cat((torch.zeros_like(rig_masks_vis[:1]),rig_masks_vis)).cpu()
#flow_thresh_vis=flow_rig_masks.expand(-1,3,-1,-1)*255
#flow_thresh_vis = torch.cat((torch.zeros_like(flow_thresh_vis[:1]),flow_thresh_vis)).cpu()
#dino_vis_large = F.interpolate(reduced_data[:,:3]/max(reduced_data[:,:3].min(),reduced_data[:,:3].max())*.5+.5,images.shape[-2:],mode="bilinear").clip(0,1)*255
#all_vis = F.interpolate(torch.cat((images.cpu(),depth_vis[:,:3],flow_vis,tracks_vis,dino_vis_large.cpu()),-1),scale_factor=.5)
#all_vis = F.interpolate(torch.cat((images.cpu(),depth_vis[:,:3],flow_vis,tracks_vis,rig_masks_vis),-1),scale_factor=.5)
all_vis = F.interpolate(torch.cat((images.cpu(),depth_vis[:,:3],flow_vis,tracks_vis),-1),scale_factor=.5)
os.makedirs(outdir/"all_vis",exist_ok=True)
for i,img in enumerate(all_vis): plt.imsave(outdir/("all_vis/%03d.png"%i),img.permute(1,2,0).numpy()/255)
