/*
 * Copyright (c) 2017-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CONCURRENT_UNORDERED_MAP_CUH
#define CONCURRENT_UNORDERED_MAP_CUH

#include <cudf/detail/nvtx/ranges.hpp>
#include <hash/hash_allocator.cuh>
#include <hash/helper_functions.cuh>
#include <hash/managed.cuh>
#include <cudf/detail/utilities/hash_functions.cuh>
#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/utilities/error.hpp>

#include <thrust/pair.h>
#include <thrust/count.h>

#include <functional>
#include <memory>
#include <cassert>
#include <iostream>
#include <iterator>
#include <limits>
#include <type_traits>

namespace {
template <std::size_t N>
struct packed {
  using type = void;
};
template <>
struct packed<sizeof(uint64_t)> {
  using type = uint64_t;
};
template <>
struct packed<sizeof(uint32_t)> {
  using type = uint32_t;
};
template <typename pair_type>
using packed_t = typename packed<sizeof(pair_type)>::type;

/**
 * @brief Indicates if a pair type can be packed.
 *
 * When the size of the key,value pair being inserted into the hash table is
 * equal in size to a type where atomicCAS is natively supported, it is more
 * efficient to "pack" the pair and insert it with a single atomicCAS.
 *
 * @note Only integral key and value types may be packed because we use
 * bitwise equality comparison, which may not be valid for non-integral
 * types.
 *
 * @tparam pair_type The pair type in question
 * @return true If the pair type can be packed
 * @return false  If the pair type cannot be packed
 **/
template <typename pair_type,
          typename key_type   = typename pair_type::first_type,
          typename value_type = typename pair_type::second_type>
constexpr bool is_packable()
{
  return std::is_integral<key_type>::value and std::is_integral<value_type>::value and
         not std::is_void<packed_t<pair_type>>::value;
}

/**
 * @brief Allows viewing a pair in a packed representation
 *
 * Used as an optimization for inserting when a pair can be inserted with a
 * single atomicCAS
 **/
template <typename pair_type, typename Enable = void>
union pair_packer;

template <typename pair_type>
union pair_packer<pair_type, std::enable_if_t<is_packable<pair_type>()>> {
  using packed_type = packed_t<pair_type>;
  packed_type const packed;
  pair_type const pair;

  __device__ pair_packer(pair_type _pair) : pair{_pair} {}

  __device__ pair_packer(packed_type _packed) : packed{_packed} {}
};
}  // namespace

template <typename Key, typename Element, typename Equality> struct _is_used {
  using value_type = thrust::pair<Key, Element>;

  _is_used(Key const &unused, Equality const &equal)
      : m_unused_key(unused), m_equal(equal) {}

  __host__ __device__ bool operator()(value_type const &x) {
    return !m_equal(x.first, m_unused_key);
  }

  Key const m_unused_key;
  Equality const m_equal;
};

/**
 * Supports concurrent insert, but not concurrent insert and find.
 *
 * @note The user is responsible for the following stream semantics:
 * - Either the same stream should be used to create the map as is used by the kernels that access
 * it, or
 * - the stream used to create the map should be synchronized before it is accessed from a different
 * stream or from host code.
 *
 * TODO:
 *  - add constructor that takes pointer to hash_table to avoid allocations
 */
template <typename Key,
          typename Element,
          typename Hasher    = default_hash<Key>,
          typename Equality  = equal_to<Key>,
          typename Allocator = default_allocator<thrust::pair<Key, Element>>>
class concurrent_unordered_map {
 public:
  using size_type      = size_t;
  using hasher         = Hasher;
  using key_equal      = Equality;
  using allocator_type = Allocator;
  using key_type       = Key;
  using mapped_type    = Element;
  using value_type     = thrust::pair<Key, Element>;
  using iterator       = cycle_iterator_adapter<value_type*>;
  using const_iterator = const cycle_iterator_adapter<value_type*>;

