# -*- coding: utf-8 -*-
#
# This file is part of UniRig.
# 
# This file is derived from https://github.com/NeuralCarver/Michelangelo
#
# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
# Copyright (c) 2025 VAST-AI-Research and contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import torch
import torch.nn as nn
from typing import Optional, Union
from einops import repeat
import math
from torch_cluster import fps
import random
import time
import numpy as np

from ..modules import checkpoint
from ..modules.embedder import FourierEmbedder
from ..modules.transformer_blocks import (
    ResidualCrossAttentionBlock,
    Transformer
)

from .tsal_base import ShapeAsLatentModule


class CrossAttentionEncoder(nn.Module):

    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[torch.dtype],
                 num_latents: int,
                 fourier_embedder: FourierEmbedder,
                 point_feats: int,
                 width: int,
                 heads: int,
                 layers: int,
                 init_scale: float = 0.25,
                 qkv_bias: bool = True,
                 flash: bool = False,
                 use_ln_post: bool = False,
                 use_checkpoint: bool = False,
                 query_method: bool = False,
                 use_full_input: bool = True,
                 token_num: int = 256,
                 no_query: bool=False):

        super().__init__()

        self.query_method = query_method
        self.token_num = token_num
        self.use_full_input = use_full_input

        self.use_checkpoint = use_checkpoint
        self.num_latents = num_latents

        if no_query:
            self.query = None
        else:
            self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)

        self.fourier_embedder = fourier_embedder
        self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
        self.cross_attn = ResidualCrossAttentionBlock(
            device=device,
            dtype=dtype,
            width=width,
            heads=heads,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
        )

        self.self_attn = Transformer(
            device=device,
            dtype=dtype,
            n_ctx=num_latents,
            width=width,
            layers=layers,
            heads=heads,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_checkpoint=False
        )

        if use_ln_post:
            self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
        else:
            self.ln_post = None

    def _forward(self, pc, feats):
        """

        Args:
            pc (torch.FloatTensor): [B, N, 3]
            feats (torch.FloatTensor or None): [B, N, C]

        Returns:

        """
        if self.query_method:
            token_num = self.num_latents
            bs = pc.shape[0] 
            data = self.fourier_embedder(pc) 
            if feats is not None: 
                data = torch.cat([data, feats], dim=-1)
            data = self.input_proj(data) 

            query = repeat(self.query, "m c -> b m c", b=bs) 

            latents = self.cross_attn(query, data)
            latents = self.self_attn(latents)

            if self.ln_post is not None:
                latents = self.ln_post(latents)

            pre_pc = None
        else:

            if isinstance(self.token_num, int):
                token_num = self.token_num
            else:
                token_num = random.choice(self.token_num)

            if self.training:
                rng = np.random.default_rng()
            else:
                rng = np.random.default_rng(seed=0)
            ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1])

            pre_pc = pc[:,ind,:]
            pre_feats = feats[:,ind,:]


            B, N, D = pre_pc.shape           
            C = pre_feats.shape[-1]
            ###### fps
            pos = pre_pc.view(B*N, D)
            pos_feats = pre_feats.view(B*N, C)
            batch = torch.arange(B).to(pc.device)
            batch = torch.repeat_interleave(batch, N)

            idx = fps(pos, batch, ratio=1. / 4, random_start=self.training)

            sampled_pc = pos[idx]
            sampled_pc = sampled_pc.view(B, -1, 3)

            sampled_feats = pos_feats[idx]
            sampled_feats = sampled_feats.view(B, -1, C)

            ######
            if self.use_full_input:
                data = self.fourier_embedder(pc) 
            else:
                data = self.fourier_embedder(pre_pc) 

            if feats is not None: 
                if not self.use_full_input:
                    feats = pre_feats
                data = torch.cat([data, feats], dim=-1) 
            data = self.input_proj(data) 

            sampled_data = self.fourier_embedder(sampled_pc) 
            if feats is not None: 
                sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) 
            sampled_data = self.input_proj(sampled_data) 

            latents = self.cross_attn(sampled_data, data) 
            latents = self.self_attn(latents)

            if self.ln_post is not None:
                latents = self.ln_post(latents)

            pre_pc = torch.cat([pre_pc, pre_feats], dim=-1)

        return latents, pc, token_num, pre_pc

    def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
        """

        Args:
            pc (torch.FloatTensor): [B, N, 3]
            feats (torch.FloatTensor or None): [B, N, C]

        Returns:
            dict
        """

        return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)


