import bpy, os
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from numpy import ndarray
from typing import Dict, Tuple, List, Optional, Union
import trimesh
import fast_simplification
from scipy.spatial import KDTree

import argparse
import yaml
from box import Box
import os

from .log import new_entry, add_error, add_warning, new_log, end_log
from .raw_data import RawData

def load(filepath: str):
    old_objs = set(bpy.context.scene.objects)
    
    if not os.path.exists(filepath):
        raise ValueError(f'File {filepath} does not exist !')
    
    try:
        if filepath.endswith(".vrm"):
            # enable vrm addon and load vrm model
            bpy.ops.preferences.addon_enable(module='vrm')
            
            bpy.ops.import_scene.vrm(
                filepath=filepath,
                use_addon_preferences=True,
                extract_textures_into_folder=False,
                make_new_texture_folder=False,
                set_shading_type_to_material_on_import=False,
                set_view_transform_to_standard_on_import=True,
                set_armature_display_to_wire=True,
                set_armature_display_to_show_in_front=True,
                set_armature_bone_shape_to_default=True,
                disable_bake=True, # customized option for better performance
            )
        elif filepath.endswith(".obj"):
            bpy.ops.wm.obj_import(filepath=filepath)
        elif filepath.endswith(".fbx") or filepath.endswith(".FBX"):
            # end bone is removed using remove_dummy_bone
            bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False)
        elif filepath.endswith(".glb") or filepath.endswith(".gltf"):
            bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False)
        elif filepath.endswith(".dae"):
            bpy.ops.wm.collada_import(filepath=filepath)
        elif filepath.endswith(".blend"):
            with bpy.data.libraries.load(filepath) as (data_from, data_to):
                data_to.objects = data_from.objects
            for obj in data_to.objects:
                if obj is not None:
                    bpy.context.collection.objects.link(obj)
        else:
            raise ValueError(f"not suported type {filepath}")
    except:
        raise ValueError(f"failed to load {filepath}")

    armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"]
    if len(armature)==0:
        return None
    if len(armature)>1:
        raise ValueError(f"multiple armatures found")
    armature = armature[0]
    
    armature.select_set(True)
    bpy.context.view_layer.objects.active = armature
    bpy.ops.object.mode_set(mode='EDIT')
    for bone in bpy.data.armatures[0].edit_bones:
        bone.roll = 0. # change all roll to 0. to prevent weird behaviour

    bpy.ops.object.mode_set(mode='OBJECT')
    armature.select_set(False)
    
    bpy.ops.object.select_all(action='DESELECT')
    return armature

# remove all data in bpy
def clean_bpy():
    # First try to purge orphan data
    try:
        bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True)
    except Exception as e:
        print(f"Warning: Could not purge orphans: {e}")
        
    # Then remove all data by type
    data_types = [
        bpy.data.actions,
        bpy.data.armatures,
        bpy.data.cameras,
        bpy.data.collections,
        bpy.data.curves,
        bpy.data.images,
        bpy.data.lights,
        bpy.data.materials,
        bpy.data.meshes,
        bpy.data.objects,
        bpy.data.textures,
        bpy.data.worlds,
        bpy.data.node_groups
    ]
    
    for data_collection in data_types:
        try:
            for item in data_collection:
                try:
                    data_collection.remove(item)
                except Exception as e:
                    print(f"Warning: Could not remove {item.name} from {data_collection}: {e}")
        except Exception as e:
            print(f"Warning: Error processing {data_collection}: {e}")
            
    # Force garbage collection to free memory
    import gc
    gc.collect()

def get_arranged_bones(armature):
    matrix_world = armature.matrix_world
    arranged_bones = []
    root = armature.pose.bones[0]
    while root.parent is not None:
        root = root.parent
    Q = [root]
    rot = np.array(matrix_world)[:3, :3]
    
    # dfs and sort
    while len(Q) != 0:
        b = Q.pop(0)
        arranged_bones.append(b)
        children = []
        for cb in b.children:
            head = rot @ np.array(b.head)
            children.append((cb, head[0], head[1], head[2]))
        children = sorted(children, key=lambda x: (x[3], x[1], x[2]))
        _c = [x[0] for x in children]
        Q = _c + Q
    return arranged_bones