 public:
  /**
   * @brief Factory to construct a new concurrent unordered map.
   *
   * Returns a `std::unique_ptr` to a new concurrent unordered map object. The
   * map is non-owning and trivially copyable and should be passed by value into
   * kernels. The `unique_ptr` contains a custom deleter that will free the
   * map's contents.
   *
   * @note The implementation of this unordered_map uses sentinel values to
   * indicate an entry in the hash table that is empty, i.e., if a hash bucket
   * is empty, the pair residing there will be equal to (unused_key,
   * unused_element). As a result, attempting to insert a key equal to
   *`unused_key` results in undefined behavior.
   *
   * @note All allocations, kernels and copies in the constructor take place
   * on stream but the constructor does not synchronize the stream. It is the user's
   * responsibility to synchronize or use the same stream to access the map.
   *
   * @param capacity The maximum number of pairs the map may hold
   * @param unused_element The sentinel value to use for an empty value
   * @param unused_key The sentinel value to use for an empty key
   * @param hash_function The hash function to use for hashing keys
   * @param equal The equality comparison function for comparing if two keys are
   * equal
   * @param allocator The allocator to use for allocation the hash table's
   * storage
   * @param stream CUDA stream to use for device operations.
   **/
  static auto create(size_type capacity,
                     const mapped_type unused_element = std::numeric_limits<mapped_type>::max(),
                     const key_type unused_key        = std::numeric_limits<key_type>::max(),
                     const Hasher& hash_function      = hasher(),
                     const Equality& equal            = key_equal(),
                     const allocator_type& allocator  = allocator_type(),
                     cudaStream_t stream              = 0)
  {
    CUDF_FUNC_RANGE();
    using Self = concurrent_unordered_map<Key, Element, Hasher, Equality, Allocator>;

    // Note: need `(*p).destroy` instead of `p->destroy` here
    // due to compiler bug: https://github.com/rapidsai/cudf/pull/5692
    auto deleter = [stream](Self* p) { (*p).destroy(stream); };

    return std::unique_ptr<Self, std::function<void(Self*)>>{
      new Self(capacity, unused_element, unused_key, hash_function, equal, allocator, stream),
      deleter};
  }

