import torch
import matplotlib.pyplot as plt 
import cv2
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid
from glob import glob

def compute_sampson_error(x1, x2, F):
    """
    :param x1 (*, N, 2)
    :param x2 (*, N, 2)
    :param F (*, 3, 3)
    """
    h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
    h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
    d1 = torch.matmul(h1, F.transpose(-1, -2))  # (B, N, 3)
    d2 = torch.matmul(h2, F)  # (B, N, 3)
    z = (h2 * d1).sum(dim=-1)  # (B, N)
    err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
    return err

for flowpath in tqdm(glob("/data/cameron/monocular_ests/*/bwd_flow.pt")):
    print(flowpath)
    flows=torch.load(flowpath)
    flows=torch.nn.functional.interpolate(flows,scale_factor=.25,mode="bilinear")[::4]

    flow_errs=[]

    for i in tqdm(range(len(flows)),leave=False):
        flow=flows[i].permute(1,2,0)

        H, W = flows[0].shape[-2:]

        device="cpu"
        yy, xx = torch.meshgrid(
            torch.arange(H, dtype=torch.float32, device=device),
            torch.arange(W, dtype=torch.float32, device=device),
            indexing="ij",
        )
        xx = 2 * (xx + 0.5) / W - 1
        yy = 2 * (yy + 0.5) / H - 1
        x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)

        F   = torch.zeros(3, 3, dtype=torch.float32)
        err = torch.zeros(H, W, dtype=torch.float32)

        #flow = flow.permute(1, 2, 0)  # (H, W, 2)
        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)
        F, _ = cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_LMEDS)
        #         F, _ = cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_8POINT)
        F = torch.from_numpy(F.astype(np.float32))  # (3, 3)
        err = compute_sampson_error(x1, x2, F).reshape(H, W)
        fac = (H + W) / 2
        err = err * fac ** 2

        flow_errs.append(err)

    #thresh = torch.quantile(err, 0.8)
    #if thresh > self.reject:
    #    ok = torch.tensor(False)
    #err = torch.where(err <= thresh, torch.zeros_like(err), err)

    flow_err_grid=make_grid(torch.stack(flow_errs)[:,None]).permute(1,2,0)[...,0]
    print(torch.quantile(torch.stack(flow_errs), 0.9))
    plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_800.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.8))
    plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_825.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.825))
    plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_850.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.850))
    plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_875.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.875))
    plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_900.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.9))
    #plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_const.png"%flowpath.split("/")[-2],flow_err_grid)
    #plt.imsave("/home/cameronsmith/tmp.png",err)
