import numpy as np
import torch

from .broadcasting import broadcast_inputs
# group operations implemented in cuda
from .group_ops import (Act3, Act4, Adj, AdjT, Exp, FromVec, Inv, Jinv, Log,
                        Mul, ToMatrix, ToVec)


class LieGroupParameter(torch.Tensor):
    """Wrapper class for LieGroup"""

    from torch._C import _disabled_torch_function_impl

    __torch_function__ = _disabled_torch_function_impl

    def __new__(cls, group, requires_grad=True):
        data = torch.zeros(
            group.tangent_shape,
            device=group.data.device,
            dtype=group.data.dtype,
            requires_grad=True,
        )

        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def __init__(self, group):
        self.group = group

    def retr(self):
        return self.group.retr(self)

    def log(self):
        return self.retr().log()

    def inv(self):
        return self.retr().inv()

    def adj(self, a):
        return self.retr().adj(a)

    def __mul__(self, other):
        if isinstance(other, LieGroupParameter):
            return self.retr() * other.retr()
        else:
            return self.retr() * other

    def add_(self, update, alpha):
        self.group = self.group.exp(alpha * update) * self.group

    def __getitem__(self, index):
        return self.retr().__getitem__(index)


class LieGroup:
    """Base class for Lie Group"""

    def __init__(self, data):
        self.data = data

    def __repr__(self):
        return "{}: size={}, device={}, dtype={}".format(
            self.group_name, self.shape, self.device, self.dtype
        )

    @property
    def shape(self):
        return self.data.shape[:-1]

    @property
    def device(self):
        return self.data.device

    @property
    def dtype(self):
        return self.data.dtype

    def vec(self):
        return self.apply_op(ToVec, self.data)

    @property
    def tangent_shape(self):
        return self.data.shape[:-1] + (self.manifold_dim,)

    @classmethod
    def Identity(cls, *batch_shape, **kwargs):
        """Construct identity element with batch shape"""

        if isinstance(batch_shape[0], tuple):
            batch_shape = batch_shape[0]

        elif isinstance(batch_shape[0], list):
            batch_shape = tuple(batch_shape[0])

        numel = np.prod(batch_shape)
        data = cls.id_elem.reshape(1, -1)

        if "device" in kwargs:
            data = data.to(kwargs["device"])

        if "dtype" in kwargs:
            data = data.type(kwargs["dtype"])

        data = data.repeat(numel, 1)
        return cls(data).view(batch_shape)

    @classmethod
    def IdentityLike(cls, G):
        return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)

    @classmethod
    def InitFromVec(cls, data):
        return cls(cls.apply_op(FromVec, data))

    @classmethod
    def Random(cls, *batch_shape, sigma=1.0, **kwargs):
        """Construct random element with batch_shape by random sampling in tangent space"""

        if isinstance(batch_shape[0], tuple):
            batch_shape = batch_shape[0]

        elif isinstance(batch_shape[0], list):
            batch_shape = tuple(batch_shape[0])

        tangent_shape = batch_shape + (cls.manifold_dim,)
        xi = torch.randn(tangent_shape, **kwargs)
        return cls.exp(sigma * xi)

    @classmethod
    def apply_op(cls, op, x, y=None):
        """Apply group operator"""
        inputs, out_shape = broadcast_inputs(x, y)

        data = op.apply(cls.group_id, *inputs)
        return data.view(out_shape + (-1,))

    @classmethod
    def exp(cls, x):
        """exponential map: x -> X"""
        return cls(cls.apply_op(Exp, x))

    def quaternion(self):
        """extract quaternion"""
        return self.apply_op(Quat, self.data)

    def log(self):
        """logarithm map"""
        return self.apply_op(Log, self.data)

    def inv(self):
        """group inverse"""
        return self.__class__(self.apply_op(Inv, self.data))

    def mul(self, other):
        """group multiplication"""
        return self.__class__(self.apply_op(Mul, self.data, other.data))

    def retr(self, a):
        """retraction: Exp(a) * X"""
        dX = self.__class__.apply_op(Exp, a)
        return self.__class__(self.apply_op(Mul, dX, self.data))

    def adj(self, a):
        """adjoint operator: b = A(X) * a"""
        return self.apply_op(Adj, self.data, a)

    def adjT(self, a):
        """transposed adjoint operator: b = a * A(X)"""
        return self.apply_op(AdjT, self.data, a)

    def Jinv(self, a):
        return self.apply_op(Jinv, self.data, a)

    def act(self, p):
        """action on a point cloud"""

        # action on point
        if p.shape[-1] == 3:
            return self.apply_op(Act3, self.data, p)

        # action on homogeneous point
        elif p.shape[-1] == 4:
            return self.apply_op(Act4, self.data, p)

    def matrix(self):
        """convert element to 4x4 matrix"""
        I = torch.eye(4, dtype=self.dtype, device=self.device)
        I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
        return self.__class__(self.data[..., None, :]).act(I).transpose(-1, -2)

    def translation(self):
        """extract translation component"""
        p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
        p = p.view(
            [1] * (len(self.data.shape) - 1)
            + [
                4,
            ]
        )
        return self.apply_op(Act4, self.data, p)

    def detach(self):
        return self.__class__(self.data.detach())

    def view(self, dims):
        data_reshaped = self.data.view(dims + (self.embedded_dim,))
        return self.__class__(data_reshaped)

    def __mul__(self, other):
        # group multiplication

        if isinstance(other, LieGroup):
            return self.mul(other)

        # action on point
        elif isinstance(other, torch.Tensor):
            return self.act(other)

    def __getitem__(self, index):
        return self.__class__(self.data[index])

    def __setitem__(self, index, item):
        self.data[index] = item.data

    def to(self, *args, **kwargs):
        return self.__class__(self.data.to(*args, **kwargs))

    def cpu(self):
        return self.__class__(self.data.cpu())

    def cuda(self):
        return self.__class__(self.data.cuda())

    def float(self, device):
        return self.__class__(self.data.float())

    def double(self, device):
        return self.__class__(self.data.double())

    def unbind(self, dim=0):
        return [self.__class__(x) for x in self.data.unbind(dim=dim)]


