//@HEADER // ************************************************************************ // // Kokkos v. 4.0 // Copyright (2022) National Technology & Engineering // Solutions of Sandia, LLC (NTESS). // // Under the terms of Contract DE-NA0003525 with NTESS, // the U.S. Government retains certain rights in this software. // // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. // See https://kokkos.org/LICENSE for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER #ifndef KOKKOS_SIMD_SCALAR_HPP #define KOKKOS_SIMD_SCALAR_HPP #include #include #include #include #ifdef KOKKOS_SIMD_COMMON_MATH_HPP #error \ "Kokkos_SIMD_Scalar.hpp must be included before Kokkos_SIMD_Common_Math.hpp!" #endif namespace Kokkos { namespace Experimental { namespace simd_abi { class scalar {}; } // namespace simd_abi template class simd_mask { bool m_value; public: using value_type = bool; using simd_type = simd; using abi_type = simd_abi::scalar; using reference = value_type&; KOKKOS_DEFAULTED_FUNCTION simd_mask() = default; KOKKOS_FORCEINLINE_FUNCTION static constexpr std::size_t size() { return 1; } KOKKOS_FORCEINLINE_FUNCTION explicit simd_mask(value_type value) : m_value(value) {} template < class G, std::enable_if_t>, bool> = false> KOKKOS_FORCEINLINE_FUNCTION constexpr explicit simd_mask(G&& gen) noexcept : m_value(gen(0)) {} template KOKKOS_FORCEINLINE_FUNCTION simd_mask( simd_mask const& other) : m_value(static_cast(other)) {} KOKKOS_FORCEINLINE_FUNCTION constexpr explicit operator bool() const { return m_value; } KOKKOS_FORCEINLINE_FUNCTION reference operator[](std::size_t) { return m_value; } KOKKOS_FORCEINLINE_FUNCTION value_type operator[](std::size_t) const { return m_value; } KOKKOS_FORCEINLINE_FUNCTION simd_mask operator||(simd_mask const& other) const { return simd_mask(m_value || other.m_value); } KOKKOS_FORCEINLINE_FUNCTION simd_mask operator&&(simd_mask const& other) const { return simd_mask(m_value && other.m_value); } KOKKOS_FORCEINLINE_FUNCTION simd_mask operator!() const { return simd_mask(!m_value); } KOKKOS_FORCEINLINE_FUNCTION bool operator==(simd_mask const& other) const { return m_value == other.m_value; } KOKKOS_FORCEINLINE_FUNCTION bool operator!=(simd_mask const& other) const { return m_value != other.m_value; } }; template class simd { T m_value; public: using value_type = T; using abi_type = simd_abi::scalar; using mask_type = simd_mask; using reference = value_type&; KOKKOS_DEFAULTED_FUNCTION simd() = default; KOKKOS_DEFAULTED_FUNCTION simd(simd const&) = default; KOKKOS_DEFAULTED_FUNCTION simd(simd&&) = default; KOKKOS_DEFAULTED_FUNCTION simd& operator=(simd const&) = default; KOKKOS_DEFAULTED_FUNCTION simd& operator=(simd&&) = default; KOKKOS_FORCEINLINE_FUNCTION static constexpr std::size_t size() { return 1; } template , bool> = false> KOKKOS_FORCEINLINE_FUNCTION simd(U&& value) : m_value(value) {} template , bool> = false> KOKKOS_FORCEINLINE_FUNCTION explicit simd(simd const& other) : m_value(static_cast(other)) {} template ()); } std::is_invocable_r_v>, bool> = false> KOKKOS_FORCEINLINE_FUNCTION constexpr explicit simd(G&& gen) noexcept : m_value(gen(0)) {} KOKKOS_FORCEINLINE_FUNCTION constexpr explicit operator T() const { return m_value; } KOKKOS_FORCEINLINE_FUNCTION void copy_from(T const* ptr, element_aligned_tag) { m_value = *ptr; } KOKKOS_FORCEINLINE_FUNCTION void copy_from(T const* ptr, vector_aligned_tag) { m_value = *ptr; } KOKKOS_FORCEINLINE_FUNCTION void copy_to(T* ptr, element_aligned_tag) const { *ptr = m_value; } KOKKOS_FORCEINLINE_FUNCTION void copy_to(T* ptr, vector_aligned_tag) const { *ptr = m_value; } KOKKOS_FORCEINLINE_FUNCTION reference operator[](std::size_t) { return m_value; } KOKKOS_FORCEINLINE_FUNCTION value_type operator[](std::size_t) const { return m_value; } KOKKOS_FORCEINLINE_FUNCTION simd operator-() const noexcept { return simd(-m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator*( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value * rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator/( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value / rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator+( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value + rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator-( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value - rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator>>( simd const& lhs, int rhs) noexcept { return simd(lhs.m_value >> rhs); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator>>( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value >> rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator<<( simd const& lhs, int rhs) noexcept { return simd(lhs.m_value << rhs); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator<<( simd const& lhs, simd const& rhs) noexcept { return simd(lhs.m_value << rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator&( simd const& lhs, simd const& rhs) noexcept { return lhs.m_value & rhs.m_value; } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr simd operator|( simd const& lhs, simd const& rhs) noexcept { return lhs.m_value | rhs.m_value; } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator<(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value < rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator>(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value > rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator<=(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value <= rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator>=(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value >= rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator==(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value == rhs.m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION friend constexpr mask_type operator!=(simd const& lhs, simd const& rhs) noexcept { return mask_type(lhs.m_value != rhs.m_value); } }; } // namespace Experimental template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION Experimental::simd abs(Experimental::simd const& a) { if constexpr (std::is_signed_v) { return (a < 0 ? -a : a); } return a; } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto floor( Experimental::simd const& a) { using data_type = std::conditional_t, T, double>; return Experimental::simd( Kokkos::floor(static_cast(a[0]))); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto ceil( Experimental::simd const& a) { using data_type = std::conditional_t, T, double>; return Experimental::simd( Kokkos::ceil(static_cast(a[0]))); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto round( Experimental::simd const& a) { using data_type = std::conditional_t, T, double>; return Experimental::simd( Experimental::round_half_to_nearest_even(static_cast(a[0]))); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto trunc( Experimental::simd const& a) { using data_type = std::conditional_t, T, double>; return Experimental::simd( Kokkos::trunc(static_cast(a[0]))); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION Experimental::simd sqrt(Experimental::simd const& a) { return Experimental::simd( std::sqrt(static_cast(a))); } template KOKKOS_FORCEINLINE_FUNCTION Experimental::simd fma(Experimental::simd const& x, Experimental::simd const& y, Experimental::simd const& z) { return Experimental::simd( (static_cast(x) * static_cast(y)) + static_cast(z)); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION Experimental::simd copysign(Experimental::simd const& a, Experimental::simd const& b) { return std::copysign(static_cast(a), static_cast(b)); } namespace Experimental { template KOKKOS_FORCEINLINE_FUNCTION simd condition( desul::Impl::dont_deduce_this_parameter_t< simd_mask> const& a, simd const& b, simd const& c) { return simd(static_cast(a) ? static_cast(b) : static_cast(c)); } template class const_where_expression, simd> { public: using abi_type = simd_abi::scalar; using value_type = simd; using mask_type = simd_mask; protected: value_type& m_value; mask_type const& m_mask; public: KOKKOS_FORCEINLINE_FUNCTION const_where_expression(mask_type const& mask_arg, value_type const& value_arg) : m_value(const_cast(value_arg)), m_mask(mask_arg) {} KOKKOS_FORCEINLINE_FUNCTION void copy_to(T* mem, element_aligned_tag) const { if (static_cast(m_mask)) *mem = static_cast(m_value); } KOKKOS_FORCEINLINE_FUNCTION void copy_to(T* mem, vector_aligned_tag) const { if (static_cast(m_mask)) *mem = static_cast(m_value); } template KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t> scatter_to(T* mem, simd const& index) const { if (static_cast(m_mask)) mem[static_cast(index)] = static_cast(m_value); } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION value_type const& impl_get_value() const { return m_value; } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION mask_type const& impl_get_mask() const { return m_mask; } }; template class where_expression, simd> : public const_where_expression, simd> { using base_type = const_where_expression, simd>; public: using typename base_type::value_type; KOKKOS_FORCEINLINE_FUNCTION where_expression(simd_mask const& mask_arg, simd& value_arg) : base_type(mask_arg, value_arg) {} KOKKOS_FORCEINLINE_FUNCTION void copy_from(T const* mem, element_aligned_tag) { if (static_cast(this->m_mask)) this->m_value = *mem; } KOKKOS_FORCEINLINE_FUNCTION void copy_from(T const* mem, vector_aligned_tag) { if (static_cast(this->m_mask)) this->m_value = *mem; } template KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t> gather_from(T const* mem, simd const& index) { if (static_cast(this->m_mask)) this->m_value = mem[static_cast(index)]; } template >, bool> = false> KOKKOS_FORCEINLINE_FUNCTION void operator=(U&& x) { if (static_cast(this->m_mask)) this->m_value = static_cast>(std::forward(x)); } }; template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION where_expression, simd> where(typename simd< T, Kokkos::Experimental::simd_abi::scalar>::mask_type const& mask, simd& value) { return where_expression(mask, value); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION const_where_expression, simd> where(typename simd< T, Kokkos::Experimental::simd_abi::scalar>::mask_type const& mask, simd const& value) { return const_where_expression(mask, value); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool all_of( simd_mask const& a) { return a == simd_mask(true); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool any_of( simd_mask const& a) { return a != simd_mask(false); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool none_of( simd_mask const& a) { return a == simd_mask(false); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T reduce(const_where_expression, simd> const& x, T identity_element, std::plus<>) { return static_cast(x.impl_get_mask()) ? static_cast(x.impl_get_value()) : identity_element; } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T hmax(const_where_expression, simd> const& x) { return static_cast(x.impl_get_mask()) ? static_cast(x.impl_get_value()) : Kokkos::reduction_identity::max(); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T hmin(const_where_expression, simd> const& x) { return static_cast(x.impl_get_mask()) ? static_cast(x.impl_get_value()) : Kokkos::reduction_identity::min(); } } // namespace Experimental } // namespace Kokkos #endif