  /**
   * @brief Returns an iterator to the first element in the map
   *
   * @note `__device__` code that calls this function should either run in the
   * same stream as `create()`, or the accessing stream either be running on the
   * same stream as create(), or the accessing stream should be appropriately
   * synchronized with the creating stream.
   *
   * @returns iterator to the first element in the map.
   **/
  __device__ iterator begin()
  {
    return iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, m_hashtbl_values);
  }

  /**
   * @brief Returns a constant iterator to the first element in the map
   *
   * @note `__device__` code that calls this function should either run in the
   * same stream as `create()`, or the accessing stream either be running on the
   * same stream as create(), or the accessing stream should be appropriately
   * synchronized with the creating stream.
   *
   * @returns constant iterator to the first element in the map.
   **/
  __device__ const_iterator begin() const
  {
    return const_iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, m_hashtbl_values);
  }

  /**
   * @brief Returns an iterator to the one past the last element in the map
   *
   * @note `__device__` code that calls this function should either run in the
   * same stream as `create()`, or the accessing stream either be running on the
   * same stream as create(), or the accessing stream should be appropriately
   * synchronized with the creating stream.
   *
   * @returns iterator to the one past the last element in the map.
   **/
  __device__ iterator end()
  {
    return iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, m_hashtbl_values + m_capacity);
  }

  /**
   * @brief Returns a constant iterator to the one past the last element in the map
   *
   * @note When called in a device code, user should make sure that it should
   * either be running on the same stream as create(), or the accessing stream
   * should be appropriately synchronized with the creating stream.
   *
   * @returns constant iterator to the one past the last element in the map.
   **/
  __device__ const_iterator end() const
  {
    return const_iterator(
      m_hashtbl_values, m_hashtbl_values + m_capacity, m_hashtbl_values + m_capacity);
  }
  __host__ __device__ inline value_type* data() const { return m_hashtbl_values; }

  __host__ __device__ inline key_type get_unused_key() const { return m_unused_key; }

  __host__ __device__ inline mapped_type get_unused_element() const { return m_unused_element; }

  __host__ __device__ inline key_equal get_key_equal() const { return m_equal; }

  __host__ __device__ inline size_type capacity() const { return m_capacity; }

 private:
  /**
   * @brief Enumeration of the possible results of attempting to insert into
   *a hash bucket
   **/
  enum class insert_result {
    CONTINUE,  ///< Insert did not succeed, continue trying to insert
               ///< (collision)
    SUCCESS,   ///< New pair inserted successfully
    DUPLICATE  ///< Insert did not succeed, key is already present
  };

  /**
   * @brief Specialization for value types that can be packed.
   *
   * When the size of the key,value pair being inserted is equal in size to
   *a type where atomicCAS is natively supported, this optimization path
   *will insert the pair in a single atomicCAS operation.
   **/
  template <typename pair_type = value_type>
  __device__ std::enable_if_t<is_packable<pair_type>(), insert_result> attempt_insert(
    value_type* insert_location, value_type const& insert_pair)
  {
    pair_packer<pair_type> const unused{thrust::make_pair(m_unused_key, m_unused_element)};
    pair_packer<pair_type> const new_pair{insert_pair};
    pair_packer<pair_type> const old{
      atomicCAS(reinterpret_cast<typename pair_packer<pair_type>::packed_type*>(insert_location),
                unused.packed,
                new_pair.packed)};

    if (old.packed == unused.packed) { return insert_result::SUCCESS; }

    if (m_equal(old.pair.first, insert_pair.first)) { return insert_result::DUPLICATE; }
    return insert_result::CONTINUE;
  }

  /**
   * @brief Atempts to insert a key,value pair at the specified hash bucket.
   *
   * @param[in] insert_location Pointer to hash bucket to attempt insert
   * @param[in] insert_pair The pair to insert
   * @return Enum indicating result of insert attempt.
   **/
  template <typename pair_type = value_type>
  __device__ std::enable_if_t<not is_packable<pair_type>(), insert_result> attempt_insert(
    value_type* const __restrict__ insert_location, value_type const& insert_pair)
  {
    key_type const old_key{atomicCAS(&(insert_location->first), m_unused_key, insert_pair.first)};

    // Hash bucket empty
    if (m_equal(m_unused_key, old_key)) {
      insert_location->second = insert_pair.second;
      return insert_result::SUCCESS;
    }

    // Key already exists
    if (m_equal(old_key, insert_pair.first)) { return insert_result::DUPLICATE; }

    return insert_result::CONTINUE;
  }

 public:
  /**
   * @brief Attempts to insert a key, value pair into the map.
   *
   * Returns an iterator, boolean pair.
   *
   * If the new key already present in the map, the iterator points to
   * the location of the existing key and the boolean is `false` indicating
   * that the insert did not succeed.
   *
   * If the new key was not present, the iterator points to the location
   *where the insert occured and the boolean is `true` indicating that the
   *insert succeeded.
   *
   * @param insert_pair The key and value pair to insert
   * @return Iterator, Boolean pair. Iterator is to the location of the
   *newly inserted pair, or the existing pair that prevented the insert.
   *Boolean indicates insert success.
   **/
  __device__ thrust::pair<iterator, bool> insert(value_type const& insert_pair)
  {
    const size_type key_hash{m_hf(insert_pair.first)};
    size_type index{key_hash % m_capacity};

    insert_result status{insert_result::CONTINUE};

    value_type* current_bucket{nullptr};

    while (status == insert_result::CONTINUE) {
      current_bucket = &m_hashtbl_values[index];
      status         = attempt_insert(current_bucket, insert_pair);
      index          = (index + 1) % m_capacity;
    }

    bool const insert_success = (status == insert_result::SUCCESS) ? true : false;

    return thrust::make_pair(
      iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, current_bucket), insert_success);
  }

  /**
   * @brief Searches the map for the specified key.
   *
   * @note `find` is not threadsafe with `insert`. I.e., it is not safe to
   *do concurrent `insert` and `find` operations.
   *
   * @param k The key to search for
   * @return An iterator to the key if it exists, else map.end()
   **/
  __device__ const_iterator find(key_type const& k) const
  {
    size_type const key_hash = m_hf(k);
    size_type index          = key_hash % m_capacity;

    value_type* current_bucket = &m_hashtbl_values[index];

    while (true) {
      key_type const existing_key = current_bucket->first;

      if (m_equal(m_unused_key, existing_key)) { return this->end(); }

      if (m_equal(k, existing_key)) {
        return const_iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, current_bucket);
      }

      index          = (index + 1) % m_capacity;
      current_bucket = &m_hashtbl_values[index];
    }
  }

  /**
   * @brief Searches the map for the specified key.
   *
   * @note `find` is not threadsafe with `insert`. I.e., it is not safe to
   *do concurrent `insert` and `find` operations.
   *
   * @param k The key to search for
   * @return An iterator to the key if it exists, else map.end()
   **/
  __device__ iterator find(key_type const& k)
  {
    size_type const key_hash = m_hf(k);
    size_type index          = key_hash % m_capacity;

    value_type* current_bucket = &m_hashtbl_values[index];

    while (true) {
      key_type const existing_key = current_bucket->first;

      if (m_equal(m_unused_key, existing_key)) { return this->end(); }

      if (m_equal(k, existing_key)) {
        return iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, current_bucket);
      }

      index          = (index + 1) % m_capacity;
      current_bucket = &m_hashtbl_values[index];
    }
  }

  /**
   * @brief Searches the map for the specified key.
   *
   * This version of the find function specifies a hashing function and an
   * equality comparison.  This allows the caller to use different functions
   * for insert and find (for example, when you want to insert keys from
   * one table and use find to match keys from a different table with the
   * keys from the first table).
   *
   * @note `find` is not threadsafe with `insert`. I.e., it is not safe to
   * do concurrent `insert` and `find` operations.
   *
   * @tparam find_hasher     Type of hashing function
   * @tparam find_key_equal  Type of equality comparison
   *
   * @param k         The key to search for
   * @param f_hash    The hashing function to use to hash this key
   * @param f_equal   The equality function to use to compare this key with the
   *                  contents of the hash table
   * @return An iterator to the key if it exists, else map.end()
   **/
  template <typename find_hasher, typename find_key_equal>
  __device__ const_iterator find(key_type const& k,
                                 find_hasher f_hash,
                                 find_key_equal f_equal) const
  {
    size_type const key_hash = f_hash(k);
    size_type index          = key_hash % m_capacity;

    value_type* current_bucket = &m_hashtbl_values[index];

    while (true) {
      key_type const existing_key = current_bucket->first;

      if (m_equal(m_unused_key, existing_key)) { return this->end(); }

      if (f_equal(k, existing_key)) {
        return const_iterator(m_hashtbl_values, m_hashtbl_values + m_capacity, current_bucket);
      }

      index          = (index + 1) % m_capacity;
      current_bucket = &m_hashtbl_values[index];
    }
  }

  gdf_error assign_async(const concurrent_unordered_map& other, cudaStream_t stream = 0)
  {
    if (other.m_capacity <= m_capacity) {
      m_capacity = other.m_capacity;
    } else {
      m_allocator.deallocate(m_hashtbl_values, m_capacity, stream);
      m_capacity = other.m_capacity;
      m_capacity = other.m_capacity;

      m_hashtbl_values = m_allocator.allocate(m_capacity, stream);
    }
    CUDA_TRY(cudaMemcpyAsync(m_hashtbl_values,
                             other.m_hashtbl_values,
                             m_capacity * sizeof(value_type),
                             cudaMemcpyDefault,
                             stream));
    return GDF_SUCCESS;
  }

  void clear_async(cudaStream_t stream = 0)
  {
    constexpr int block_size = 128;
    init_hashtbl<<<((m_capacity + block_size - 1) / block_size), block_size, 0, stream>>>(
      m_hashtbl_values, m_capacity, m_unused_key, m_unused_element);
  }

  void print() const
  {
    for (size_type i = 0; i < m_capacity; ++i) {
      std::cout << i << ": " << m_hashtbl_values[i].first << "," << m_hashtbl_values[i].second
                << std::endl;
    }
  }

  size_t size() const
  {
    return thrust::count_if(thrust::device, m_hashtbl_values, m_hashtbl_values + m_capacity,
                            _is_used<Key, Element, Equality>(get_unused_key(), get_key_equal()));
  }

  gdf_error prefetch(const int dev_id, cudaStream_t stream = 0)
  {
    cudaPointerAttributes hashtbl_values_ptr_attributes;
    cudaError_t status = cudaPointerGetAttributes(&hashtbl_values_ptr_attributes, m_hashtbl_values);

    if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
      CUDA_TRY(
        cudaMemPrefetchAsync(m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream));
    }
    CUDA_TRY(cudaMemPrefetchAsync(this, sizeof(*this), dev_id, stream));

    return GDF_SUCCESS;
  }

  /**
   * @brief Frees the contents of the map and destroys the map object.
   *
   * This function is invoked as the deleter of the `std::unique_ptr` returned
   * from the `create()` factory function.
   *
   * @param stream CUDA stream to use for device operations.
   **/
  void destroy(cudaStream_t stream = 0)
  {
    m_allocator.deallocate(m_hashtbl_values, m_capacity, stream);
    delete this;
  }

  concurrent_unordered_map()                                = delete;
  concurrent_unordered_map(concurrent_unordered_map const&) = default;
  concurrent_unordered_map(concurrent_unordered_map&&)      = default;
  concurrent_unordered_map& operator=(concurrent_unordered_map const&) = default;
  concurrent_unordered_map& operator=(concurrent_unordered_map&&) = default;
  ~concurrent_unordered_map()                                     = default;

 private:
  hasher m_hf;
  key_equal m_equal;
  mapped_type m_unused_element;
  key_type m_unused_key;
  allocator_type m_allocator;
  size_type m_capacity;
  value_type* m_hashtbl_values;

  /**
   * @brief Private constructor used by `create` factory function.
   *
   * @param capacity The desired m_capacity of the hash table
   * @param unused_element The sentinel value to use for an empty value
   * @param unused_key The sentinel value to use for an empty key
   * @param hash_function The hash function to use for hashing keys
   * @param equal The equality comparison function for comparing if two keys
   *are equal
   * @param allocator The allocator to use for allocation the hash table's
   * storage
   * @param stream CUDA stream to use for device operations.
   **/
  concurrent_unordered_map(size_type capacity,
                           const mapped_type unused_element,
                           const key_type unused_key,
                           const Hasher& hash_function,
                           const Equality& equal,
                           const allocator_type& allocator,
                           cudaStream_t stream = 0)
    : m_hf(hash_function),
      m_equal(equal),
      m_allocator(allocator),
      m_capacity(capacity),
      m_unused_element(unused_element),
      m_unused_key(unused_key)
  {
    m_hashtbl_values         = m_allocator.allocate(m_capacity, stream);
    constexpr int block_size = 128;
    {
      cudaPointerAttributes hashtbl_values_ptr_attributes;
      cudaError_t status =
        cudaPointerGetAttributes(&hashtbl_values_ptr_attributes, m_hashtbl_values);

      if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
        int dev_id = 0;
        CUDA_TRY(cudaGetDevice(&dev_id));
        CUDA_TRY(
          cudaMemPrefetchAsync(m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream));
      }
    }

    init_hashtbl<<<((m_capacity + block_size - 1) / block_size), block_size, 0, stream>>>(
      m_hashtbl_values, m_capacity, m_unused_key, m_unused_element);
    CUDA_TRY(cudaGetLastError());
  }
};

#endif  // CONCURRENT_UNORDERED_MAP_CUH
