import torch

from glob import glob
from matplotlib import cm
import os,sys
import matplotlib.pyplot as plt 
from pathlib import Path
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm
import torchvision
import flow_vis_torch
import cv2

import numpy as np

torch.set_grad_enabled(False)

#sys.path.append("/home/cameronsmith/repos/Video-Depth-Anything/")
#from video_depth_anything.video_depth import VideoDepthAnything
#model_configs = {
#    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
#    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
#}
#video_depth_anything = VideoDepthAnything(**model_configs["vitl"])
#video_depth_anything.load_state_dict(torch.load(f'/home/cameronsmith/repos/Video-Depth-Anything/checkpoints/video_depth_anything_vitl.pth', map_location='cpu'), strict=True)
#video_depth_anything = video_depth_anything.cuda().eval()

from geocalib import GeoCalib
device = "cuda" if torch.cuda.is_available() else "cpu"
geocalibmodel = GeoCalib().to(device)

sys.path.append("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/")
from gmflow.gmflow import GMFlow
gm_flow = GMFlow(feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6,).cuda()
checkpoint = torch.load("/home/cameronsmith/repos/refactor_flowmap_and_splat/gmflow/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth")
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
gm_flow.load_state_dict(weights, strict=False)
for param in gm_flow.parameters(): param.requires_grad = False

from cotracker.utils.visualizer import Visualizer
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").cuda()

# up now create models up top and loop over directories in here instead of the revererse
outdir=Path("/data/cameron/monocular_ests/test")

for input_dir in glob("/data/cameron/LBM11/flattened_robotdirs_wrist/*/*"):
    outdir="/data/cameron/monocular_ests/robot_real_moredirs/"+input_dir.replace("/","-")
