import lietorch
import torch
from gradcheck import get_analytical_jacobian, gradcheck
from lietorch import SE3, SO3, RxSO3, Sim3

### forward tests ###


def make_homogeneous(p):
    return torch.cat([p, torch.ones_like(p[..., :1])], dim=-1)


def matv(A, b):
    return torch.matmul(A, b[..., None])[..., 0]


def test_exp_log(Group, device="cuda"):
    """check Log(Exp(x)) == x"""
    a = 0.2 * torch.randn(2, 3, 4, 5, 6, 7, Group.manifold_dim, device=device).double()
    b = Group.exp(a).log()
    assert torch.allclose(a, b, atol=1e-8), "should be identity"
    print("\t-", Group, "Passed exp-log test")


def test_inv(Group, device="cuda"):
    """check X * X^{-1} == 0"""
    X = Group.exp(
        0.1 * torch.randn(2, 3, 4, 5, Group.manifold_dim, device=device).double()
    )
    a = (X * X.inv()).log()
    assert torch.allclose(a, torch.zeros_like(a), atol=1e-8), "should be 0"
    print("\t-", Group, "Passed inv test")


def test_adj(Group, device="cuda"):
    """check X * Exp(a) == Exp(Adj(X,a)) * X 0"""
    X = Group.exp(torch.randn(2, 3, 4, 5, Group.manifold_dim, device=device).double())
    a = torch.randn(2, 3, 4, 5, Group.manifold_dim, device=device).double()

    b = X.adj(a)
    Y1 = X * Group.exp(a)
    Y2 = Group.exp(b) * X

    c = (Y1 * Y2.inv()).log()
    assert torch.allclose(c, torch.zeros_like(c), atol=1e-8), "should be 0"
    print("\t-", Group, "Passed adj test")


def test_act(Group, device="cuda"):
    X = Group.exp(torch.randn(1, Group.manifold_dim, device=device).double())
    p = torch.randn(1, 3, device=device).double()

    p1 = X.act(p)
    p2 = matv(X.matrix(), make_homogeneous(p))

    assert torch.allclose(p1, p2[..., :3], atol=1e-8), "should be 0"
    print("\t-", Group, "Passed act test")


### backward tests ###
def test_exp_log_grad(Group, device="cuda", tol=1e-8):

    D = Group.manifold_dim

    def fn(a):
        return Group.exp(a).log()

    a = torch.zeros(1, Group.manifold_dim, requires_grad=True, device=device).double()
    analytical, reentrant, correct_grad_sizes, correct_grad_types = (
        get_analytical_jacobian((a,), fn(a))
    )

    assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)

    a = (
        0.2
        * torch.randn(1, Group.manifold_dim, requires_grad=True, device=device).double()
    )
    analytical, reentrant, correct_grad_sizes, correct_grad_types = (
        get_analytical_jacobian((a,), fn(a))
    )

    assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)

    print("\t-", Group, "Passed eye-grad test")


def test_inv_log_grad(Group, device="cuda", tol=1e-8):

    D = Group.manifold_dim
    X = Group.exp(0.2 * torch.randn(1, D, device=device).double())

    def fn(a):
        return (Group.exp(a) * X).inv().log()

    a = torch.zeros(1, D, requires_grad=True, device=device).double()
    analytical, numerical = gradcheck(fn, [a], eps=1e-4)

    # assert torch.allclose(analytical[0], numerical[0], atol=tol)
    if not torch.allclose(analytical[0], numerical[0], atol=tol):
        print(analytical[0])
        print(numerical[0])

    print("\t-", Group, "Passed inv-grad test")


def test_adj_grad(Group, device="cuda"):
    D = Group.manifold_dim
    X = Group.exp(0.5 * torch.randn(1, Group.manifold_dim, device=device).double())

    def fn(a, b):
        return (Group.exp(a) * X).adj(b)

    a = torch.zeros(1, D, requires_grad=True, device=device).double()
    b = torch.randn(1, D, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
    assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
    assert torch.allclose(analytical[1], numerical[1], atol=1e-8)

    print("\t-", Group, "Passed adj-grad test")


def test_adjT_grad(Group, device="cuda"):
    D = Group.manifold_dim
    X = Group.exp(0.5 * torch.randn(1, Group.manifold_dim, device=device).double())

    def fn(a, b):
        return (Group.exp(a) * X).adjT(b)

    a = torch.zeros(1, D, requires_grad=True, device=device).double()
    b = torch.randn(1, D, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)

    assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
    assert torch.allclose(analytical[1], numerical[1], atol=1e-8)

    print("\t-", Group, "Passed adjT-grad test")


def test_act_grad(Group, device="cuda"):
    D = Group.manifold_dim
    X = Group.exp(5 * torch.randn(1, D, device=device).double())

    def fn(a, b):
        return (X * Group.exp(a)).act(b)

    a = torch.zeros(1, D, requires_grad=True, device=device).double()
    b = torch.randn(1, 3, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)

    assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
    assert torch.allclose(analytical[1], numerical[1], atol=1e-8)

    print("\t-", Group, "Passed act-grad test")


def test_matrix_grad(Group, device="cuda"):
    D = Group.manifold_dim
    X = Group.exp(torch.randn(1, D, device=device).double())

    def fn(a):
        return (Group.exp(a) * X).matrix()

    a = torch.zeros(1, D, requires_grad=True, device=device).double()
    analytical, numerical = gradcheck(fn, [a], eps=1e-4)
    assert torch.allclose(analytical[0], numerical[0], atol=1e-6)

    print("\t-", Group, "Passed matrix-grad test")


def extract_translation_grad(Group, device="cuda"):
    """prototype function"""

    D = Group.manifold_dim
    X = Group.exp(5 * torch.randn(1, D, device=device).double())

    def fn(a):
        return (Group.exp(a) * X).translation()

    a = torch.zeros(1, D, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a], eps=1e-4)

    assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
    print("\t-", Group, "Passed translation grad test")


