# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from math import isqrt
from typing import Literal, Optional
import torch
from einops import rearrange, repeat
from tqdm import tqdm

from depth_anything_3.specs import Gaussians
from depth_anything_3.utils.camera_trj_helpers import (
    interpolate_extrinsics,
    interpolate_intrinsics,
    render_dolly_zoom_path,
    render_stabilization_path,
    render_wander_path,
    render_wobble_inter_path,
)
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov
from depth_anything_3.utils.logger import logger

try:
    from gsplat import rasterization
except ImportError:
    logger.warn(
        "Dependency `gsplat` is required for rendering 3DGS. "
        "Install via: pip install git+https://github.com/nerfstudio-project/"
        "gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"
    )


def render_3dgs(
    extrinsics: torch.Tensor,  # "batch_views 4 4", w2c
    intrinsics: torch.Tensor,  # "batch_views 3 3", normalized
    image_shape: tuple[int, int],
    gaussian: Gaussians,
    background_color: Optional[torch.Tensor] = None,  # "batch_views 3"
    use_sh: bool = True,
    num_view: int = 1,
    color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D",
    **kwargs,
) -> tuple[
    torch.Tensor,  # "batch_views 3 height width"
    torch.Tensor,  # "batch_views height width"
]:
    # extract gaussian params
    gaussian_means = gaussian.means
    gaussian_scales = gaussian.scales
    gaussian_quats = gaussian.rotations
    gaussian_opacities = gaussian.opacities
    gaussian_sh_coefficients = gaussian.harmonics
    b, _, _ = extrinsics.shape

    if background_color is None:
        background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to(
            gaussian_sh_coefficients
        )

    if use_sh:
        _, _, _, n = gaussian_sh_coefficients.shape
        degree = isqrt(n) - 1
        shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
    else:  # use color
        shs = (
            gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous()
        )  # (b, g, c), normed to (0, 1)

    h, w = image_shape

    fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
    tan_fov_x = (0.5 * fov_x).tan()
    tan_fov_y = (0.5 * fov_y).tan()
    focal_length_x = w / (2 * tan_fov_x)
    focal_length_y = h / (2 * tan_fov_y)

    view_matrix = extrinsics.float()

    all_images = []
    all_radii = []
    all_depths = []
    # render view in a batch based, each batch contains one scene
    # assume the Gaussian parameters are originally repeated along the view dim
    batch_scene = b // num_view

    def index_i_gs_attr(full_attr, idx):
        # return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0]
        return full_attr[idx]

    for i in range(batch_scene):
        K = repeat(
            torch.tensor(
                [
                    [0, 0, w / 2.0],
                    [0, 0, h / 2.0],
                    [0, 0, 1],
                ]
            ),
            "i j -> v i j",
            v=num_view,
        ).to(gaussian_means)
        K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i]
        K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i]

        i_means = index_i_gs_attr(gaussian_means, i)  # [N, 3]
        i_scales = index_i_gs_attr(gaussian_scales, i)
        i_quats = index_i_gs_attr(gaussian_quats, i)
        i_opacities = index_i_gs_attr(gaussian_opacities, i)  # [N,]
        i_colors = index_i_gs_attr(shs, i)  # [N, K, 3]
        i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i]  # [v, 4, 4]
        i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[
            i
        ]  # [v, 3]

        render_colors, render_alphas, info = rasterization(
            means=i_means,
            quats=i_quats,  # [N, 4]
            scales=i_scales,  # [N, 3]
            opacities=i_opacities,
            colors=i_colors,
            viewmats=i_viewmats,  # [v, 4, 4]
            Ks=K,  # [v, 3, 3]
            backgrounds=i_backgrounds,
            render_mode=color_mode,
            width=w,
            height=h,
            packed=False,
            sh_degree=degree if use_sh else None,
        )
        depth = render_colors[..., -1].unbind(dim=0)

        image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0)
        radii = info["radii"].unbind(dim=0)
        try:
            info["means2d"].retain_grad()  # [1, N, 2]
        except Exception:
            pass
        all_images.extend(image)
        all_depths.extend(depth)
        all_radii.extend(radii)

    return torch.stack(all_images), torch.stack(all_depths)