#for input_dir in ["/data/cameron/LBM11/BimanualHangMugsOnMugHolderFromDryingRack_riverway_real//BimanualHangMugsOnMugHolderFromDryingRack_riverway_real_2024-09-30T15-22-31-04-00_000066/rgb/wrist_right_minus/"]:
#for input_dir in glob("/data/cameron/LBM11/*/*/rgb/wrist_right_minus/"):
#    outdir="/data/cameron/monocular_ests/robot_grippercam_test/"+input_dir.replace("/","-")
    os.makedirs(outdir,exist_ok = True)
    outdir=Path(outdir)

    max_i=40
    #image_paths = sorted(list(glob(str("/data/cameron/LBM11/BimanualHangMugsOnMugHolderFromDryingRack_riverway_real/BimanualHangMugsOnMugHolderFromDryingRack_riverway_real_2024-09-30T15-22-31-04-00_000000/rgb/scene_left/*.jpg"))))[:max_i]
    image_paths = sorted(list(glob(str(input_dir)+"/*.jpg")))[:max_i]
    if len(image_paths)<max_i:continue
    #image_paths = sorted(list(glob(str("/data/DAVIS/1080p/bear/*.jpg"))))[:max_i]

    images = [torchvision.transforms.ToTensor()(Image.open(path)) for path in tqdm(image_paths, desc="Loading images")]
    miny,minx=min([x.size(1) for x in images]),min([x.size(2) for x in images])
    images = [F.interpolate(x[None],(miny,minx))[0] for x in images]
    images = torch.stack( images )
    if images.size(-1)>1000 or images.size(-2)>1000: images= F.interpolate(images,[640, 1024][::[-1,1][images.size(-1)>images.size(-2)]],mode="bilinear")
    images= F.interpolate(images,(images.size(-2)//16 * 16,images.size(-1)//16 * 16),mode="bilinear")
    images=images.cuda()*255
    if not os.path.exists(outdir/"imgs.pt") or 0: torch.save(images.cpu(),outdir/"imgs.pt")

    if 0:
        if not os.path.exists(outdir/"depth_ests.pt") or 0:
            depth_ests, fps = video_depth_anything.infer_video_depth(images.to(torch.uint8).permute(0,2,3,1).cpu().numpy(), 24, input_size=518, device="cuda", fp32=True)
            depth_ests=torch.from_numpy(1/(1e-3+depth_ests))
            torch.save(depth_ests,outdir/"depth_ests.pt")
            depth_vis=1/(1e-2+torchvision.utils.make_grid(depth_ests[[0,-1],None],pad_value=depth_ests.median())[0].numpy())
            plt.imsave(outdir/"depth_vis.png",depth_vis)
            print("Saved depth vis to ",outdir/"depth_vis.png")

    if not os.path.exists(outdir/"intrinsics.pt") or 0:  # geocalib
        image = geocalibmodel.load_image(image_paths[0]).cuda()
        result = geocalibmodel.calibrate(image, camera_model="pinhole")
        focal = result["camera"].f[0,0]/np.array(Image.open(image_paths[0])).shape[1]
        torch.save(focal.cpu(),outdir/"intrinsics.pt")

    # Optical flow (raft)
    if not os.path.exists(outdir/"bwd_flow.pt") or 0:  

        bwd_flow=[]
        for x,y in tqdm(zip(images[1:],images[:-1]),"Optical flow est"):
            bwd_flow_=gm_flow(x[None],y[None],attn_splits_list=[2],corr_radius_list=[-1],prop_radius_list=[-1],pred_bidir_flow=False)["flow_preds"][-1]
            bwd_flow.append( bwd_flow_/(torch.tensor(images.shape[-2:][::-1])-1).to(images)[None,:,None,None] )# normalize flow coordinates to [-1,1]
        bwd_flow=torch.cat(bwd_flow)

        # save flow and vis
        torch.save(bwd_flow.cpu(),outdir/"bwd_flow.pt")
        flow_vis = flow_vis_torch.flow_to_color(torchvision.utils.make_grid(bwd_flow[[0,-1]])).permute(1,2,0).cpu().numpy()/255
        plt.imsave(outdir/"flow_vis.png",flow_vis)
        print("Saved flow vis to ",outdir/"flow_vis.png")
    else:
        bwd_flow = torch.load(outdir/"bwd_flow.pt").cuda()[:max_i-1]

    # Flow rigid thresholding for nonrigid motion approximation
    if not os.path.exists(outdir/"rig_flow_masks.pt") or 0: 
        flow_errs=[]

        def compute_sampson_error(x1, x2, F_):
            """
            :param x1 (*, N, 2)
            :param x2 (*, N, 2)
            :param F (*, 3, 3)
            """
            h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
            h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
            d1 = torch.matmul(h1, F_.transpose(-1, -2))  # (B, N, 3)
            d2 = torch.matmul(h2, F_)  # (B, N, 3)
            z = (h2 * d1).sum(dim=-1)  # (B, N)
            err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
            return err

        bwd_flow_low=torch.nn.functional.interpolate(bwd_flow,scale_factor=.25,mode="bilinear")

        H, W = bwd_flow_low[0].shape[-2:]
        device="cpu"
        yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing="ij",)
        xx = 2 * (xx + 0.5) / W - 1
        yy = 2 * (yy + 0.5) / H - 1
        x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)
        errs=[]
        masks_all=[]
        for i in tqdm(range(len(bwd_flow)),leave=False,desc="Estimating fundamental matrices"):
            flow=bwd_flow_low[i].permute(1,2,0).cpu()

            x2 = x1 + flow.flatten(0,1)  # (H*W, 2)

            F_ = cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_LMEDS)[0]  # (3, 3)
            errs.append( compute_sampson_error(x1, x2, torch.from_numpy(F_).float()) * ((H + W) / 2) ** 2 )
        thresh=torch.quantile(torch.stack(errs).flatten(),.82)
        for i,err in tqdm(enumerate(errs),leave=False):

            sp="/home/cameronsmith/tmp/tmp%02d%02d.png"%(i,0)
            #print(sp)
            #if i==13:from pdb import set_trace as pdb_;pdb_() 
            mask=(err<thresh).reshape(H,W)
            #plt.imsave("/home/cameronsmith/tmp/tmp%02d%02d.png"%(i,0),err.reshape(H,W))
            #plt.imsave(sp,mask)
            masks_all.append(mask[None])
        masks_all=torch.stack(masks_all)
        #mask_vis_path = outdir/"rig_mask_vis.png"
        #print("Saved rig mask vis to ",mask_vis_path)
        #plt.imsave(mask_vis_path,torchvision.utils.make_grid(masks_all).permute(1,2,0).float().numpy())
        torch.save(masks_all.cpu(),outdir/"rig_flow_masks.pt")
    else:
        masks_all = torch.load(outdir/"rig_flow_masks.pt").cuda()[:max_i-1]

    # Pointtracks (cotracker3)
    video = images[None]
    grid_size,n_track_frames=64,2 # use 64 and 8 frames for dense track e.g. for pretty 4d reconstruction, use 42 for faster processing for large-scale learning
    if not os.path.exists(outdir/"pred_tracks_offline.pt") or 0: # TODO use sky mask as depth>thresh to mask out sky points 

        pred_tracks_all,pred_visibility_all=[],[]
        # Offline cotracker (better but expensive for long videos)
        for i,start_frame in enumerate(tqdm(torch.linspace(0,len(images)-1, n_track_frames).int().tolist(),"Cotracker point tracking")):
            pred_tracks, pred_visibility = cotracker(images[None], grid_size=grid_size, backward_tracking=True, grid_query_frame=start_frame)
            #vis = Visualizer(save_dir=outdir, pad_value=120, linewidth=3).visualize(images[None], pred_tracks, pred_visibility,filename='tracks_%02d'%i,query_frame=start_frame)
            pred_tracks_norm = pred_tracks/(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()
            pred_tracks_all.append(pred_tracks_norm)
            pred_visibility_all.append(pred_visibility)
        pred_tracks_all,pred_visibility_all = torch.stack(pred_tracks_all),torch.stack(pred_visibility_all)
        torch.save((pred_tracks_all.cpu(),pred_visibility_all.cpu()),outdir/"pred_tracks_offline.pt")
    #else:
    #    pred_tracks_all,pred_visibility_all = [x[:max_i].cuda() for x in torch.load(outdir/"pred_tracks_more.pt")]
    #    #pred_tracks_all=pred_tracks_norm*(torch.tensor(images.shape[-2:][::-1])-1)[None,None,None].cuda()