def test_vec_grad(Group, device="cuda", tol=1e-6):

    D = Group.manifold_dim
    X = Group.exp(5 * torch.randn(1, D, device=device).double())

    def fn(a):
        return (Group.exp(a) * X).vec()

    a = torch.zeros(1, D, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a], eps=1e-4)

    assert torch.allclose(analytical[0], numerical[0], atol=tol)
    print("\t-", Group, "Passed tovec grad test")


def test_fromvec_grad(Group, device="cuda", tol=1e-6):

    def fn(a):
        if Group == SO3:
            a = a / a.norm(dim=-1, keepdim=True)

        elif Group == RxSO3:
            q, s = a.split([4, 1], dim=-1)
            q = q / q.norm(dim=-1, keepdim=True)
            a = torch.cat([q, s.exp()], dim=-1)

        elif Group == SE3:
            t, q = a.split([3, 4], dim=-1)
            q = q / q.norm(dim=-1, keepdim=True)
            a = torch.cat([t, q], dim=-1)

        elif Group == Sim3:
            t, q, s = a.split([3, 4, 1], dim=-1)
            q = q / q.norm(dim=-1, keepdim=True)
            a = torch.cat([t, q, s.exp()], dim=-1)

        return Group.InitFromVec(a).vec()

    D = Group.embedded_dim
    a = torch.randn(1, 2, D, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a], eps=1e-4)

    assert torch.allclose(analytical[0], numerical[0], atol=tol)
    print("\t-", Group, "Passed fromvec grad test")


def scale(device="cuda"):

    def fn(a, s):
        X = SE3.exp(a)
        X.scale(s)
        return X.log()

    s = torch.rand(1, requires_grad=True, device=device).double()
    a = torch.randn(1, 6, requires_grad=True, device=device).double()

    analytical, numerical = gradcheck(fn, [a, s], eps=1e-3)
    print(analytical[1])
    print(numerical[1])

    assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
    assert torch.allclose(analytical[1], numerical[1], atol=1e-8)

    print("\t-", "Passed se3-to-sim3 test")


if __name__ == "__main__":

    print("Testing lietorch forward pass (CPU) ...")
    for Group in [SO3, RxSO3, SE3, Sim3]:
        test_exp_log(Group, device="cpu")
        test_inv(Group, device="cpu")
        test_adj(Group, device="cpu")
        test_act(Group, device="cpu")

    print("Testing lietorch backward pass (CPU)...")
    for Group in [SO3, RxSO3, SE3, Sim3]:
        if Group == Sim3:
            tol = 1e-3
        else:
            tol = 1e-8

        test_exp_log_grad(Group, device="cpu", tol=tol)
        test_inv_log_grad(Group, device="cpu", tol=tol)
        test_adj_grad(Group, device="cpu")
        test_adjT_grad(Group, device="cpu")
        test_act_grad(Group, device="cpu")
        test_matrix_grad(Group, device="cpu")
        extract_translation_grad(Group, device="cpu")
        test_vec_grad(Group, device="cpu")
        test_fromvec_grad(Group, device="cpu")

    print("Testing lietorch forward pass (GPU) ...")
    for Group in [SO3, RxSO3, SE3, Sim3]:
        test_exp_log(Group, device="cuda")
        test_inv(Group, device="cuda")
        test_adj(Group, device="cuda")
        test_act(Group, device="cuda")

    print("Testing lietorch backward pass (GPU)...")
    for Group in [SO3, RxSO3, SE3, Sim3]:
        if Group == Sim3:
            tol = 1e-3
        else:
            tol = 1e-8

        test_exp_log_grad(Group, device="cuda", tol=tol)
        test_inv_log_grad(Group, device="cuda", tol=tol)
        test_adj_grad(Group, device="cuda")
        test_adjT_grad(Group, device="cuda")
        test_act_grad(Group, device="cuda")
        test_matrix_grad(Group, device="cuda")
        extract_translation_grad(Group, device="cuda")
        test_vec_grad(Group, device="cuda")
        test_fromvec_grad(Group, device="cuda")
