import torch
import matplotlib.pyplot as plt 
import cv2
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid
from glob import glob
from einops import rearrange, repeat
import flow_vis_torch

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def compute_sampson_error(x1, x2, F):
    """
    :param x1 (*, N, 2)
    :param x2 (*, N, 2)
    :param F (*, 3, 3)
    """
    h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
    h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
    d1 = torch.matmul(h1, F.transpose(-1, -2))  # (B, N, 3)
    d2 = torch.matmul(h2, F)  # (B, N, 3)
    z = (h2 * d1).sum(dim=-1)  # (B, N)
    err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
    return err

#flowpath in tqdm(glob("/data/cameron/monocular_ests/*/bwd_flow.pt")):
scene="horns"
flowpath = f"/data/cameron/monocular_ests/{scene}/pred_tracks_more.pt"
tracks,pred_vis=list(torch.load(flowpath))
query_f=6
tracks=tracks[query_f,0]
pred_vis=pred_vis[query_f,0]

# Plot tracks as flow image
from pdb import set_trace as pdb_;pdb_() 
sl=64
low_res=(sl,sl)
uv = torch.from_numpy(np.flip(np.mgrid[0 : low_res[0], 0 : low_res[1]].astype(float).transpose(1, 2, 0), axis=-1).copy()).long() / torch.tensor([low_res[1]-1, low_res[0]-1])  # uv in [0,1]
#track_unp = lambda x: rearrange(x,"t (x y s) c -> (t s) c x y",y=sl,x=sl)
flow_img = flow_vis_torch.flow_to_color(make_grid((ch_fst(tracks*pred_vis.unsqueeze(-1),sl) - uv.permute(2,0,1)[None])*ch_fst(pred_vis.unsqueeze(-1),sl),nrow=1000))/255
plt.imsave("/home/cameronsmith/repos/tmp/vis/flow_vis.png",flow_img.permute(1,2,0).numpy())

errs=[]
i=0
for j in tqdm(range(len(tracks))):
    x1=tracks[i]
    x2=tracks[j]
    valid_mask=torch.minimum(pred_vis[i],pred_vis[j])
    F = torch.from_numpy(cv2.findFundamentalMat(x1[valid_mask].numpy(), x2[valid_mask].numpy(), cv2.FM_LMEDS)[0].astype(np.float32))  # (3, 3)
    errs.append( compute_sampson_error(x1, x2, F) )
errs=torch.stack(errs)
errs=errs*pred_vis
err_img = make_grid(ch_fst(errs.unsqueeze(-1),sl),nrow=1000)[0]
plt.imsave("/home/cameronsmith/repos/tmp/vis/err.png",err_img.numpy())
track_err_agg = (errs * pred_vis).max(dim=0)[0]
plt.imsave("/home/cameronsmith/repos/tmp/vis/agg_err.png",track_err_agg.unflatten(0,(sl,sl)).numpy())
print(torch.quantile(errs, 0.9))
torch.save(track_err_agg,f"tmperrs/{scene}.pt")
#flow_err_grid=make_grid(torch.stack(flow_errs)[:,None]).permute(1,2,0)[...,0]
#print(torch.quantile(torch.stack(flow_errs), 0.9))
#plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_900.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.9))
