import numpy as np
from numpy import ndarray

from typing import Dict, Tuple, Union, List

from .spec import TokenizerSpec, TokenizeInput, DetokenzeOutput, TokenizerConfig
from .spec import make_skeleton
from ..data.order import get_order

class TokenizerPart(TokenizerSpec):
    def __init__(
        self,
        config: TokenizerConfig,
    ):
        super().__init__()
        
        self._num_discrete      = config.num_discrete
        self._continuous_range  = config.continuous_range
        self.cls_token_id       = config.cls_token_id.copy()
        self.parts_token_id     = config.parts_token_id.copy()
        self.order              = get_order(config.order_config)
        _offset                 = config.num_discrete
        
        self.token_id_branch = _offset + 0
        self.token_id_bos    = _offset + 1
        self.token_id_eos    = _offset + 2
        self.token_id_pad    = _offset + 3
        _offset += 4
        
        self.token_id_spring   = _offset + 0
        _offset += 1
        
        assert None not in self.parts_token_id
        for i in self.parts_token_id:
            self.parts_token_id[i] += _offset
        _offset += len(self.parts_token_id)
        
        self.token_id_cls_none = _offset + 0
        _offset += 1
        
        for i in self.cls_token_id:
            self.cls_token_id[i] += _offset
        _offset += len(self.cls_token_id)
        
        self._vocab_size = _offset
        
        self.parts_token_id_name = [x for x in self.parts_token_id]
        
        self.part_token_to_name = {v: k for k, v in self.parts_token_id.items()}
        assert len(self.part_token_to_name) == len(self.parts_token_id), 'names with same token found in parts_token_id'
        self.part_token_to_name[self.token_id_spring] = None
        
        self.cls_token_to_name = {v: k for k, v in self.cls_token_id.items()}
        assert len(self.cls_token_to_name) == len(self.cls_token_id), 'names with same token found in cls_token_id'

    def cls_name_to_token(self, cls: str) -> int:
        if cls not in self.cls_token_id:
            return self.token_id_cls_none
        return self.cls_token_id[cls]
    
    def part_name_to_token(self, part: str) -> int:
        assert part in self.parts_token_id, f"do not find part name    `{part}` in tokenizer"
        return self.parts_token_id[part]
        
    def tokenize(self, input: TokenizeInput) -> ndarray:
        num_bones = input.num_bones
        bones = discretize(t=input.bones, continuous_range=self.continuous_range, num_discrete=self.num_discrete)
        tails = discretize(t=input.tails, continuous_range=self.continuous_range, num_discrete=self.num_discrete)
        
        branch = input.branch
        is_leaf = input.is_leaf
        
        tokens = [self.token_id_bos]
        if input.cls is None or input.cls not in self.cls_token_id:
            tokens.append(self.token_id_cls_none)
        else:
            tokens.append(self.cls_token_id[input.cls])
        use_leaf = False
        for i in range(num_bones):
            # add parts token id
            if i in input.parts_bias:
                part = input.parts_bias[i]
                if part is None:
                    tokens.append(self.token_id_spring)
                else:
                    assert part in self.parts_token_id, f"do not find part name {part} in tokenizer {self.__class__}"
                    tokens.append(self.parts_token_id[part])
            if branch[i]:
                tokens.append(self.token_id_branch)
                tokens.append(bones[i, 0])
                tokens.append(bones[i, 1])
                tokens.append(bones[i, 2])
                tokens.append(bones[i, 3])
                tokens.append(bones[i, 4])
                tokens.append(bones[i, 5])
            else:
                tokens.append(bones[i, 3])
                tokens.append(bones[i, 4])
                tokens.append(bones[i, 5])
        tokens.append(self.token_id_eos)
        return np.array(tokens, dtype=np.int64)
            

    def detokenize(self, ids: ndarray, **kwargs) -> DetokenzeOutput:
        assert isinstance(ids, ndarray), 'expect ids to be ndarray'
        if ids[0] != self.token_id_bos:
            raise ValueError(f"first token is not bos")
        trailing_pad = 0
        while trailing_pad < ids.shape[0] and ids[-trailing_pad-1] == self.token_id_pad:
            trailing_pad += 1
        if ids[-1-trailing_pad] != self.token_id_eos:
            raise ValueError(f"last token is not eos")
        ids = ids[1:-1-trailing_pad]
        joints = []
        p_joints = []
        tails_dict = {}
        parts = []
        i = 0
        is_branch = False
        last_joint = None
        num_bones = 0
        while i < len(ids):
            if ids[i] < self.num_discrete:
                if is_branch:
                    p_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
                    current_joint = undiscretize(t=ids[i+3:i+6], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
                    joints.append(current_joint)
                    p_joints.append(p_joint)
                    i += 6
                else:
                    current_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
                    joints.append(current_joint)
                    if len(p_joints) == 0: # root
                        p_joints.append(current_joint)
                        p_joint = current_joint
                    else:
                        assert last_joint is not None
                        p_joints.append(last_joint)
                        p_joint = last_joint
                    i += 3
                if last_joint is not None:
                    tails_dict[num_bones-1] = current_joint
                last_joint = current_joint
                num_bones += 1
                is_branch = False
            elif ids[i]==self.token_id_branch:
                is_branch = True
                last_joint = None
                i += 1
            elif ids[i]==self.token_id_spring or ids[i] in self.parts_token_id.values():
                parts.append(self.part_token_to_name[ids[i]])
                i += 1
            elif ids[i] in self.cls_token_id.values():
                cls = ids[i]
                i += 1
            elif ids[i] == self.token_id_cls_none:
                cls = None
                i += 1
            else:
                raise ValueError(f"unexpected token found: {ids[i]}")
        joints = np.stack(joints)
        p_joints = np.stack(p_joints)
        # leaf is ignored in this tokenizer so need to extrude tails for leaf and branch
        bones, tails, available_bones_id, parents = make_skeleton(
            joints=joints,
            p_joints=p_joints,
            tails_dict=tails_dict,
            convert_leaf_bones_to_tails=False,
            extrude_tail_for_leaf=True,
            extrude_tail_for_branch=True,
        )
        bones = bones[available_bones_id]
        tails = tails[available_bones_id]
        if cls in self.cls_token_to_name:
            cls = self.cls_token_to_name[cls]
        else:
            cls = None
        if self.order is not None:
            names = self.order.make_names(cls=cls, parts=parts, num_bones=num_bones)
        else:
            names = [f"bone_{i}" for i in range(num_bones)]
        return DetokenzeOutput(
            tokens=ids,
            parents=parents,
            bones=bones,
            tails=tails,
            no_skin=None,
            cls=cls,
            parts=parts,
            names=names,
            continuous_range=self.continuous_range,
        )
    
    def get_require_parts(self) -> List[str]:
        return self.parts_token_id_name
            
    @property
    def vocab_size(self):
        return self._vocab_size

    @property
    def pad(self):
        return self.token_id_pad
    
    @property
    def bos(self):
        return self.token_id_bos

    @property
    def eos(self):
        return self.token_id_eos
    
    @property
    def num_discrete(self):
        return self._num_discrete
    
    @property
    def continuous_range(self) -> Tuple[float, float]:
        return self._continuous_range

def discretize(
    t: ndarray,
    continuous_range: Tuple[float, float],
    num_discrete: int,
) -> ndarray:
    lo, hi = continuous_range
    assert hi >= lo
    t = (t - lo) / (hi - lo)
    t *= num_discrete
    return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64)

def undiscretize(
    t: ndarray,
    continuous_range: Tuple[float, float],
    num_discrete: int,
) -> ndarray:
    lo, hi = continuous_range
    assert hi >= lo
    t = t.astype(np.float32) + 0.5
    t /= num_discrete
    return t * (hi - lo) + lo