def process_mesh():
    meshes = []
    for v in bpy.data.objects:
        if v.type == 'MESH':
            meshes.append(v)
    
    _dict_mesh = {}
    for obj in meshes:
        m = np.array(obj.matrix_world)
        matrix_world_rot = m[:3, :3]
        matrix_world_bias = m[:3, 3]
        rot = matrix_world_rot
        total_vertices = len(obj.data.vertices)
        vertex = np.zeros((4, total_vertices))
        vertex_normal = np.zeros((total_vertices, 3))
        obj_verts = obj.data.vertices
        faces = []
        normals = []
        
        for v in obj_verts:
            vertex_normal[v.index] = rot @ np.array(v.normal) # be careful !
            vv = rot @ v.co
            vv = np.array(vv) + matrix_world_bias
            vertex[0:3, v.index] = vv
            vertex[3][v.index] = 1 # affine coordinate
        
        for polygon in obj.data.polygons:
            edges = polygon.edge_keys
            nodes = []
            adj = {}
            for edge in edges:
                if adj.get(edge[0]) is None:
                    adj[edge[0]] = []
                adj[edge[0]].append(edge[1])
                if adj.get(edge[1]) is None:
                    adj[edge[1]] = []
                adj[edge[1]].append(edge[0])
                nodes.append(edge[0])
                nodes.append(edge[1])
            normal = polygon.normal
            nodes = list(set(sorted(nodes)))
            first = nodes[0]
            loop = []
            now = first
            vis = {}
            while True:
                loop.append(now)
                vis[now] = True
                if vis.get(adj[now][0]) is None:
                    now = adj[now][0]
                elif vis.get(adj[now][1]) is None:
                    now = adj[now][1]
                else:
                    break
            for (second, third) in zip(loop[1:], loop[2:]):
                faces.append((first + 1, second + 1, third + 1)) # the cursed +1
                normals.append(rot @ normal) # and the cursed normal of BLENDER

        correct_faces = []
        for (i, face) in enumerate(faces):
            normal = normals[i]
            v0 = face[0] - 1
            v1 = face[1] - 1
            v2 = face[2] - 1
            v = np.cross(
                vertex[:3, v1] - vertex[:3, v0],
                vertex[:3, v2] - vertex[:3, v0],
            )
            if (v*normal).sum() > 0:
                correct_faces.append(face)
            else:
                correct_faces.append((face[0], face[2], face[1]))
        if len(correct_faces) > 0:
            _dict_mesh[obj.name] = {
                'vertex': vertex,
                'face': correct_faces,
            }
    
    vertex = np.concatenate([_dict_mesh[name]['vertex'] for name in _dict_mesh], axis=1)[:3, :].transpose()
    
    total_faces = 0
    now_bias = 0
    for name in _dict_mesh:
        total_faces += len(_dict_mesh[name]['face'])
    faces = np.zeros((total_faces, 3), dtype=np.int64)
    tot = 0
    for name in _dict_mesh:
        f = np.array(_dict_mesh[name]['face'], dtype=np.int64)
        faces[tot:tot+f.shape[0]] = f + now_bias
        now_bias += _dict_mesh[name]['vertex'].shape[1]
        tot += f.shape[0]
    
    return vertex, faces

