import cuda_corr
import torch


class CorrLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout):
        """forward correlation"""
        ctx.save_for_backward(fmap1, fmap2, coords, ii, jj)
        ctx.radius = radius
        ctx.dropout = dropout
        (corr,) = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius)

        return corr

    @staticmethod
    def backward(ctx, grad):
        """backward correlation"""
        fmap1, fmap2, coords, ii, jj = ctx.saved_tensors

        if ctx.dropout < 1:
            perm = torch.rand(len(ii), device="cuda") < ctx.dropout
            coords = coords[:, perm]
            grad = grad[:, perm]
            ii = ii[perm]
            jj = jj[perm]

        fmap1_grad, fmap2_grad = cuda_corr.backward(
            fmap1, fmap2, coords, ii, jj, grad, ctx.radius
        )

        return fmap1_grad, fmap2_grad, None, None, None, None, None


class PatchLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, net, coords, radius):
        """forward patchify"""
        ctx.radius = radius
        ctx.save_for_backward(net, coords)

        (patches,) = cuda_corr.patchify_forward(net, coords, radius)
        return patches

    @staticmethod
    def backward(ctx, grad):
        """backward patchify"""
        net, coords = ctx.saved_tensors
        (grad,) = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)

        return grad, None, None


def patchify(net, coords, radius, mode="bilinear"):
    """extract patches"""

    patches = PatchLayer.apply(net, coords, radius)

    if mode == "bilinear":
        offset = (coords - coords.floor()).to(net.device)
        dx, dy = offset[:, :, None, None, None].unbind(dim=-1)

        d = 2 * radius + 1
        x00 = (1 - dy) * (1 - dx) * patches[..., :d, :d]
        x01 = (1 - dy) * (dx) * patches[..., :d, 1:]
        x10 = (dy) * (1 - dx) * patches[..., 1:, :d]
        x11 = (dy) * (dx) * patches[..., 1:, 1:]

        return x00 + x01 + x10 + x11

    return patches


def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1):
    return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout)
