//@HEADER // ************************************************************************ // // Kokkos v. 4.0 // Copyright (2022) National Technology & Engineering // Solutions of Sandia, LLC (NTESS). // // Under the terms of Contract DE-NA0003525 with NTESS, // the U.S. Government retains certain rights in this software. // // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. // See https://kokkos.org/LICENSE for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER #ifndef KOKKOS_SYCL_HALF_HPP_ #define KOKKOS_SYCL_HALF_HPP_ #ifdef KOKKOS_IMPL_SYCL_HALF_TYPE_DEFINED #include #include namespace Kokkos { namespace Experimental { /************************** half conversions **********************************/ KOKKOS_INLINE_FUNCTION half_t cast_to_half(half_t val) { return val; } KOKKOS_INLINE_FUNCTION half_t cast_to_half(float val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(double val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(short val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(unsigned short val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(int val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(unsigned int val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(long long val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(unsigned long long val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(long val) { return half_t::impl_type(val); } KOKKOS_INLINE_FUNCTION half_t cast_to_half(unsigned long val) { return half_t::impl_type(val); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_half(half_t val) { return static_cast(half_t::impl_type(val)); } } // namespace Experimental template <> struct reduction_identity { KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::Experimental::half_t sum() noexcept { return Kokkos::Experimental::half_t::impl_type(0.0F); } KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::Experimental::half_t prod() noexcept { return Kokkos::Experimental::half_t::impl_type(1.0F); } KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::Experimental::half_t max() noexcept { return std::numeric_limits< Kokkos::Experimental::half_t::impl_type>::lowest(); } KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::Experimental::half_t min() noexcept { return std::numeric_limits::max(); } }; } // namespace Kokkos #endif // KOKKOS_IMPL_SYCL_HALF_TYPE_DEFINED #ifdef KOKKOS_IMPL_SYCL_BHALF_TYPE_DEFINED namespace Kokkos { namespace Experimental { /************************** bhalf conversions *********************************/ KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(bhalf_t val) { return val; } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(float val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(double val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(short val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(unsigned short val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(int val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(unsigned int val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(long long val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(unsigned long long val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(long val) { return bhalf_t::impl_type(val); } KOKKOS_INLINE_FUNCTION bhalf_t cast_to_bhalf(unsigned long val) { return bhalf_t::impl_type(val); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } template KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> cast_from_bhalf(bhalf_t val) { return static_cast(bhalf_t::impl_type(val)); } } // namespace Experimental // sycl::bfloat16 doesn't have constexpr constructors so we return float template <> struct reduction_identity { KOKKOS_FORCEINLINE_FUNCTION constexpr static float sum() noexcept { return 0.f; } KOKKOS_FORCEINLINE_FUNCTION constexpr static float prod() noexcept { return 1.0f; } KOKKOS_FORCEINLINE_FUNCTION constexpr static float max() noexcept { return -0x7f7f; } KOKKOS_FORCEINLINE_FUNCTION constexpr static float min() noexcept { return 0x7f7f; } }; } // namespace Kokkos #endif // KOKKOS_IMPL_SYCL_BHALF_TYPE_DEFINED #endif