def process_armature(
    armature,
    arranged_bones,
) -> Tuple[np.ndarray, np.ndarray]:
    matrix_world = armature.matrix_world
    index = {}

    for (id, pbone) in enumerate(arranged_bones):
        index[pbone.name] = id
    
    root = armature.pose.bones[0]
    while root.parent is not None:
        root = root.parent
    m = np.array(matrix_world.to_4x4())
    scale_inv = np.linalg.inv(np.diag(matrix_world.to_scale()))
    rot = m[:3, :3]
    bias = m[:3, 3]
    
    s = []
    bpy.ops.object.editmode_toggle()
    edit_bones = armature.data.edit_bones
    
    J = len(arranged_bones)
    joints = np.zeros((J, 3), dtype=np.float32)
    tails = np.zeros((J, 3), dtype=np.float32)
    parents = []
    name_to_id = {}
    names = []
    matrix_local_stack = np.zeros((J, 4, 4), dtype=np.float32)
    for (id, pbone) in enumerate(arranged_bones):
        name = pbone.name
        names.append(name)
        matrix_local = np.array(pbone.bone.matrix_local)
        use_inherit_rotation = pbone.bone.use_inherit_rotation
        if use_inherit_rotation == False:
            add_warning(f"use_inherit_rotation of bone {name} is False !")
        head = rot @ matrix_local[0:3, 3] + bias
        s.append(head)
        edit_bone = edit_bones.get(name)
        tail = rot @ np.array(edit_bone.tail) + bias
        
        name_to_id[name] = id
        joints[id] = head
        tails[id] = tail
        parents.append(None if pbone.parent not in arranged_bones else name_to_id[pbone.parent.name])
        # remove scale part
        matrix_local[:, 3:4] = m @ matrix_local[:, 3:4]
        matrix_local[:3, :3] = scale_inv @ matrix_local[:3, :3]
        matrix_local_stack[id] = matrix_local
    bpy.ops.object.editmode_toggle()
    
    return joints, tails, parents, names, matrix_local_stack

def save_raw_data(
    path: str,
    vertices: ndarray,
    faces: ndarray,
    joints: Union[ndarray, None],
    tails: Union[ndarray, None],
    parents: Union[List[Union[int, None]], None],
    names: Union[List[str], None],
    matrix_local: Union[ndarray, None],
    target_count: int,
):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    vertices = np.array(mesh.vertices, dtype=np.float32)
    faces = np.array(mesh.faces, dtype=np.int64)
    if faces.shape[0] > target_count:
        vertices, faces = fast_simplification.simplify(vertices, faces, target_count=target_count)
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    
    new_vertices = np.array(mesh.vertices, dtype=np.float32)
    new_vertex_normals = np.array(mesh.vertex_normals, dtype=np.float32)
    new_faces = np.array(mesh.faces, dtype=np.int64)
    new_face_normals = np.array(mesh.face_normals, dtype=np.float32)
    if joints is not None:
        new_joints = np.array(joints, dtype=np.float32)
    else:
        new_joints = None
    raw_data = RawData(
        vertices=new_vertices,
        vertex_normals=new_vertex_normals,
        faces=new_faces,
        face_normals=new_face_normals,
        joints=new_joints,
        tails=tails,
        skin=None,
        no_skin=None,
        parents=parents,
        names=names,
        matrix_local=matrix_local,
    )
    raw_data.check()
    raw_data.save(path=path)

def extract_builtin(
    output_folder: str,
    target_count: int,
    num_runs: int,
    id: int,
    time: str,
    files: List[Union[str, str]],
):
    log_path = "./logs"
    log_path = os.path.join(log_path, time)

    num_files = len(files)
    gap = num_files // num_runs
    start = gap * id
    end = gap * (id + 1)
    if id+1==num_runs:
        end = num_files
    
    files = sorted(files)
    if end!=-1:
        files = files[:end]
    new_log(log_path, f"extract_builtin_{start}_{end}")
    tot = 0
    for file in tqdm(files[start:]):
        input_file = file[0]
        output_dir = file[1]
        clean_bpy()
        new_entry(input_file)
        try:
            print(f"Now processing {input_file}...")
            
            armature = load(input_file)
            
            print('save to:', output_dir)
            os.makedirs(output_dir, exist_ok=True)
            
            vertices, faces = process_mesh()
            if armature is not None:
                arranged_bones = get_arranged_bones(armature)
                joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)
                
            else:
                joints = None
                tails = None
                parents = None
                names = None
                matrix_local = None
            
            save_file = os.path.join(output_dir, 'raw_data.npz')
            save_raw_data(
                path=save_file,
                vertices=vertices,
                faces=faces-1,
                joints=joints,
                tails=tails,
                parents=parents,
                names=names,
                matrix_local=matrix_local,
                target_count=target_count,
            )
            
            tot += 1

        except ValueError as e:
            add_error(str(e))
            print(f"ValueError: {str(e)}")
        except RuntimeError as e:
            add_error(str(e))
            print(f"RuntimeError: {str(e)}")
        except TimeoutError as e:
            add_error("time out")
            print("TimeoutError: Processing timed out")
        except Exception as e:
            add_error(f"Unexpected error: {str(e)}")
            print(f"Unexpected error: {str(e)}")
    end_log()
    print(f"{tot} models processed")

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def nullable_string(val):
    if not val:
        return None
    return val

