//@HEADER // ************************************************************************ // // Kokkos v. 4.0 // Copyright (2022) National Technology & Engineering // Solutions of Sandia, LLC (NTESS). // // Under the terms of Contract DE-NA0003525 with NTESS, // the U.S. Government retains certain rights in this software. // // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. // See https://kokkos.org/LICENSE for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER #ifndef KOKKOS_SIMD_COMMON_HPP #define KOKKOS_SIMD_COMMON_HPP #include #include namespace Kokkos { namespace Experimental { template class simd; template class simd_mask; class simd_alignment_vector_aligned {}; template struct simd_flags {}; inline constexpr simd_flags<> simd_flag_default{}; inline constexpr simd_flags simd_flag_aligned{}; using element_aligned_tag = simd_flags<>; using vector_aligned_tag = simd_flags; // class template declarations for const_where_expression and where_expression template class const_where_expression { protected: T& m_value; M const& m_mask; public: const_where_expression(M const& mask_arg, T const& value_arg) : m_value(const_cast(value_arg)), m_mask(mask_arg) {} KOKKOS_FORCEINLINE_FUNCTION T const& value() const { return this->m_value; } }; template class where_expression : public const_where_expression { using base_type = const_where_expression; public: where_expression(M const& mask_arg, T& value_arg) : base_type(mask_arg, value_arg) {} KOKKOS_FORCEINLINE_FUNCTION T& value() { return this->m_value; } }; // specializations of where expression templates for the case when the // mask type is bool, to allow generic code to use where() on both // SIMD types and non-SIMD builtin arithmetic types template class const_where_expression { protected: T& m_value; bool m_mask; public: KOKKOS_FORCEINLINE_FUNCTION const_where_expression(bool mask_arg, T const& value_arg) : m_value(const_cast(value_arg)), m_mask(mask_arg) {} KOKKOS_FORCEINLINE_FUNCTION T const& value() const { return this->m_value; } }; template class where_expression : public const_where_expression { using base_type = const_where_expression; public: KOKKOS_FORCEINLINE_FUNCTION where_expression(bool mask_arg, T& value_arg) : base_type(mask_arg, value_arg) {} KOKKOS_FORCEINLINE_FUNCTION T& value() { return this->m_value; } template , bool> = false> KOKKOS_FORCEINLINE_FUNCTION void operator=(U const& x) { if (this->m_mask) this->m_value = x; } }; template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION where_expression, simd> where(typename simd::mask_type const& mask, simd& value) { return where_expression(mask, value); } template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION const_where_expression, simd> where(typename simd::mask_type const& mask, simd const& value) { return const_where_expression(mask, value); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION where_expression where( bool mask, T& value) { return where_expression(mask, value); } template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION const_where_expression where( bool mask, T const& value) { return const_where_expression(mask, value); } // The code below provides: // operator@(simd, Arithmetic) // operator@(Arithmetic, simd) // operator@=(simd&, U&&) // operator@=(where_expression&, U&&) template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator+( Experimental::simd const& lhs, U rhs) { using result_member = decltype(lhs[0] + rhs); return Experimental::simd(lhs) + Experimental::simd(rhs); } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator+( U lhs, Experimental::simd const& rhs) { using result_member = decltype(lhs + rhs[0]); return Experimental::simd(lhs) + Experimental::simd(rhs); } template KOKKOS_FORCEINLINE_FUNCTION simd& operator+=(simd& lhs, U&& rhs) { lhs = lhs + std::forward(rhs); return lhs; } template KOKKOS_FORCEINLINE_FUNCTION where_expression& operator+=( where_expression& lhs, U&& rhs) { lhs = lhs.value() + std::forward(rhs); return lhs; } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator-( Experimental::simd const& lhs, U rhs) { using result_member = decltype(lhs[0] - rhs); return Experimental::simd(lhs) - Experimental::simd(rhs); } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator-( U lhs, Experimental::simd const& rhs) { using result_member = decltype(lhs - rhs[0]); return Experimental::simd(lhs) - Experimental::simd(rhs); } template KOKKOS_FORCEINLINE_FUNCTION simd& operator-=(simd& lhs, U&& rhs) { lhs = lhs - std::forward(rhs); return lhs; } template KOKKOS_FORCEINLINE_FUNCTION where_expression& operator-=( where_expression& lhs, U&& rhs) { lhs = lhs.value() - std::forward(rhs); return lhs; } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator*( Experimental::simd const& lhs, U rhs) { using result_member = decltype(lhs[0] * rhs); return Experimental::simd(lhs) * Experimental::simd(rhs); } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator*( U lhs, Experimental::simd const& rhs) { using result_member = decltype(lhs * rhs[0]); return Experimental::simd(lhs) * Experimental::simd(rhs); } template KOKKOS_FORCEINLINE_FUNCTION simd& operator*=(simd& lhs, U&& rhs) { lhs = lhs * std::forward(rhs); return lhs; } template KOKKOS_FORCEINLINE_FUNCTION where_expression& operator*=( where_expression& lhs, U&& rhs) { lhs = lhs.value() * std::forward(rhs); return lhs; } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator/( Experimental::simd const& lhs, U rhs) { using result_member = decltype(lhs[0] / rhs); return Experimental::simd(lhs) / Experimental::simd(rhs); } template , bool> = false> [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto operator/( U lhs, Experimental::simd const& rhs) { using result_member = decltype(lhs / rhs[0]); return Experimental::simd(lhs) / Experimental::simd(rhs); } template KOKKOS_FORCEINLINE_FUNCTION simd& operator/=(simd& lhs, U&& rhs) { lhs = lhs / std::forward(rhs); return lhs; } template KOKKOS_FORCEINLINE_FUNCTION where_expression& operator/=( where_expression& lhs, U&& rhs) { lhs = lhs.value() / std::forward(rhs); return lhs; } // implement mask reductions for type bool to allow generic code to accept // both simd and just double [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr bool all_of(bool a) { return a; } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr bool any_of(bool a) { return a; } [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr bool none_of(bool a) { return !a; } // fallback implementations of reductions across simd_mask: template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool all_of( simd_mask const& a) { return a == simd_mask(true); } template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool any_of( simd_mask const& a) { return a != simd_mask(false); } template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool none_of( simd_mask const& a) { return a == simd_mask(false); } // A temporary device-callable implemenation of round half to nearest even template [[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION auto round_half_to_nearest_even( T const& x) { auto ceil = Kokkos::ceil(x); auto floor = Kokkos::floor(x); if (Kokkos::abs(ceil - x) == Kokkos::abs(floor - x)) { auto rem = Kokkos::remainder(ceil, 2.0); return (rem == 0) ? ceil : floor; } return Kokkos::round(x); } } // namespace Experimental } // namespace Kokkos #endif