// ----------------------------------------------------------------------------
// -                        Open3D: www.open3d.org                            -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2023 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------

#pragma once

#include "open3d/utility/Helper.isph"

// This C/C++ header is compatible with ISPC.
// Variadic macros, i.e. ellipsis (...) and __VA_ARGS__ are also supported.
#include "open3d/utility/Preprocessor.h"

/// OPEN3D_SPECIALIZED(T, ISPCFunction)
///
/// Resolves the given template-like function to its specialization of type T.
#define OPEN3D_SPECIALIZED(T, ISPCFunction) \
    OPEN3D_CONCAT(ISPCFunction, OPEN3D_CONCAT(_, T))

/// OPEN3D_INSTANTIATE_TEMPLATE()
///
/// Instantiates a template-like function defined by the TEMPLATE(T) macro for
/// the following types:
/// - unsigned + signed {8,16,32,64} bit integers,
/// - float, double
///
/// Example:
///
/// \code
/// #define TEMPLATE(T)                                         \
///     static inline void OPEN3D_SPECIALIZED(T, TemplateFunc)( \
///             int64_t idx, /* more parameters */) {           \
///         // Do something ...
/// }
///
/// OPEN3D_INSTANTIATE_TEMPLATE()
///
/// #undef TEMPLATE
/// \endcode
#define OPEN3D_INSTANTIATE_TEMPLATE() \
    TEMPLATE(uint8_t)                 \
    TEMPLATE(int8_t)                  \
    TEMPLATE(uint16_t)                \
    TEMPLATE(int16_t)                 \
    TEMPLATE(uint32_t)                \
    TEMPLATE(int32_t)                 \
    TEMPLATE(uint64_t)                \
    TEMPLATE(int64_t)                 \
    TEMPLATE(float)                   \
    TEMPLATE(double)

/// OPEN3D_INSTANTIATE_TEMPLATE_WITH_BOOL()
///
/// Instantiates a template-like function defined by the TEMPLATE(T) macro for
/// the following types:
/// - bool
/// - unsigned + signed {8,16,32,64} bit integers,
/// - float, double
///
/// See OPEN3D_INSTANTIATE_TEMPLATE() for more details.
#define OPEN3D_INSTANTIATE_TEMPLATE_WITH_BOOL() \
    TEMPLATE(bool)                              \
    OPEN3D_INSTANTIATE_TEMPLATE()

/// OPEN3D_INSTANTIATE_INTEGER_TEMPLATE()
///
/// Instantiates a template-like function defined by the TEMPLATE(T) macro for
/// the following types:
/// - unsigned + signed {8,16,32,64} bit integers,
///
/// See OPEN3D_INSTANTIATE_TEMPLATE() for more details.
#define OPEN3D_INSTANTIATE_INTEGER_TEMPLATE() \
    TEMPLATE(uint8_t)                         \
    TEMPLATE(int8_t)                          \
    TEMPLATE(uint16_t)                        \
    TEMPLATE(int16_t)                         \
    TEMPLATE(uint32_t)                        \
    TEMPLATE(int32_t)                         \
    TEMPLATE(uint64_t)                        \
    TEMPLATE(int64_t)

/// OPEN3D_INSTANTIATE_INTEGER_TEMPLATE_WITH_BOOL()
///
/// Instantiates a template-like function defined by the TEMPLATE(T) macro for
/// the following types:
/// - bool
/// - unsigned + signed {8,16,32,64} bit integers,
///
/// See OPEN3D_INSTANTIATE_TEMPLATE() for more details.
#define OPEN3D_INSTANTIATE_INTEGER_TEMPLATE_WITH_BOOL() \
    TEMPLATE(bool)                                      \
    OPEN3D_INSTANTIATE_INTEGER_TEMPLATE()

/// OPEN3D_INSTANTIATE_FLOAT_TEMPLATE()
///
/// Instantiates a template-like function defined by the TEMPLATE(T) macro for
/// the following types:
/// - float, double
///
/// See OPEN3D_INSTANTIATE_TEMPLATE() for more details.
#define OPEN3D_INSTANTIATE_FLOAT_TEMPLATE() \
    TEMPLATE(float)                         \
    TEMPLATE(double)

/// OPEN3D_EXPORT_VECTORIZED(ISPCKernel, ISPCFunction, ...)
///
/// Defines a kernel which calls the provided function.
///
/// Use the OPEN3D_VECTORIZED macro to call the kernel in the C++
/// source file.
#define OPEN3D_EXPORT_VECTORIZED(ISPCKernel, ISPCFunction, ...) \
    OPEN3D_OVERLOAD(OPEN3D_EXPORT_VECTORIZED_, __VA_ARGS__)     \
    (ISPCKernel, ISPCFunction, __VA_ARGS__)

/// Internal helper macro.
#define OPEN3D_EXPORT_OVERLOADED_(T, ISPCKernel, ISPCFunction, ...) \
    OPEN3D_EXPORT_VECTORIZED(OPEN3D_SPECIALIZED(T, ISPCKernel),     \
                             OPEN3D_SPECIALIZED(T, ISPCFunction), __VA_ARGS__)

