from collections import defaultdict
from dataclasses import dataclass
import numpy as np
from numpy import ndarray

from typing import Dict, Union, List, Tuple

from .order import Order
from .raw_data import RawData
from .exporter import Exporter

from ..tokenizer.spec import TokenizeInput
from .utils import linear_blend_skinning

import trimesh


@dataclass
class Asset(Exporter):
    '''
    Dataclass to handle data parsed from raw data.
    '''
    
    # data class
    cls: str
    
    # where is this asset from
    path: str
    
    # data file name
    data_name: str
    
    # vertices of the mesh, shape (N, 3), float32
    vertices: ndarray
    
    # normals of vertices, shape (N, 3), float32
    vertex_normals: ndarray
    
    # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
    faces: ndarray
    
    # face normal of mesh, shape (F, 3), float32
    face_normals: ndarray
    
    # joints of bones, shape (J, 3), float32
    joints: Union[ndarray, None]=None
    
    # tails of joints, shape (J, 3), float32
    tails: Union[ndarray, None]=None
    
    # skinning of joints, shape (N, J), float32
    skin: Union[ndarray, None]=None
    
    # whether the joint has skin, bool
    no_skin: Union[ndarray, None]=None
    
    # vertex groups
    vertex_groups: Union[Dict[str, ndarray], None]=None
    
    # parents of joints, None represents no parent(a root joint)
    # make sure parent[k] < k
    parents: Union[List[Union[int, None]], None]=None
    
    # names of joints
    names: Union[List[str], None]=None
    
    # sampled vertices, shape (N, 3)
    sampled_vertices: Union[ndarray, None]=None
    
    # sampled normals, shape (N, 3)
    sampled_normals: Union[ndarray, None]=None
    
    # sampled vertex groups, every vertex group should be (N, J)
    sampled_vertex_groups: Union[Dict[str, ndarray], None]=None
    
    # {id: part}, part==None -> a spring token
    parts_bias: Union[Dict[int, Union[str, None]], None]=None
    
    # local coordinate, shape (J, 4, 4)
    matrix_local: Union[ndarray, None]=None
    
    # pose matrix for skinning loss calculation, shape (J, 4, 4)
    pose_matrix: Union[ndarray, None]=None
    
    meta: Union[Dict[str, ...], None]=None
    
    @property
    def N(self):
        '''
        number of vertices
        '''
        return self.vertices.shape[0]
    
    @property
    def F(self):
        '''
        number of faces
        '''
        return self.faces.shape[0]
    
    @property
    def J(self):
        '''
        number of joints
        '''
        return self.joints.shape[0]
    
    def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None):
        '''
        get matrix
        
        matrix_basis: (J, 4, 4)
        '''
        if matrix_local is None:
            assert self.joints is not None
            matrix_local = self.matrix_local
            if matrix_local is None:
                matrix_local = np.zeros((self.J, 4, 4))
                matrix_local[:, 0, 0] = 1.
                matrix_local[:, 1, 1] = 1.
                matrix_local[:, 2, 2] = 1.
                matrix_local[:, 3, 3] = 1.
                for i in range(self.J):
                    matrix_local[i, :3, 3] = self.joints[i]
        
        matrix = np.zeros((self.J, 4, 4))
        for i in range(self.J):
            if i==0:
                matrix[i] = matrix_local[i] @ matrix_basis[i]
            else:
                pid = self.parents[i]
                matrix_parent = matrix[pid]
                matrix_local_parent = matrix_local[pid]
                
                matrix[i] = (
                    matrix_parent @
                    (np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @
                    matrix_basis[i]
                )
        return matrix
    
    def apply_matrix_basis(self, matrix_basis: ndarray):
        '''
        apply a pose to armature
        
        matrix_basis: (J, 4, 4)
        '''
        matrix_local = self.matrix_local
        if matrix_local is None:
            matrix_local = np.zeros((self.J, 4, 4))
            matrix_local[:, 0, 0] = 1.
            matrix_local[:, 1, 1] = 1.
            matrix_local[:, 2, 2] = 1.
            matrix_local[:, 3, 3] = 1.
            for i in range(self.J):
                matrix_local[i, :3, 3] = self.joints[i]
        
        matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local)
        self.joints = matrix[:, :3, 3]
        vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.)
        # update matrix_local
        self.matrix_local = matrix

        # change tails
        if self.tails is not None:
            t_skin = np.eye(self.J)
            self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.)
        # in accordance with trimesh's normals
        mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False)
        self.vertices = vertices
        self.vertex_normals = mesh.vertex_normals.copy()
        self.face_normals = mesh.face_normals.copy()
    
    def set_order_by_names(self, new_names: List[str]):
        assert len(new_names) == len(self.names)
        name_to_id = {name: id for (id, name) in enumerate(self.names)}
        new_name_to_id = {name: id for (id, name) in enumerate(new_names)}
        perm = []
        new_parents = []
        for (new_id, name) in enumerate(new_names):
            perm.append(name_to_id[name])
            pid = self.parents[name_to_id[name]]
            if new_id == 0:
                assert pid is None, 'first bone is not root bone'
            else:
                pname = self.names[pid]
                pid = new_name_to_id[pname]
                assert pid < new_id, 'new order does not form a tree'
            new_parents.append(pid)
        
        if self.joints is not None:
            self.joints = self.joints[perm]
        self.parents = new_parents
        if self.tails is not None:
            self.tails = self.tails[perm]
        if self.skin is not None:
            self.skin = self.skin[:, perm]
        if self.no_skin is not None:
            self.no_skin = self.no_skin[perm]
        if self.matrix_local is not None:
            self.matrix_local = self.matrix_local[perm]
        self.names = new_names
    
    def set_order(self, order: Order):
        if self.names is None or self.parents is None:
            return
        new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents)
        self.set_order_by_names(new_names=new_names)
    
    def collapse(self, keep: List[str]):
        dsu = [i for i in range(self.J)]
        
        def find(x: int) -> int:
            if dsu[x] == x:
                return x
            y = find(dsu[x])
            dsu[x] = y
            return y
        
        def merge(x: int, y: int):
            dsu[find(x)] = find(y)

        if self.tails is not None:
            new_tails = self.tails.copy()
        else:
            new_tails = None
        if self.skin is not None:
            new_skin = self.skin.copy()
        else:
            new_skin = None

        if self.no_skin is not None:
            new_no_skin = self.no_skin.copy()
        else:
            new_no_skin = None
        
        if self.matrix_local is not None:
            matrix_local = self.matrix_local.copy()
        else:
            matrix_local = None
        new_names = []
        new_parents = []
        perm = []
        new_name_to_id = {}
        tot = 0
        for (i, name) in enumerate(self.names):
            if name in keep:
                new_names.append(name)
                new_name_to_id[name] = tot
                tot += 1
                perm.append(i)
                pid = self.parents[i]
                if pid is None:
                    new_parents.append(None)
                else:
                    pid = find(pid)
                    new_parents.append(new_name_to_id[self.names[pid]])
                continue
            assert i != 0, 'cannot remove root'
            id = find(i)
            pid = find(self.parents[id])
            # be careful !
            # do not copy tail here because you dont know which child to inherit from
            if new_skin is not None:
                new_skin[:, pid] += new_skin[:, id]
            if new_no_skin is not None:
                new_no_skin[pid] &= new_no_skin[id]
            merge(id, pid)
        
        if new_tails is not None:
            new_tails = new_tails[perm]
        if new_skin is not None:
            new_skin = new_skin[:, perm]
        if new_no_skin is not None:
            new_no_skin = new_no_skin[perm]
        if matrix_local is not None:
            matrix_local = matrix_local[perm]
        
        if self.joints is not None:
            self.joints = self.joints[perm]
        self.parents        = new_parents
        self.tails          = new_tails
        self.skin           = new_skin
        self.no_skin        = new_no_skin
        self.names          = new_names
        self.matrix_local   = matrix_local
    
    @staticmethod
    def from_raw_data(
        raw_data: RawData,
        cls: str,
        path: str,
        data_name: str,
    ) -> 'Asset':
        '''
        Return an asset initialized from raw data and do transform.
        '''
        return Asset(
            cls=cls,
            path=path,
            data_name=data_name,
            vertices=raw_data.vertices,
            vertex_normals=raw_data.vertex_normals,
            faces=raw_data.faces,
            face_normals=raw_data.face_normals,
            joints=raw_data.joints,
            tails=raw_data.tails,
            skin=raw_data.skin,
            no_skin=raw_data.no_skin,
            parents=raw_data.parents,
            names=raw_data.names,
            matrix_local=raw_data.matrix_local,
            meta={},
        )
    
    def get_tokenize_input(self) -> TokenizeInput:
        children = defaultdict(list)
        
        for (id, p) in enumerate(self.parents):
            if p is not None:
                children[p].append(id)
        bones = []
        branch = []
        is_leaf = []
        last = None
        for i in range(self.J):
            is_leaf.append(len(children[i])==0)
            if i == 0:
                bones.append(np.concatenate([self.joints[i], self.joints[i]]))
                branch.append(False)
            else:
                pid = self.parents[i]
                bones.append(np.concatenate([self.joints[pid], self.joints[i]]))
                branch.append(pid!=last)
            last = i
        bones = np.stack(bones)
        branch = np.array(branch, dtype=bool)
        is_leaf = np.array(is_leaf, dtype=bool)
        return TokenizeInput(
            bones=bones,
            tails=self.tails,
            branch=branch,
            is_leaf=is_leaf,
            no_skin=self.no_skin,
            cls=self.cls,
            parts_bias=self.parts_bias,
        )
    
    def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
        '''
        export point cloud
        '''
        vertices    = self.vertices
        normals     = self.vertex_normals
        if self.sampled_vertices is not None:
            vertices    = self.sampled_vertices
            normals     = self.sampled_normals
        if with_normal == False:
            normals = None
        self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size)
    
    def export_mesh(self, path: str):
        '''
        export mesh
        '''
        self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
    
    def export_skeleton(self, path: str):
        '''
        export spring
        '''
        self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
    
    def export_skeleton_sequence(self, path: str):
        '''
        export spring
        '''
        self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
    
    def export_fbx(
        self,
        path: str,
        vertex_group_name: str,
        extrude_size: float=0.03,
        group_per_vertex: int=-1,
        add_root: bool=False,
        do_not_normalize: bool=False,
        use_extrude_bone: bool=True,
        use_connect_unique_child: bool=True,
        extrude_from_parent: bool=True,
        use_tail: bool=False,
        use_origin: bool=False,
    ):
        '''
        export the whole model with skining
        '''
        self._export_fbx(
            path=path,
            vertices=self.vertices if use_origin else self.sampled_vertices,
            joints=self.joints,
            skin=self.sampled_vertex_groups[vertex_group_name],
            parents=self.parents,
            names=self.names,
            faces=self.faces if use_origin else None,
            extrude_size=extrude_size,
            group_per_vertex=group_per_vertex,
            add_root=add_root,
            do_not_normalize=do_not_normalize,
            use_extrude_bone=use_extrude_bone,
            use_connect_unique_child=use_connect_unique_child,
            extrude_from_parent=extrude_from_parent,
            tails=self.tails if use_tail else None,
        )

    def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False):
        if use_tail:
            assert self.tails is not None
            self._export_render(
                path=path,
                vertices=self.vertices,
                faces=self.faces,
                bones=np.concatenate([self.joints, self.tails], axis=-1),
                resolution=resolution,
            )
        else:
            pjoints = self.joints[self.parents[1:]]
            self._export_render(
                path=path,
                vertices=self.vertices,
                faces=self.faces,
                bones=np.concatenate([pjoints, self.joints[1:]], axis=-1),
                resolution=resolution,
            )