class SO3(LieGroup):
    group_name = "SO3"
    group_id = 1
    manifold_dim = 3
    embedded_dim = 4

    # unit quaternion
    id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])

    def __init__(self, data):
        if isinstance(data, SE3):
            data = data.data[..., 3:7]

        super(SO3, self).__init__(data)


class RxSO3(LieGroup):
    group_name = "RxSO3"
    group_id = 2
    manifold_dim = 4
    embedded_dim = 5

    # unit quaternion
    id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])

    def __init__(self, data):
        if isinstance(data, Sim3):
            data = data.data[..., 3:8]

        super(RxSO3, self).__init__(data)


class SE3(LieGroup):
    group_name = "SE3"
    group_id = 3
    manifold_dim = 6
    embedded_dim = 7

    # translation, unit quaternion
    id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])

    def __init__(self, data):
        if isinstance(data, SO3):
            translation = torch.zeros_like(data.data[..., :3])
            data = torch.cat([translation, data.data], -1)

        super(SE3, self).__init__(data)

    def scale(self, s):
        t, q = self.data.split([3, 4], -1)
        t = t * s.unsqueeze(-1)
        return SE3(torch.cat([t, q], dim=-1))


class Sim3(LieGroup):
    group_name = "Sim3"
    group_id = 4
    manifold_dim = 7
    embedded_dim = 8

    # translation, unit quaternion, scale
    id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])

    def __init__(self, data):

        if isinstance(data, SO3):
            scale = torch.ones_like(SO3.data[..., :1])
            translation = torch.zeros_like(SO3.data[..., :3])
            data = torch.cat([translation, SO3.data, scale], -1)

        elif isinstance(data, SE3):
            scale = torch.ones_like(data.data[..., :1])
            data = torch.cat([data.data, scale], -1)

        elif isinstance(data, Sim3):
            data = data.data

        super(Sim3, self).__init__(data)


def cat(group_list, dim):
    """Concatenate groups along dimension"""
    data = torch.cat([X.data for X in group_list], dim=dim)
    return group_list[0].__class__(data)


def stack(group_list, dim):
    """Concatenate groups along dimension"""
    data = torch.stack([X.data for X in group_list], dim=dim)
    return group_list[0].__class__(data)