def get_files(
    data_name: str,
    input_dataset_dir: str,
    output_dataset_dir: str,
    inputs: Union[str, None]=None,
    require_suffix: List[str]=['obj','fbx','FBX','dae','glb','gltf','vrm'],
    force_override: bool=False,
    warning: bool=True,
) -> List[Tuple[str, str]]:
    
    files = [] # (input_file, output_dir)
    if inputs is not None: # specified input file(s)
        vis = {}
        inputs = inputs.split(',')
        for file in inputs:
            file_name = file.removeprefix("./")
            # remove suffix
            file_name = '.'.join(file_name.split('.')[:-1])
            output_dir = os.path.join(output_dataset_dir, file_name)
            raw_data_npz = os.path.join(output_dir, data_name)
            if not force_override and os.path.exists(raw_data_npz):
                continue
            if warning and output_dir in vis:
                print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
            vis[output_dir] = True
            files.append((file, output_dir))
    else:
        vis = {}
        for root, dirs, f in os.walk(input_dataset_dir):
            for file in f:
                if file.split('.')[-1] in require_suffix:
                    file_name = file.removeprefix("./")
                    # remove suffix
                    file_name = '.'.join(file_name.split('.')[:-1])
                    
                    output_dir = os.path.join(output_dataset_dir, os.path.relpath(root, input_dataset_dir), file_name)
                    raw_data_npz = os.path.join(output_dir, data_name)
                    
                    # Check if all required files exist
                    if not force_override and os.path.exists(raw_data_npz):
                        continue
                    if warning and output_dir in vis:
                        print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
                    vis[output_dir] = True
                    files.append((os.path.join(root, file), output_dir))

    return files

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--require_suffix', type=str, required=True)
    parser.add_argument('--faces_target_count', type=int, required=True)
    parser.add_argument('--num_runs', type=int, required=True)
    parser.add_argument('--force_override', type=str2bool, required=True)
    parser.add_argument('--id', type=int, required=True)
    parser.add_argument('--time', type=str, required=True)

    parser.add_argument('--input', type=nullable_string, required=False, default=None)
    parser.add_argument('--input_dir', type=nullable_string, required=False, default=None)
    parser.add_argument('--output_dir', type=nullable_string, required=False, default=None)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse()
    
    config = Box(yaml.safe_load(open(args.config, "r")))
    
    num_runs        = args.num_runs
    id              = args.id
    timestamp       = args.time
    require_suffix  = args.require_suffix.split(',')
    force_override  = args.force_override
    target_count    = args.faces_target_count
    
    if args.input_dir:
        config.input_dataset_dir = args.input_dir
    if args.output_dir:
        config.output_dataset_dir = args.output_dir
    
    assert config.input_dataset_dir is not None or args.input is None, 'you cannot specify both input and input_dir'

    files = get_files(
        data_name='raw_data.npz',
        inputs=args.input,
        input_dataset_dir=config.input_dataset_dir,
        output_dataset_dir=config.output_dataset_dir,
        require_suffix=require_suffix,
        force_override=force_override,
        warning=True,
    )

    extract_builtin(
        output_folder=config.output_dataset_dir,
        target_count=target_count,
        num_runs=num_runs,
        id=id,
        time=timestamp,
        files=files,
    )