class CrossAttentionDecoder(nn.Module):

    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[torch.dtype],
                 num_latents: int,
                 out_channels: int,
                 fourier_embedder: FourierEmbedder,
                 width: int,
                 heads: int,
                 init_scale: float = 0.25,
                 qkv_bias: bool = True,
                 flash: bool = False,
                 use_checkpoint: bool = False,
                 mlp_width_scale: int = 4,
                 supervision_type: str = 'occupancy'):

        super().__init__()

        self.use_checkpoint = use_checkpoint
        self.fourier_embedder = fourier_embedder
        self.supervision_type = supervision_type

        self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)

        self.cross_attn_decoder = ResidualCrossAttentionBlock(
            device=device,
            dtype=dtype,
            n_data=num_latents,
            width=width,
            heads=heads,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            mlp_width_scale=mlp_width_scale,
        )

        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
        self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
        if self.supervision_type == 'occupancy-sdf':
            self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype)



    def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
        if next(self.query_proj.parameters()).dtype == torch.float16:
            queries = queries.half()
            latents = latents.half()
        # print(f"queries: {queries.dtype}, {queries.device}")
        # print(f"latents: {latents.dtype}, {latents.device}"z)
        queries = self.query_proj(self.fourier_embedder(queries))
        x = self.cross_attn_decoder(queries, latents)
        x = self.ln_post(x)
        x_1 = self.output_proj(x)
        if self.supervision_type == 'occupancy-sdf':
            x_2 = self.output_proj_sdf(x)
            return x_1, x_2
        else:
            return x_1

    def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
        return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)


class ShapeAsLatentPerceiver(ShapeAsLatentModule):
    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[torch.dtype],
                 num_latents: int,
                 point_feats: int = 0,
                 embed_dim: int = 0,
                 num_freqs: int = 8,
                 include_pi: bool = True,
                 width: int,
                 heads: int,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 decoder_width: Optional[int] = None,
                 init_scale: float = 0.25,
                 qkv_bias: bool = True,
                 flash: bool = False,
                 use_ln_post: bool = False,
                 use_checkpoint: bool = False,
                 supervision_type: str = 'occupancy',
                 query_method: bool = False,
                 token_num: int = 256,
                 grad_type: str = "numerical",
                 grad_interval: float = 0.005,
                 use_full_input: bool = True,
                 freeze_encoder: bool = False,
                 decoder_mlp_width_scale: int = 4,
                 residual_kl: bool = False,
                 ):

        super().__init__()

        self.use_checkpoint = use_checkpoint

        self.num_latents = num_latents
        assert grad_type in ["numerical", "analytical"]
        self.grad_type = grad_type
        self.grad_interval = grad_interval
        self.supervision_type = supervision_type
        self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)

        init_scale = init_scale * math.sqrt(1.0 / width)
        self.encoder = CrossAttentionEncoder(
            device=device,
            dtype=dtype,
            fourier_embedder=self.fourier_embedder,
            num_latents=num_latents,
            point_feats=point_feats,
            width=width,
            heads=heads,
            layers=num_encoder_layers,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_ln_post=use_ln_post,
            use_checkpoint=use_checkpoint,
            query_method=query_method,
            use_full_input=use_full_input,
            token_num=token_num
        )

        self.embed_dim = embed_dim
        self.residual_kl = residual_kl
        if decoder_width is None:
            decoder_width = width
        if embed_dim > 0:
            # VAE embed
            self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
            self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype)
            self.latent_shape = (num_latents, embed_dim)
            if self.residual_kl:
                assert self.post_kl.out_features % self.post_kl.in_features == 0
                assert self.pre_kl.in_features % self.pre_kl.out_features == 0 
        else:
            self.latent_shape = (num_latents, width)

        self.transformer = Transformer(
            device=device,
            dtype=dtype,
            n_ctx=num_latents,
            width=decoder_width,
            layers=num_decoder_layers,
            heads=heads,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_checkpoint=use_checkpoint
        )

        # geometry decoder
        self.geo_decoder = CrossAttentionDecoder(
            device=device,
            dtype=dtype,
            fourier_embedder=self.fourier_embedder,
            out_channels=1,
            num_latents=num_latents,
            width=decoder_width,
            heads=heads,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_checkpoint=use_checkpoint,
            supervision_type=supervision_type,
            mlp_width_scale=decoder_mlp_width_scale
        )

        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False
            for p in self.pre_kl.parameters():
                p.requires_grad = False
            print("freeze encoder and pre kl")

    def forward(self,
                pc: torch.FloatTensor,
                feats: torch.FloatTensor,
                volume_queries: torch.FloatTensor,
                sample_posterior: bool = True):
        """

        Args:
            pc (torch.FloatTensor): [B, N, 3]
            feats (torch.FloatTensor or None): [B, N, C]
            volume_queries (torch.FloatTensor): [B, P, 3]
            sample_posterior (bool):

        Returns:
            logits (torch.FloatTensor): [B, P]
            center_pos (torch.FloatTensor): [B, M, 3]
            posterior (DiagonalGaussianDistribution or None).

        """

        latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)

        latents = self.decode(latents)
        logits = self.query_geometry(volume_queries, latents)

        return logits, center_pos, posterior


