import torch
import matplotlib.pyplot as plt 
import cv2
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid
from glob import glob

def compute_sampson_error(x1, x2, F):
    """
    :param x1 (*, N, 2)
    :param x2 (*, N, 2)
    :param F (*, 3, 3)
    """
    h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
    h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
    d1 = torch.matmul(h1, F.transpose(-1, -2))  # (B, N, 3)
    d2 = torch.matmul(h2, F)  # (B, N, 3)
    z = (h2 * d1).sum(dim=-1)  # (B, N)
    err = z ** 2 / ( d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2)
    return err

#for flowpath in tqdm(glob("/data/cameron/monocular_ests/*/bwd_flow.pt")):
for flowpath in tqdm([f"/data/cameron/monocular_ests/{x}/bwd_flow.pt" for x in ["bear","horns","soapbox","blackswan","fern","robotics"]]):
    print(flowpath)
    flows=torch.load(flowpath)
    flows=torch.nn.functional.interpolate(flows,scale_factor=.25,mode="bilinear")[::8]

    flow_errs=[]

    device="cpu"
    H, W = flows[0].shape[-2:]
    yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing="ij",)
    xx = 2 * (xx + 0.5) / W - 1
    yy = 2 * (yy + 0.5) / H - 1
    x1 = torch.stack([xx, yy], dim=-1).flatten(0,1)

    errs=[]
    for i in tqdm(range(len(flows)),leave=False):
        flow=flows[i].permute(1,2,0)

        masks=[]

        x2 = x1 + flow.flatten(0,1)  # (H*W, 2)

        F = torch.from_numpy(cv2.findFundamentalMat(x1.numpy(), x2.numpy(), cv2.FM_LMEDS)[0].astype(np.float32))  # (3, 3)
        err = compute_sampson_error(x1, x2, F) * ((H + W) / 2) ** 2
        errs.append(err)
    thresh=torch.quantile(torch.stack(errs).flatten(),.82)
    for i,err in tqdm(enumerate(errs),leave=False):

        #mask.append(err>thresh)
        j=0

        sp="/home/cameronsmith/repos/tmp/rig_threshs/tmp%s_%02d%02d.png"%(flowpath.split("/")[-2],i,0)
        print(sp,thresh,err.median())
        #remaining_mask.append(err>thresh)
        plt.imsave(sp,err.reshape(H,W))
        sp="/home/cameronsmith/repos/tmp/rig_threshs/tmp%s_%02d%02d.png"%(flowpath.split("/")[-2],i,1)
        plt.imsave(sp,err.reshape(H,W)>thresh)
        #plt.imsave(sp,to_be_explained.reshape(H,W))

        #flow_errs.append(err)
        #from pdb import set_trace as pdb_;pdb_() 
        #zz

    #flow_err_grid=make_grid(torch.stack(flow_errs)[:,None]).permute(1,2,0)[...,0]
    #print(torch.quantile(torch.stack(flow_errs), 0.9))
    #plt.imsave("/home/cameronsmith/repos/tmp/flow_errs/%s_800.png"%flowpath.split("/")[-2],flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.8))
    #plt.imsave("/home/cameronsmith/tmp.png",flow_err_grid>torch.quantile(torch.stack(flow_errs), 0.8))
    #plt.imsave("/home/cameronsmith/tmp.png",err)