/// OPEN3D_EXPORT_TEMPLATE_VECTORIZED(ISPCKernel, ISPCFunction, ...)
///
/// Defines a kernel which calls the provided function.
///
/// Use the OPEN3D_TEMPLATE_VECTORIZED macro to call the kernel in the
/// C++ source file.
///
/// Use either
/// - OPEN3D_INSTANTIATE_TEMPLATE_WITH_BOOL()
/// - OPEN3D_INSTANTIATE_TEMPLATE() + custom bool specialization
/// to define a function.
#define OPEN3D_EXPORT_TEMPLATE_VECTORIZED(ISPCKernel, ISPCFunction, ...)       \
    OPEN3D_EXPORT_OVERLOADED_(bool, ISPCKernel, ISPCFunction, __VA_ARGS__)     \
    OPEN3D_EXPORT_OVERLOADED_(uint8_t, ISPCKernel, ISPCFunction, __VA_ARGS__)  \
    OPEN3D_EXPORT_OVERLOADED_(int8_t, ISPCKernel, ISPCFunction, __VA_ARGS__)   \
    OPEN3D_EXPORT_OVERLOADED_(uint16_t, ISPCKernel, ISPCFunction, __VA_ARGS__) \
    OPEN3D_EXPORT_OVERLOADED_(int16_t, ISPCKernel, ISPCFunction, __VA_ARGS__)  \
    OPEN3D_EXPORT_OVERLOADED_(uint32_t, ISPCKernel, ISPCFunction, __VA_ARGS__) \
    OPEN3D_EXPORT_OVERLOADED_(int32_t, ISPCKernel, ISPCFunction, __VA_ARGS__)  \
    OPEN3D_EXPORT_OVERLOADED_(uint64_t, ISPCKernel, ISPCFunction, __VA_ARGS__) \
    OPEN3D_EXPORT_OVERLOADED_(int64_t, ISPCKernel, ISPCFunction, __VA_ARGS__)  \
    OPEN3D_EXPORT_OVERLOADED_(float, ISPCKernel, ISPCFunction, __VA_ARGS__)    \
    OPEN3D_EXPORT_OVERLOADED_(double, ISPCKernel, ISPCFunction, __VA_ARGS__)

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_1(ISPCKernel, ISPCFunction, type1)    \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end, \
                           type1 arg1) {                               \
        foreach (i = start... end) { ISPCFunction(i, arg1); }          \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_2(ISPCKernel, ISPCFunction, type1, type2) \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,     \
                           type1 arg1, type2 arg2) {                       \
        foreach (i = start... end) { ISPCFunction(i, arg1, arg2); }        \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_3(ISPCKernel, ISPCFunction, type1, type2, \
                                   type3)                                  \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,     \
                           type1 arg1, type2 arg2, type3 arg3) {           \
        foreach (i = start... end) { ISPCFunction(i, arg1, arg2, arg3); }  \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_4(ISPCKernel, ISPCFunction, type1, type2,   \
                                   type3, type4)                             \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,       \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4) { \
        foreach (i = start... end) {                                         \
            ISPCFunction(i, arg1, arg2, arg3, arg4);                         \
        }                                                                    \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_5(ISPCKernel, ISPCFunction, type1, type2, \
                                   type3, type4, type5)                    \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,     \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4, \
                           type5 arg5) {                                   \
        foreach (i = start... end) {                                       \
            ISPCFunction(i, arg1, arg2, arg3, arg4, arg5);                 \
        }                                                                  \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_6(ISPCKernel, ISPCFunction, type1, type2, \
                                   type3, type4, type5, type6)             \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,     \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4, \
                           type5 arg5, type6 arg6) {                       \
        foreach (i = start... end) {                                       \
            ISPCFunction(i, arg1, arg2, arg3, arg4, arg5, arg6);           \
        }                                                                  \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_7(ISPCKernel, ISPCFunction, type1, type2, \
                                   type3, type4, type5, type6, type7)      \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,     \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4, \
                           type5 arg5, type6 arg6, type7 arg7) {           \
        foreach (i = start... end) {                                       \
            ISPCFunction(i, arg1, arg2, arg3, arg4, arg5, arg6, arg7);     \
        }                                                                  \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_8(ISPCKernel, ISPCFunction, type1, type2,   \
                                   type3, type4, type5, type6, type7, type8) \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,       \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4,   \
                           type5 arg5, type6 arg6, type7 arg7, type8 arg8) { \
        foreach (i = start... end) {                                         \
            ISPCFunction(i, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8); \
        }                                                                    \
    }

/// See OPEN3D_EXPORT_VECTORIZED(...) for details.
#define OPEN3D_EXPORT_VECTORIZED_9(ISPCKernel, ISPCFunction, type1, type2,   \
                                   type3, type4, type5, type6, type7, type8, \
                                   type9)                                    \
    export void ISPCKernel(uniform int64_t start, uniform int64_t end,       \
                           type1 arg1, type2 arg2, type3 arg3, type4 arg4,   \
                           type5 arg5, type6 arg6, type7 arg7, type8 arg8,   \
                           type9 arg9) {                                     \
        foreach (i = start... end) {                                         \
            ISPCFunction(i, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8,  \
                         arg9);                                              \
        }                                                                    \
    }
