//@HEADER // ************************************************************************ // // Kokkos v. 4.0 // Copyright (2022) National Technology & Engineering // Solutions of Sandia, LLC (NTESS). // // Under the terms of Contract DE-NA0003525 with NTESS, // the U.S. Government retains certain rights in this software. // // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. // See https://kokkos.org/LICENSE for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER #ifndef KOKKOS_THREADS_PARALLEL_FOR_TEAM_HPP #define KOKKOS_THREADS_PARALLEL_FOR_TEAM_HPP #include namespace Kokkos { namespace Impl { template class ParallelFor, Kokkos::Threads> { private: using Policy = Kokkos::Impl::TeamPolicyInternal; using WorkTag = typename Policy::work_tag; using Member = typename Policy::member_type; const FunctorType m_functor; const Policy m_policy; const size_t m_shared; template inline static std::enable_if_t::value && std::is_same::value> exec_team(const FunctorType &functor, Member member) { for (; member.valid_static(); member.next_static()) { functor(member); } } template inline static std::enable_if_t::value && std::is_same::value> exec_team(const FunctorType &functor, Member member) { const TagType t{}; for (; member.valid_static(); member.next_static()) { functor(t, member); } } template inline static std::enable_if_t::value && std::is_same::value> exec_team(const FunctorType &functor, Member member) { for (; member.valid_dynamic(); member.next_dynamic()) { functor(member); } } template inline static std::enable_if_t::value && std::is_same::value> exec_team(const FunctorType &functor, Member member) { const TagType t{}; for (; member.valid_dynamic(); member.next_dynamic()) { functor(t, member); } } static void exec(ThreadsInternal &instance, const void *arg) { const ParallelFor &self = *((const ParallelFor *)arg); ParallelFor::exec_team( self.m_functor, Member(&instance, self.m_policy, self.m_shared)); instance.barrier(); instance.fan_in(); } template Policy fix_policy(Policy policy) { if (policy.impl_vector_length() < 0) { policy.impl_set_vector_length(1); } if (policy.team_size() < 0) { policy.impl_set_team_size( policy.team_size_recommended(m_functor, ParallelForTag{})); } return policy; } public: inline void execute() const { ThreadsInternal::resize_scratch( 0, Policy::member_type::team_reduce_size() + m_shared); ThreadsInternal::start(&ParallelFor::exec, this); ThreadsInternal::fence(); } ParallelFor(const FunctorType &arg_functor, const Policy &arg_policy) : m_functor(arg_functor), m_policy(fix_policy(arg_policy)), m_shared(m_policy.scratch_size(0) + m_policy.scratch_size(1) + FunctorTeamShmemSize::value( arg_functor, m_policy.team_size())) {} }; } // namespace Impl } // namespace Kokkos #endif