diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index d3d386bca4a..514021f6304 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -76,6 +76,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/lib/array_slice.h" @@ -392,19 +393,22 @@ class KernelArgsArrayBase { template class KernelArgsArray : public KernelArgsArrayBase { public: - explicit KernelArgsArray() - : total_shared_memory_bytes_(0), - number_of_argument_addresses_(0), - number_of_shared_memory_arguments_(0) {} + static constexpr int kMaxGenericArgSize = 8; // Adds an argument to the list. - // - // Note that the address of the argument is stored, so the input must not go - // out of scope before the instance of this class that calls this method does. template void add_argument(const T &arg) { - argument_addresses_[number_of_argument_addresses_] = - static_cast(&arg); + static_assert(sizeof(T) <= kMaxGenericArgSize, + "Please adjust kMaxGenericArgSize"); + static_assert(std::is_pod::value, "Only pod types supported!"); + char *generic_arg_storage = + &generic_arguments_[number_of_generic_arguments_++ * + kMaxGenericArgSize]; + + CHECK_EQ(reinterpret_cast(generic_arg_storage) % alignof(T), 0); + std::memcpy(generic_arg_storage, &arg, sizeof(T)); + + argument_addresses_[number_of_argument_addresses_] = generic_arg_storage; argument_sizes_[number_of_argument_addresses_] = sizeof(arg); ++number_of_argument_addresses_; } @@ -463,6 +467,10 @@ class KernelArgsArray : public KernelArgsArrayBase { // Addresses for non-shared-memory arguments. std::array argument_addresses_; + // Storage for arguments of templated type. + alignas(kMaxGenericArgSize) + std::array generic_arguments_; + // Sizes for non-shared-memory arguments. std::array argument_sizes_; @@ -473,14 +481,17 @@ class KernelArgsArray : public KernelArgsArrayBase { std::array shared_memory_indices_; // Total of all shared memory sizes. - size_t total_shared_memory_bytes_; + size_t total_shared_memory_bytes_ = 0; // Number of significant entries in argument_addresses_ and argument_sizes_. - size_t number_of_argument_addresses_; + size_t number_of_argument_addresses_ = 0; // Number of significant entries in shared_memory_bytes_ and // shared_memory_indices_. - size_t number_of_shared_memory_arguments_; + size_t number_of_shared_memory_arguments_ = 0; + + // The number of generic arguments that have been added to generic_arguments_. + size_t number_of_generic_arguments_ = 0; }; // Typed variant of KernelBase, like a typed device function pointer. See the