import torch

from glob import glob
from matplotlib import cm
import os,sys
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="rgb files")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save files")
parser.add_argument('--n_skip', type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. ")
args = parser.parse_args()

imgdir=Path(args.input_dir)
outdir=Path(args.output_dir)

os.makedirs(outdir,exist_ok = True)

sys.path.append("/home/cameronsmith/repos/Video-Depth-Anything/")
from video_depth_anything.video_depth import VideoDepthAnything
model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
video_depth_anything = VideoDepthAnything(**model_configs["vitl"])
video_depth_anything.load_state_dict(torch.load(f'/home/cameronsmith/repos/Video-Depth-Anything/checkpoints/video_depth_anything_vitl.pth', map_location='cpu'), strict=True)
video_depth_anything = video_depth_anything.cuda().eval()

max_i=40
#image_paths = sorted(imgdir.iterdir())[:][::args.n_skip+1][:max_i]
image_paths = sorted(list(glob(str(args.input_dir+"/*.jpg"))))[:][::args.n_skip+1][:max_i]
images = [torchvision.transforms.ToTensor()(Image.open(path)) for path in tqdm(image_paths, desc="Loading images")]
miny,minx=min([x.size(1) for x in images]),min([x.size(2) for x in images])
images = [F.interpolate(x[None],(miny,minx))[0] for x in images]
images = torch.stack( images )
if images.size(-1)>1000 or images.size(-2)>1000: images= F.interpolate(images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
images= F.interpolate(images,(images.size(-2)//16 * 16,images.size(-1)//16 * 16),mode="bilinear")
images=images.cuda()*255
if not os.path.exists(outdir/"img.pt") or 0: torch.save(images.cpu(),outdir/"imgs.pt")

from geocalib import GeoCalib
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GeoCalib().to(device)
image = model.load_image(image_paths[0]).to(device)
result = model.calibrate(image, camera_model="pinhole")
focal = result["camera"].f[0,0]/np.array(Image.open(image_paths[0])).shape[1]
torch.save(focal.cpu(),outdir/"intrinsics.pt")

depth_ests, fps = video_depth_anything.infer_video_depth(images.to(torch.uint8).permute(0,2,3,1).cpu().numpy(), 24, input_size=518, device="cuda", fp32=True)
depth_ests=torch.from_numpy(1/(1e-3+depth_ests))
torch.save(depth_ests,outdir/"video_depth_ests.pt")
depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[[0,-1],None],pad_value=depth_ests.median())[0].numpy())
plt.imsave(outdir/"depth_vis.png",depth_vis)
print("Saved depth vis to ",outdir/"depth_vis.png")
zz

if "DAVIS" in args.input_dir and os.path.exists(args.input_dir.replace("1080p","Annotations")):
    seg_images = [torchvision.transforms.ToTensor()(plt.imread(path.replace("1080p","Annotations").replace("jpg","png")))[0] for path in tqdm(image_paths, desc="Loading images")]
    seg_images = [F.interpolate(x[None,None],(miny,minx))[0] for x in seg_images]
    seg_images = torch.stack( seg_images )
    if images.size(-1)>1000 or images.size(-2)>1000: seg_images= F.interpolate(seg_images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
    seg_images= F.interpolate(seg_images,(seg_images.size(-2)//16 * 16,seg_images.size(-1)//16 * 16),mode="bilinear")
    if not os.path.exists(outdir/"seg_img.pt") or 0: torch.save((seg_images!=0).cpu(),outdir/"seg_imgs.pt")

    # Dino features for baseline exp
    if not os.path.exists(outdir/"dino_feats.pt") or 0: 
        upsampler = torch.hub.load("mhamilton723/FeatUp", 'dinov2', use_norm=True).cuda()

        dino_feats=[]
        for img in tqdm(images,"dino feats est"):
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                dino_feats.append( F.interpolate(upsampler(F.interpolate(img[None],(torch.tensor(img.shape[1:])//14*14).tolist())/256),
                    (torch.tensor(img.shape[1:])//4).tolist())[0] )
        dino_feats=torch.stack(dino_feats)

        # pca
        n_components=8 # just use first 9 components, probably only need first 3 tbh
        data = dino_feats.permute(0,2,3,1).flatten(0,-2)
        data_mean = data.mean(dim=0, keepdim=True)
        centered_data = data - data_mean
        U, S, Vt = torch.linalg.svd(centered_data, full_matrices=False)
        reduced_data = centered_data @ Vt[:n_components].T
        reduced_data = reduced_data.unflatten(0,dino_feats.permute(0,2,3,1)[...,0].shape).permute(0,3,1,2)

        reduced_data = reduced_data / max(reduced_data.min(),reduced_data.max())

        torch.save(reduced_data,outdir/"dino_feats.pt")
        dino_vis=torchvision.utils.make_grid(reduced_data[:,:3],normalize=True).permute(1,2,0)
        plt.imsave(outdir/"dino_vis.png",dino_vis.cpu().numpy())
        print("Saved depth vis to ",outdir/"depth_vis.png")

# Camera intrinsics (geocalib or dataset providing)
if "Bimanual" in str(imgdir) and 0: # tri
    #focal = torch.tensor([np.load(str(imgdir).replace("rgb","intrinsics")+"/0000000000.npy")[0,0]/np.array(Image.open(image_paths[0])).shape[1]])
    focal = torch.tensor([np.load(image_paths[0].replace("jpg","npy"))[0,0]/np.array(Image.open(image_paths[0])).shape[1]])
    torch.save(focal.cpu(),outdir/"intrinsics.pt")
    # Also load in depths since we know it's TRI data
    if not os.path.exists(outdir/"depth.pt") or 0: 
        #depths = [torch.from_numpy(np.load(str(path).replace("rgb","depth").replace("jpg","npz"))["data"]) for path in tqdm(image_paths, desc="Loading depths")]
        depths = [torch.from_numpy(np.load(path.replace("jpg","npz"))["data"]) for path in tqdm(image_paths, desc="Loading depths")]
        depths = [F.interpolate(depth[None,None],(miny,minx))[0,0] for depth in depths]
        depths = torch.stack( depths )
        depths= F.interpolate(depths[:,None],(images.size(-2)//16 * 16,images.size(-1)//16 * 16),mode="bilinear").squeeze(1)

        depth_vis=depths[:,None].cpu().numpy()+1
        depth_vis = torch.from_numpy(cm.get_cmap('magma')(depth_vis.min().item()/depth_vis)).squeeze(1).permute(0,3,1,2)*255
        depth_vis = torchvision.utils.make_grid(depth_vis[[0,-1]])[:3].permute(1,2,0).cpu().numpy()/255
        plt.imsave(outdir/"depth_vis.png",depth_vis)

        torch.save(depths.cpu(),outdir/"depths.pt")
elif not os.path.exists(outdir/"intrinsics.pt") and 0:  # geocalib
    from geocalib import GeoCalib
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = GeoCalib().to(device)
    image = model.load_image(image_paths[0]).to(device)
    result = model.calibrate(image, camera_model="pinhole")
    focal = result["camera"].f[0,0]/np.array(Image.open(image_paths[0])).shape[1]
    torch.save(focal.cpu(),outdir/"intrinsics.pt")

# Optical flow (raft)
if not os.path.exists(outdir/"bwd_flow.pt") or 0: 

    sys.path.append("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/")
    from gmflow.gmflow import GMFlow
    gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
    checkpoint = torch.load("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
    weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
    gm_flow.load_state_dict(weights, strict=False)
    for param in gm_flow.parameters(): param.requires_grad = False

    #from torchvision.models.optical_flow import Raft_Large_Weights
    #raft_transforms = Raft_Large_Weights.DEFAULT.transforms()
    #from torchvision.models.optical_flow import raft_large
    #raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).cuda()

    bwd_flow=[]
    for x,y in tqdm(zip(images[1:],images[:-1]),"Optical flow est"):
        #raft_inp_x,raft_inp_y = raft_transforms(x[None].to(torch.uint8), y[None].to(torch.uint8))
        #bwd_flow_=tmp= raft(raft_inp_x,raft_inp_y,num_flow_updates=32)[-1]

        bwd_flow_=gm_flow(x[None],y[None],attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
        bwd_flow.append( bwd_flow_/(torch.tensor(images.shape[-2:][::-1])-1).to(images)[None,:,None,None] )# normalize flow coordinates to [-1,1]
    bwd_flow=torch.cat(bwd_flow)

    # save flow and vis
    torch.save(bwd_flow.cpu(),outdir/"bwd_flow.pt")
    flow_vis = flow_vis_torch.flow_to_color(torchvision.utils.make_grid(bwd_flow[[0,-1]])).permute(1,2,0).cpu().numpy()/255
    plt.imsave(outdir/"flow_vis.png",flow_vis)
    print("Saved flow vis to ",outdir/"flow_vis.png")
else:
    bwd_flow = torch.load(outdir/"bwd_flow.pt").cuda()[:max_i-1]

# Flow rigid thresholding for nonrigid motion approximation
if not os.path.exists(outdir/"rig_flow_masks.pt") or 0: 
    flow_errs=[]

    def compute_sampson_error(x1, x2, F_):
        """
        :param x1 (*, N, 2)
        :param x2 (*, N, 2)
        :param F (*, 3, 3)
        """
        h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
        h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
        d1 = torch.matmul(h1, F_.transpose(-1, -2))  # (B, N, 3)
        d2 = torch.matmul(h2, F_)  # (B, N, 3)
        z = (h2 * d1).sum(dim=-1)  # (B, N)
        err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
        return err

    bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.25,mode="bilinear")

    H, W = bwd_flow_low[0].shape[-2:]
    device="cpu"
    yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing="ij",)
    xx = 2 * (xx + 0.5) / W - 1
    yy = 2 * (yy + 0.5) / H - 1
    x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)
    errs=[]
    masks_all=[]
    for i in tqdm(range(len(bwd_flow)),leave=False,desc="Estimating fundamental matrices"):
        flow=bwd_flow_low[i].permute(1,2,0).cpu()

        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)

        F_ = cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_LMEDS)[0]  # (3, 3)
        errs.append( compute_sampson_error(x1, x2, torch.from_numpy(F_).float()) * ((H + W) / 2) ** 2 )
    thresh=torch.quantile(torch.stack(errs).flatten(),.82)
    for i,err in tqdm(enumerate(errs),leave=False):

        sp="/home/cameronsmith/tmp/tmp%02d%02d.png"%(i,0)
        #print(sp)
        #if i==13:from pdb import set_trace as pdb_;pdb_() 
        mask=(err<thresh).reshape(H,W)
        #plt.imsave("/home/cameronsmith/tmp/tmp%02d%02d.png"%(i,0),err.reshape(H,W))
        #plt.imsave(sp,mask)
        masks_all.append(mask[None])
    masks_all=torch.stack(masks_all)
    #mask_vis_path = outdir/"rig_mask_vis.png"
    #print("Saved rig mask vis to ",mask_vis_path)
    #plt.imsave(mask_vis_path,torchvision.utils.make_grid(masks_all).permute(1,2,0).float().numpy())
    torch.save(masks_all.cpu(),outdir/"rig_flow_masks.pt")
else:
    masks_all = torch.load(outdir/"rig_flow_masks.pt").cuda()[:max_i-1]

# Pointtracks (cotracker3)
from cotracker.utils.visualizer import Visualizer
video = images[None]
grid_size,n_track_frames=64,8 # use 64 and 8 frames for dense track e.g. for pretty 4d reconstruction, use 42 for faster processing for large-scale learning
if not os.path.exists(outdir/"pred_tracks_offline.pt") or 0: # TODO use sky mask as depth>thresh to mask out sky points 

    pred_tracks_all,pred_visibility_all=[],[]
    # Offline cotracker (better but expensive for long videos)
    cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").cuda()
    for i,start_frame in enumerate(tqdm(torch.linspace(0,len(images)-1, n_track_frames).int().tolist(),"Cotracker point tracking")):
        pred_tracks, pred_visibility = cotracker(images[None], grid_size=grid_size, backward_tracking=True, grid_query_frame=start_frame)
        #vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(images[None], pred_tracks, pred_visibility,filename='tracks_%02d'%i,query_frame=start_frame)
        pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
        pred_tracks_all.append(pred_tracks_norm)
        pred_visibility_all.append(pred_visibility)
    pred_tracks_all,pred_visibility_all = torch.stack(pred_tracks_all),torch.stack(pred_visibility_all)
    torch.save((pred_tracks_all.cpu(),pred_visibility_all.cpu()),outdir/"pred_tracks_offline.pt")

    #cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").cuda()
    ## run cotracker on N frames
    ##for i,start_frame in enumerate(tqdm(torch.linspace(cotracker.step+2,len(images)-(cotracker.step+2), 3).int().tolist(),"Cotracker point tracking")):
    #for i,start_frame in enumerate(tqdm(torch.linspace(0,len(images)-cotracker.step, len(images)//8).int().tolist(),"Cotracker point tracking")):
    #    # manually separate into backward and forward videos since online tracker doesn't support backward tracking and start frame
    #    video_left,video_right = video[:,:start_frame],video[:,start_frame:]
    #    if video_left.size(1)<=cotracker.step: video_right=video
    #    if video_right.size(1)<=cotracker.step: video_left=video
    #    video_left=torch.flip(video_left,dims=[1])
    #    # left
    #    if video_left.size(1)>cotracker.step:
    #        cotracker(video_chunk=video_left, is_first_step=True, grid_size=grid_size)
    #        for ind in range(0, video_left.shape[1] - cotracker.step, cotracker.step):
    #            pred_tracks_left, pred_visibility_left = cotracker( video_chunk=video_left[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1
    #        pred_tracks_left,pred_visibility_left = torch.flip(pred_tracks_left,dims=[1]),torch.flip(pred_visibility_left,dims=[1])
    #    # right
    #    if video_right.size(1)>cotracker.step:
    #        cotracker(video_chunk=video_right, is_first_step=True, grid_size=grid_size)
    #        for ind in range(0, video_right.shape[1] - cotracker.step, cotracker.step):
    #            pred_tracks_right, pred_visibility_right = cotracker( video_chunk=video_right[:, ind : ind + cotracker.step * 2])  # B T N 2,  B T N 1

    #    if video_left.size(1)<=cotracker.step: pred_tracks, pred_visibility = pred_tracks_right, pred_visibility_right
    #    elif video_right.size(1)<=cotracker.step: pred_tracks,pred_visibility  = pred_tracks_left, pred_visibility_left
    #    else: pred_tracks,pred_visibility  = torch.cat((pred_tracks_left,pred_tracks_right),1),torch.cat((pred_visibility_left,pred_visibility_right),1)
    #    # concate them
    #    #print("saving vid")
    #    #if i%2==0:vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(video, pred_tracks, pred_visibility,filename='tracks_more_%02d'%i,query_frame=start_frame)
    #    #print("done vid")

    #    #print(pred_tracks.shape)
    #    pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
    #    pred_tracks_all.append(pred_tracks_norm)
    #    pred_visibility_all.append(pred_visibility)
    #pred_tracks_all,pred_visibility_all = torch.stack(pred_tracks_all),torch.stack(pred_visibility_all)
    #torch.save((pred_tracks_all.cpu(),pred_visibility_all.cpu()),outdir/"pred_tracks_more.pt")
if 1:
    # Depth maps and focal (ml-pro)
    if not os.path.exists(outdir/"depth_ests.pt") or 0:
        import depth_pro
        depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri="./output/depth_pro.pt"
        model, transform = depth_pro.create_model_and_transforms()
        model.eval()
        model=model.cuda()
        depth_ests,focal_ests=[],[]
        for image_path in tqdm(image_paths,desc="Doing depth est"):
            image, _, f_px = depth_pro.load_rgb(image_path)
            image = transform(image)

            # Run inference.
            prediction = model.infer(image.cuda(), f_px=f_px)
            depth_ests.append(F.interpolate(prediction["depth"][None,None],images.shape[-2:],mode="bilinear")[0,0].cpu())
            focal_ests.append(torch.tensor([prediction["focallength_px"]]))

        depth_ests=torch.stack(depth_ests)
        focal_ests=torch.stack(focal_ests)
        focal = focal_ests.median().item()/image.size(-1)

        torch.save((depth_ests,focal),outdir/"depth_ests.pt")
        depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[:,None],pad_value=depth_ests.median())[0].numpy())
        plt.imsave(outdir/"depth_vis.png",depth_vis)
        print("Saved depth vis to ",outdir/"depth_vis.png")
    else:
        depth_ests=torch.load(outdir/"depth_ests.pt")[0].cuda()[:max_i]

