//@HEADER // ************************************************************************ // // Kokkos v. 4.0 // Copyright (2022) National Technology & Engineering // Solutions of Sandia, LLC (NTESS). // // Under the terms of Contract DE-NA0003525 with NTESS, // the U.S. Government retains certain rights in this software. // // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. // See https://kokkos.org/LICENSE for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER #ifndef KOKKOS_HIP_PARALLEL_FOR_RANGE_HPP #define KOKKOS_HIP_PARALLEL_FOR_RANGE_HPP #include #include #include namespace Kokkos { namespace Impl { template class ParallelFor, Kokkos::HIP> { public: using Policy = Kokkos::RangePolicy; private: using Member = typename Policy::member_type; using WorkTag = typename Policy::work_tag; using LaunchBounds = typename Policy::launch_bounds; const FunctorType m_functor; const Policy m_policy; template inline __device__ std::enable_if_t::value> exec_range( const Member i) const { m_functor(i); } template inline __device__ std::enable_if_t::value> exec_range( const Member i) const { m_functor(TagType(), i); } public: using functor_type = FunctorType; ParallelFor() = delete; ParallelFor(ParallelFor const&) = default; ParallelFor& operator=(ParallelFor const&) = delete; inline __device__ void operator()() const { const Member work_stride = blockDim.y * gridDim.x; const Member work_end = m_policy.end(); for (Member iwork = m_policy.begin() + threadIdx.y + blockDim.y * blockIdx.x; iwork < work_end; iwork = iwork < work_end - work_stride ? iwork + work_stride : work_end) { this->template exec_range(iwork); } } inline void execute() const { const typename Policy::index_type nwork = m_policy.end() - m_policy.begin(); using DriverType = ParallelFor; const int block_size = Kokkos::Impl::hip_get_preferred_blocksize(); const dim3 block(1, block_size, 1); const dim3 grid( typename Policy::index_type((nwork + block.y - 1) / block.y), 1, 1); if (block_size == 0) { Kokkos::Impl::throw_runtime_exception( std::string("Kokkos::Impl::ParallelFor< HIP > could not find a " "valid execution configuration.")); } Kokkos::Impl::hip_parallel_launch( *this, grid, block, 0, m_policy.space().impl_internal_space_instance(), false); } ParallelFor(const FunctorType& arg_functor, const Policy& arg_policy) : m_functor(arg_functor), m_policy(arg_policy) {} }; } // namespace Impl } // namespace Kokkos #endif