import torch
import matplotlib.pyplot as plt 
import cv2
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid
from glob import glob

def compute_sampson_error(x1, x2, F):
    """
    :param x1 (*, N, 2)
    :param x2 (*, N, 2)
    :param F (*, 3, 3)
    """
    h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
    h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
    d1 = torch.matmul(h1, F.transpose(-1, -2))  # (B, N, 3)
    d2 = torch.matmul(h2, F)  # (B, N, 3)
    z = (h2 * d1).sum(dim=-1)  # (B, N)
    err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
    return err

for flowpath in tqdm(glob("/data/cameron/monocular_ests/*/bwd_flow.pt")):
    print(flowpath)
    flows=torch.load(flowpath)
    flows=torch.nn.functional.interpolate(flows,scale_factor=.25,mode="bilinear")[::4]

    flow_errs=[]

    device="cpu"
    H, W = flows[0].shape[-2:]
    yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing="ij",)
    xx = 2 * (xx + 0.5) / W - 1
    yy = 2 * (yy + 0.5) / H - 1
    x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)

    for i in tqdm(range(len(flows)),leave=False):
        flow=flows[i].permute(1,2,0)

        masks=[]

        to_be_explained=torch.ones(H*W).bool()
        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)

        for j in range(4):
            F = torch.from_numpy(cv2.findFundamentalMat(x1[to_be_explained].numpy(), x2[to_be_explained].numpy(), cv2.FM_LMEDS)[0].astype(np.float32))  # (3, 3)
            err = compute_sampson_error(x1, x2, F) * ((H + W) / 2) ** 2

            err[~to_be_explained]=0
            thresh=torch.quantile(err[to_be_explained], 0.8)
            to_be_explained=err>thresh

            #mask.append(err>thresh)

            sp="/home/cameronsmith/repos/tmp/recursive_flow_errs/tmp%s_%02d%02d.png"%(flowpath.split("/")[-2],i,j)
            print(sp)
            #remaining_mask.append(err>thresh)
            plt.imsave(sp,err.reshape(H,W))
            #plt.imsave(sp,to_be_explained.reshape(H,W))

        #flow_errs.append(err)
        #from pdb import set_trace as pdb_;pdb_() 
        #zz

    #flow_err_grid=make_grid(torch.stack(flow_errs)[:,None]).permute(1,2,0)[...,0]
    #print(torch.quantile(torch.stack(flow_errs), 0.9))
    #plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_800.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.8))
    #plt.imsave("/home/cameronsmith/tmp.png",flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.8))
    #plt.imsave("/home/cameronsmith/tmp.png",err)