def run_renderer_in_chunk_w_trj_mode(
    gaussians: Gaussians,
    extrinsics: torch.Tensor,  # world2cam, "batch view 4 4" | "batch view 3 4"
    intrinsics: torch.Tensor,  # unnormed intrinsics, "batch view 3 3"
    image_shape: tuple[int, int],
    chunk_size: Optional[int] = 8,
    trj_mode: Literal[
        "original",
        "smooth",
        "interpolate",
        "interpolate_smooth",
        "wander",
        "dolly_zoom",
        "extend",
        "wobble_inter",
    ] = "smooth",
    input_shape: Optional[tuple[int, int]] = None,
    enable_tqdm: Optional[bool] = False,
    **kwargs,
) -> tuple[
    torch.Tensor,  # color, "batch view 3 height width"
    torch.Tensor,  # depth, "batch view height width"
]:
    cam2world = affine_inverse(as_homogeneous(extrinsics))
    if input_shape is not None:
        in_h, in_w = input_shape
    else:
        in_h, in_w = image_shape
    intr_normed = intrinsics.clone().detach()
    intr_normed[..., 0, :] /= in_w
    intr_normed[..., 1, :] /= in_h
    if extrinsics.shape[1] <= 1:
        assert trj_mode in [
            "wander",
            "dolly_zoom",
        ], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1"

    def _smooth_trj_fn_batch(raw_c2ws, k_size=50):
        try:
            smooth_c2ws = torch.stack(
                [render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws],
                dim=0,
            )
        except Exception as e:
            print(f"[DEBUG] Path smoothing failed with error: {e}.")
            smooth_c2ws = raw_c2ws
        return smooth_c2ws

    # get rendered trj
    if trj_mode == "original":
        tgt_c2w = cam2world
        tgt_intr = intr_normed
    elif trj_mode == "smooth":
        tgt_c2w = _smooth_trj_fn_batch(cam2world)
        tgt_intr = intr_normed
    elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]:
        inter_len = 8
        total_len = (cam2world.shape[1] - 1) * inter_len
        if total_len > 24 * 18:  # no more than 18s
            inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1))
        if total_len < 24 * 2:  # no less than 2s
            inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1))

        if inter_len > 2:
            t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device)
            t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
            tgt_c2w_b = []
            tgt_intr_b = []
            for b_idx in range(cam2world.shape[0]):
                tgt_c2w = []
                tgt_intr = []
                for cur_idx in range(cam2world.shape[1] - 1):
                    tgt_c2w.append(
                        interpolate_extrinsics(
                            cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t
                        )[(0 if cur_idx == 0 else 1) :]
                    )
                    tgt_intr.append(
                        interpolate_intrinsics(
                            intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t
                        )[(0 if cur_idx == 0 else 1) :]
                    )
                tgt_c2w_b.append(torch.cat(tgt_c2w))
                tgt_intr_b.append(torch.cat(tgt_intr))
            tgt_c2w = torch.stack(tgt_c2w_b)  # b v 4 4
            tgt_intr = torch.stack(tgt_intr_b)  # b v 3 3
        else:
            tgt_c2w = cam2world
            tgt_intr = intr_normed
        if trj_mode in ["interpolate_smooth", "extend"]:
            tgt_c2w = _smooth_trj_fn_batch(tgt_c2w)
        if trj_mode == "extend":
            # apply dolly_zoom and wander in the middle frame
            assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently."
            mid_idx = tgt_c2w.shape[1] // 2
            c2w_wd, intr_wd = render_wander_path(
                tgt_c2w[0, mid_idx],
                tgt_intr[0, mid_idx],
                h=in_h,
                w=in_w,
                num_frames=max(36, min(60, mid_idx // 2)),
                max_disp=24.0,
            )
            c2w_dz, intr_dz = render_dolly_zoom_path(
                tgt_c2w[0, mid_idx],
                tgt_intr[0, mid_idx],
                h=in_h,
                w=in_w,
                num_frames=max(36, min(60, mid_idx // 2)),
            )
            tgt_c2w = torch.cat(
                [
                    tgt_c2w[:, :mid_idx],
                    c2w_wd.unsqueeze(0),
                    c2w_dz.unsqueeze(0),
                    tgt_c2w[:, mid_idx:],
                ],
                dim=1,
            )
            tgt_intr = torch.cat(
                [
                    tgt_intr[:, :mid_idx],
                    intr_wd.unsqueeze(0),
                    intr_dz.unsqueeze(0),
                    tgt_intr[:, mid_idx:],
                ],
                dim=1,
            )
    elif trj_mode in ["wander", "dolly_zoom"]:
        if trj_mode == "wander":
            render_fn = render_wander_path
            extra_kwargs = {"max_disp": 24.0}
        else:
            render_fn = render_dolly_zoom_path
            extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0}
        tgt_c2w = []
        tgt_intr = []
        for b_idx in range(cam2world.shape[0]):
            c2w_i, intr_i = render_fn(
                cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs
            )
            tgt_c2w.append(c2w_i)
            tgt_intr.append(intr_i)
        tgt_c2w = torch.stack(tgt_c2w)
        tgt_intr = torch.stack(tgt_intr)
    elif trj_mode == "wobble_inter":
        tgt_c2w, tgt_intr = render_wobble_inter_path(
            cam2world=cam2world,
            intr_normed=intr_normed,
            inter_len=10,
            n_skip=3,
        )
    else:
        raise Exception(f"trj mode [{trj_mode}] is not implemented.")

    _, v = tgt_c2w.shape[:2]
    tgt_extr = affine_inverse(tgt_c2w)
    if chunk_size is None:
        chunk_size = v
    chunk_size = min(v, chunk_size)
    all_colors = []
    all_depths = []
    for chunk_idx in tqdm(
        range(math.ceil(v / chunk_size)),
        desc="Rendering novel views",
        disable=(not enable_tqdm),
        leave=False,
    ):
        s = int(chunk_idx * chunk_size)
        e = int((chunk_idx + 1) * chunk_size)
        cur_n_view = tgt_extr[:, s:e].shape[1]
        color, depth = render_3dgs(
            extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."),  # w2c
            intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."),  # normed
            image_shape=image_shape,
            gaussian=gaussians,
            num_view=cur_n_view,
            **kwargs,
        )
        all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view))
        all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view))
    all_colors = torch.cat(all_colors, dim=1)
    all_depths = torch.cat(all_depths, dim=1)

    return all_colors, all_depths