class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):

    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[str],
                 num_latents: int,
                 point_feats: int = 0,
                 embed_dim: int = 0,
                 num_freqs: int = 8,
                 include_pi: bool = True,
                 width: int,
                 heads: int,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 decoder_width: Optional[int] = None,
                 init_scale: float = 0.25,
                 qkv_bias: bool = True,
                 flash: bool = False,
                 use_ln_post: bool = False,
                 use_checkpoint: bool = False,
                 supervision_type: str = 'occupancy',
                 grad_type: str = "numerical",
                 grad_interval: float = 0.005,
                 query_method: bool = False,
                 use_full_input: bool = True,
                 token_num: int = 256,
                 freeze_encoder: bool = False,
                 decoder_mlp_width_scale: int = 4,
                 residual_kl: bool = False,
                ):

        MAP_DTYPE = {
            'float32': torch.float32,
            'float16': torch.float16,
            'bfloat16': torch.bfloat16,
        }
        if dtype is not None:
            dtype = MAP_DTYPE[dtype]
        super().__init__(
            device=device,
            dtype=dtype,
            num_latents=1 + num_latents,
            point_feats=point_feats,
            embed_dim=embed_dim,
            num_freqs=num_freqs,
            include_pi=include_pi,
            width=width,
            decoder_width=decoder_width,
            heads=heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_ln_post=use_ln_post,
            use_checkpoint=use_checkpoint,
            supervision_type=supervision_type,
            grad_type=grad_type,
            grad_interval=grad_interval,
            query_method=query_method,
            token_num=token_num,
            use_full_input=use_full_input,
            freeze_encoder=freeze_encoder,
            decoder_mlp_width_scale=decoder_mlp_width_scale,
            residual_kl=residual_kl,
        )

        self.width = width

    def encode(self,
               pc: torch.FloatTensor,
               feats: Optional[torch.FloatTensor] = None,
               sample_posterior: bool = True,
               only_shape: bool=False):
        """

        Args:
            pc (torch.FloatTensor): [B, N, 3]
            feats (torch.FloatTensor or None): [B, N, c]
            sample_posterior (bool):

        Returns:
            shape_embed (torch.FloatTensor)
            kl_embed (torch.FloatTensor):
            posterior (DiagonalGaussianDistribution or None):
        """

        shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats)
        if only_shape:
            return shape_embed
        kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)

        return shape_embed, kl_embed, posterior, token_num, pre_pc

    def encode_latents(self,
                       pc: torch.FloatTensor,
                       feats: Optional[torch.FloatTensor] = None):

        x, _, token_num, pre_pc = self.encoder(pc, feats)

        shape_embed = x[:, 0]
        # latents = x[:, 1:]
        # use all tokens
        latents = x

        return shape_embed, latents, token_num, pre_pc

    def forward(self,
                pc: torch.FloatTensor,
                feats: torch.FloatTensor,
                volume_queries: torch.FloatTensor,
                sample_posterior: bool = True):
        raise NotImplementedError()

#####################################################
# a simplified verstion of perceiver encoder
#####################################################

class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule):
    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[Union[torch.dtype, str]],
                 num_latents: int,
                 point_feats: int = 0,
                 embed_dim: int = 0,
                 num_freqs: int = 8,
                 include_pi: bool = True,
                 width: int,
                 heads: int,
                 num_encoder_layers: int,
                 init_scale: float = 0.25,
                 qkv_bias: bool = True,
                 flash: bool = False,
                 use_ln_post: bool = False,
                 use_checkpoint: bool = False,
                 supervision_type: str = 'occupancy',
                 query_method: bool = False,
                 token_num: int = 256,
                 grad_type: str = "numerical",
                 grad_interval: float = 0.005,
                 use_full_input: bool = True,
                 freeze_encoder: bool = False,
                 residual_kl: bool = False,
                 ):

        super().__init__()


        MAP_DTYPE = {
            'float32': torch.float32,
            'float16': torch.float16,
            'bfloat16': torch.bfloat16,
        }

        if dtype is not None and isinstance(dtype, str):
            dtype = MAP_DTYPE[dtype]

        self.use_checkpoint = use_checkpoint

        self.num_latents = num_latents
        assert grad_type in ["numerical", "analytical"]
        self.grad_type = grad_type
        self.grad_interval = grad_interval
        self.supervision_type = supervision_type
        self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)

        init_scale = init_scale * math.sqrt(1.0 / width)
        self.encoder = CrossAttentionEncoder(
            device=device,
            dtype=dtype,
            fourier_embedder=self.fourier_embedder,
            num_latents=num_latents,
            point_feats=point_feats,
            width=width,
            heads=heads,
            layers=num_encoder_layers,
            init_scale=init_scale,
            qkv_bias=qkv_bias,
            flash=flash,
            use_ln_post=use_ln_post,
            use_checkpoint=use_checkpoint,
            query_method=query_method,
            use_full_input=use_full_input,
            token_num=token_num,
            no_query=True,
        )

        self.embed_dim = embed_dim
        self.residual_kl = residual_kl
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False
            print("freeze encoder")
        self.width = width

    def encode_latents(self,
                       pc: torch.FloatTensor,
                       feats: Optional[torch.FloatTensor] = None):

        x, _, token_num, pre_pc = self.encoder(pc, feats)

        shape_embed = x[:, 0]
        latents = x

        return shape_embed, latents, token_num, pre_pc

    def forward(self):
        raise NotImplementedError()