From 847d6a89587de8233705287339b2af7099eb5565 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 2 Mar 2020 11:42:25 -0800 Subject: [PATCH] Avoid a footgun in KernelArgsArray Do not store the address of a `const T&`. Storing the address of a `const T&` means add_argument(42) does not work, which very counter-intuitive. PiperOrigin-RevId: 298405422 Change-Id: I769dfa8d7dad92b1e73b1a4f591768b4536cca39 --- tensorflow/stream_executor/kernel.h | 35 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index d3d386bca4a..514021f6304 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -76,6 +76,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/lib/array_slice.h" @@ -392,19 +393,22 @@ class KernelArgsArrayBase { template class KernelArgsArray : public KernelArgsArrayBase { public: - explicit KernelArgsArray() - : total_shared_memory_bytes_(0), - number_of_argument_addresses_(0), - number_of_shared_memory_arguments_(0) {} + static constexpr int kMaxGenericArgSize = 8; // Adds an argument to the list. - // - // Note that the address of the argument is stored, so the input must not go - // out of scope before the instance of this class that calls this method does. template void add_argument(const T &arg) { - argument_addresses_[number_of_argument_addresses_] = - static_cast(&arg); + static_assert(sizeof(T) <= kMaxGenericArgSize, + "Please adjust kMaxGenericArgSize"); + static_assert(std::is_pod::value, "Only pod types supported!"); + char *generic_arg_storage = + &generic_arguments_[number_of_generic_arguments_++ * + kMaxGenericArgSize]; + + CHECK_EQ(reinterpret_cast(generic_arg_storage) % alignof(T), 0); + std::memcpy(generic_arg_storage, &arg, sizeof(T)); + + argument_addresses_[number_of_argument_addresses_] = generic_arg_storage; argument_sizes_[number_of_argument_addresses_] = sizeof(arg); ++number_of_argument_addresses_; } @@ -463,6 +467,10 @@ class KernelArgsArray : public KernelArgsArrayBase { // Addresses for non-shared-memory arguments. std::array argument_addresses_; + // Storage for arguments of templated type. + alignas(kMaxGenericArgSize) + std::array generic_arguments_; + // Sizes for non-shared-memory arguments. std::array argument_sizes_; @@ -473,14 +481,17 @@ class KernelArgsArray : public KernelArgsArrayBase { std::array shared_memory_indices_; // Total of all shared memory sizes. - size_t total_shared_memory_bytes_; + size_t total_shared_memory_bytes_ = 0; // Number of significant entries in argument_addresses_ and argument_sizes_. - size_t number_of_argument_addresses_; + size_t number_of_argument_addresses_ = 0; // Number of significant entries in shared_memory_bytes_ and // shared_memory_indices_. - size_t number_of_shared_memory_arguments_; + size_t number_of_shared_memory_arguments_ = 0; + + // The number of generic arguments that have been added to generic_arguments_. + size_t number_of_generic_arguments_ = 0; }; // Typed variant of KernelBase, like a typed device function pointer. See the