import torch
from torch import nn, FloatTensor, LongTensor, Tensor
import torch.nn.functional as F
import numpy as np
from torch.nn.functional import pad
from typing import Dict, List
from transformers import AutoModelForCausalLM, AutoConfig
import math
import torch_scatter
from flash_attn.modules.mha import MHA

from .spec import ModelSpec, ModelInput
from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder

from ..data.utils import linear_blend_skinning

class FrequencyPositionalEmbedding(nn.Module):
    """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
    each feature dimension of `x[..., i]` into:
        [
            sin(x[..., i]),
            sin(f_1*x[..., i]),
            sin(f_2*x[..., i]),
            ...
            sin(f_N * x[..., i]),
            cos(x[..., i]),
            cos(f_1*x[..., i]),
            cos(f_2*x[..., i]),
            ...
            cos(f_N * x[..., i]),
            x[..., i]     # only present if include_input is True.
        ], here f_i is the frequency.

    Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
    If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
    Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].

    Args:
        num_freqs (int): the number of frequencies, default is 6;
        logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
            otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
        input_dim (int): the input dimension, default is 3;
        include_input (bool): include the input tensor or not, default is True.

    Attributes:
        frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
                otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);

        out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
            otherwise, it is input_dim * num_freqs * 2.

    """

    def __init__(
        self,
        num_freqs: int = 6,
        logspace: bool = True,
        input_dim: int = 3,
        include_input: bool = True,
        include_pi: bool = True,
    ) -> None:
        """The initialization"""

        super().__init__()

        if logspace:
            frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
        else:
            frequencies = torch.linspace(
                1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
            )

        if include_pi:
            frequencies *= torch.pi

        self.register_buffer("frequencies", frequencies, persistent=False)
        self.include_input = include_input
        self.num_freqs = num_freqs

        self.out_dim = self._get_dims(input_dim)

    def _get_dims(self, input_dim):
        temp = 1 if self.include_input or self.num_freqs == 0 else 0
        out_dim = input_dim * (self.num_freqs * 2 + temp)

        return out_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process.

        Args:
            x: tensor of shape [..., dim]

        Returns:
            embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
                where temp is 1 if include_input is True and 0 otherwise.
        """

        if self.num_freqs > 0:
            embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device)).view(
                *x.shape[:-1], -1
            )
            if self.include_input:
                return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
            else:
                return torch.cat((embed.sin(), embed.cos()), dim=-1)
        else:
            return x

class ResidualCrossAttn(nn.Module):
    def __init__(self, feat_dim: int, num_heads: int):
        super().__init__()
        assert feat_dim % num_heads == 0, "feat_dim must be divisible by num_heads"

        self.norm1 = nn.LayerNorm(feat_dim)
        self.norm2 = nn.LayerNorm(feat_dim)
        # self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)
        self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)
        self.ffn = nn.Sequential(
            nn.Linear(feat_dim, feat_dim * 4),
            nn.GELU(),
            nn.Linear(feat_dim * 4, feat_dim),
        )
        
    def forward(self, q, kv):
        residual = q
        attn_output = self.attention(q, x_kv=kv)
        x = self.norm1(residual + attn_output)
        x = self.norm2(x + self.ffn(x))
        return x

class BoneEncoder(nn.Module):
    def __init__(
        self,
        feat_bone_dim: int,
        feat_dim: int,
        embed_dim: int,
        num_heads: int,
        num_attn: int,
    ):
        super().__init__()
        self.feat_bone_dim = feat_bone_dim
        self.feat_dim = feat_dim
        self.num_heads = num_heads
        self.num_attn = num_attn
        
        self.position_embed = FrequencyPositionalEmbedding(input_dim=self.feat_bone_dim)

        self.bone_encoder = nn.Sequential(
            self.position_embed,
            nn.Linear(self.position_embed.out_dim, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim * 4),
            nn.LayerNorm(embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, feat_dim),
            nn.LayerNorm(feat_dim),
            nn.GELU(),
        )
        self.attn = nn.ModuleList()
        for _ in range(self.num_attn):
            self.attn.append(ResidualCrossAttn(feat_dim, self.num_heads))

    def forward(
        self,
        base_bone: FloatTensor,
        num_bones: LongTensor,
        parents: LongTensor,
        min_coord: FloatTensor,
        global_latents: FloatTensor,
    ):
        # base_bone: (B, J, C)
        B = base_bone.shape[0]
        J = base_bone.shape[1]
        x = self.bone_encoder((base_bone-min_coord[:, None, :]).reshape(-1, base_bone.shape[-1])).reshape(B, J, -1)

        latents = torch.cat([x, global_latents], dim=1)
        
        for (i, attn) in enumerate(self.attn):
            x = attn(x, latents)
        return x

class SkinweightPred(nn.Module):
    def __init__(self, in_dim, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, mlp_dim),
            nn.LayerNorm(mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.LayerNorm(mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.LayerNorm(mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.LayerNorm(mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, 1),
        )

    def forward(self, x):
        return self.net(x)

class UniRigSkin(ModelSpec):
    
    def process_fn(self, batch: List[ModelInput]) -> List[Dict]:
        max_bones = 0
        for b in batch:
            max_bones = max(max_bones, b.asset.J)
        res = []
        current_offset = 0
        for b in batch:
            vertex_groups = b.asset.sampled_vertex_groups
            current_offset += b.vertices.shape[0]
            # (N, J)
            voxel_skin = vertex_groups['voxel_skin']
            
            voxel_skin = np.pad(voxel_skin, ((0, 0), (0, max_bones-b.asset.J)), 'constant', constant_values=0.0)
            
            # (J, 4, 4)
            res.append({
                'voxel_skin': voxel_skin,
                'offset': current_offset,
            })
        return res
    
    def __init__(self, mesh_encoder, global_encoder, **kwargs):
        super().__init__()
        
        self.num_train_vertex       = kwargs['num_train_vertex']
        self.feat_dim               = kwargs['feat_dim']
        self.num_heads              = kwargs['num_heads']
        self.grid_size              = kwargs['grid_size']
        self.mlp_dim                = kwargs['mlp_dim']
        self.num_bone_attn          = kwargs['num_bone_attn']
        self.num_mesh_bone_attn     = kwargs['num_mesh_bone_attn']
        self.bone_embed_dim         = kwargs['bone_embed_dim']
        self.voxel_mask             = kwargs.get('voxel_mask', 2)

        self.mesh_encoder = get_mesh_encoder(**mesh_encoder)
        self.global_encoder = get_mesh_encoder(**global_encoder)
        if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj):
            self.feat_map = nn.Sequential(
                nn.Linear(mesh_encoder['enc_channels'][-1], self.feat_dim),
                nn.LayerNorm(self.feat_dim),
                nn.GELU(),
            )
        else:
            raise NotImplementedError()
        if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder):
            self.out_proj = nn.Sequential(
                nn.Linear(self.global_encoder.width, self.feat_dim),
                nn.LayerNorm(self.feat_dim),
                nn.GELU(),
            )
        else:
            raise NotImplementedError()

        self.bone_encoder = BoneEncoder(
            feat_bone_dim=3,
            feat_dim=self.feat_dim,
            embed_dim=self.bone_embed_dim,
            num_heads=self.num_heads,
            num_attn=self.num_bone_attn,
        )
        
        self.downscale = nn.Sequential(
            nn.Linear(2 * self.num_heads, self.num_heads),
            nn.LayerNorm(self.num_heads),
            nn.GELU(),
        )
        self.skinweight_pred = SkinweightPred(
            self.num_heads,
            self.mlp_dim,
        )
        
        self.mesh_bone_attn = nn.ModuleList()
        self.mesh_bone_attn.extend([
            ResidualCrossAttn(self.feat_dim, self.num_heads) for _ in range(self.num_mesh_bone_attn)
        ])

        self.qmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads)
        self.kmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads)

        self.voxel_skin_embed = nn.Linear(1, self.num_heads)
        self.voxel_skin_norm = nn.LayerNorm(self.num_heads)
        self.attn_skin_norm = nn.LayerNorm(self.num_heads)

    def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor:
        assert not torch.isnan(vertices).any()
        assert not torch.isnan(normals).any()
        if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder):
            if (len(vertices.shape) == 3):
                shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices, feats=normals)
            else:
                shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0))
            latents = self.out_proj(latents)
            return latents
        else:
            raise NotImplementedError()

    def _get_predict(self, batch: Dict) -> FloatTensor:
        '''
        Return predicted skin.
        '''
        
        num_bones: Tensor = batch['num_bones']
        vertices: FloatTensor = batch['vertices'] # (B, N, 3)
        normals: FloatTensor = batch['normals']
        joints: FloatTensor = batch['joints']
        tails: FloatTensor = batch['tails']
        voxel_skin: FloatTensor = batch['voxel_skin']
        parents: LongTensor = batch['parents']
        
        # turn inputs' dtype into model's dtype
        dtype = next(self.parameters()).dtype
        vertices = vertices.type(dtype)
        normals = normals.type(dtype)
        joints = joints.type(dtype)
        tails = tails.type(dtype)
        voxel_skin = voxel_skin.type(dtype)
        
        B = vertices.shape[0]
        N = vertices.shape[1]
        J = joints.shape[1]
        
        assert vertices.dim() == 3
        assert normals.dim() == 3
        
        part_offset = torch.tensor([(i+1)*N for i in range(B)], dtype=torch.int64, device=vertices.device)
        idx_ptr = torch.nn.functional.pad(part_offset, (1, 0), value=0)
        min_coord = torch_scatter.segment_csr(vertices.reshape(-1, 3), idx_ptr, reduce="min")

        pack = []
        if self.training:
            train_indices = torch.randperm(N)[:self.num_train_vertex]
            pack.append(train_indices)
        else:
            for i in range((N + self.num_train_vertex - 1) // self.num_train_vertex):
                pack.append(torch.arange(i*self.num_train_vertex, min((i+1)*self.num_train_vertex, N)))
        
        # (B, seq_len, feat_dim)
        global_latents = self.encode_mesh_cond(vertices, normals)
        bone_feat = self.bone_encoder(
            base_bone=joints,
            num_bones=num_bones,
            parents=parents,
            min_coord=min_coord,
            global_latents=global_latents,
        )
        
        if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj):
            feat = torch.cat([vertices, normals, torch.zeros_like(vertices)], dim=-1)
            ptv3_input = {
                'coord': vertices.reshape(-1, 3),
                'feat': feat.reshape(-1, 9),
                'offset': torch.tensor(batch['offset']),
                'grid_size': self.grid_size,
            }
            if not self.training:
                # must cast to float32 to avoid sparse-conv precision bugs
                with torch.autocast(device_type='cuda', dtype=torch.float32):
                    mesh_feat = self.mesh_encoder(ptv3_input).feat
                    mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim)
            else:
                mesh_feat = self.mesh_encoder(ptv3_input).feat
                mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim)
            mesh_feat = mesh_feat.type(dtype)
        else:
            raise NotImplementedError()

        # (B, J + seq_len, feat_dim)
        latents = torch.cat([bone_feat, global_latents], dim=1)
        # (B, N, feat_dim)
        for block in self.mesh_bone_attn:
            mesh_feat = block(
                q=mesh_feat,
                kv=latents,
            )

        # trans to (B, num_heads, J, feat_dim)
        bone_feat = self.kmesh(bone_feat).view(B, J, self.num_heads, self.feat_dim).transpose(1, 2)

        skin_pred_list = []
        if not self.training:
            skin_mask = voxel_skin.clone()
            for b in range(B):
                num = num_bones[b]
                for i in range(num):
                    p = parents[b, i]
                    if p < 0:
                        continue
                    skin_mask[b, :, p] += skin_mask[b, :, i]
        for indices in pack:
            cur_N = len(indices)
            # trans to (B, num_heads, N, feat_dim)
            cur_mesh_feat = self.qmesh(mesh_feat[:, indices]).view(B, cur_N, self.num_heads, self.feat_dim).transpose(1, 2)

            # attn_weight shape : (B, num_heads, N, J)
            attn_weight = F.softmax(torch.bmm(
                cur_mesh_feat.reshape(B * self.num_heads, cur_N, -1), 
                bone_feat.transpose(-2, -1).reshape(B * self.num_heads, -1, J)
            ) / math.sqrt(self.feat_dim), dim=-1, dtype=dtype)
            # (B, num_heads, N, J) -> (B, N, J, num_heads)
            attn_weight = attn_weight.reshape(B, self.num_heads, cur_N, J).permute(0, 2, 3, 1)
            attn_weight = self.attn_skin_norm(attn_weight)
            
            embed_voxel_skin = self.voxel_skin_embed(voxel_skin[:, indices].reshape(B, cur_N, J, 1))
            embed_voxel_skin = self.voxel_skin_norm(embed_voxel_skin)
            
            attn_weight = torch.cat([attn_weight, embed_voxel_skin], dim=-1)
            attn_weight = self.downscale(attn_weight)
        
            # (B, N, J, num_heads * (1+c)) -> (B, N, J)
            skin_pred = torch.zeros(B, cur_N, J).to(attn_weight.device, dtype)
            for i in range(B):
                # (N*J, C)
                input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1])
                
                pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i])
                skin_pred[i, :, :num_bones[i]] = F.softmax(pred)
            skin_pred_list.append(skin_pred)
        skin_pred_list = torch.cat(skin_pred_list, dim=1)
        for i in range(B):
            n = num_bones[i]
            skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] * torch.pow(skin_mask[i, :, :n], self.voxel_mask)
            skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] / skin_pred_list[i, :, :n].sum(dim=-1, keepdim=True)
        return skin_pred_list, torch.cat(pack, dim=0)
    
    def predict_step(self, batch: Dict):
        with torch.no_grad():
            num_bones: Tensor = batch['num_bones']
            
            skin_pred, _ = self._get_predict(batch=batch)
            outputs = []
            for i in range(skin_pred.shape[0]):
                outputs.append(skin_pred[i, :, :num_bones[i]])
            return outputs