Avoid a footgun in KernelArgsArray

Do not store the address of a `const T&`.  Storing the address of a `const T&`
means add_argument(42) does not work, which very counter-intuitive.

PiperOrigin-RevId: 298405422
Change-Id: I769dfa8d7dad92b1e73b1a4f591768b4536cca39
This commit is contained in:
Sanjoy Das 2020-03-02 11:42:25 -08:00 committed by TensorFlower Gardener
parent c578ea6343
commit 847d6a8958

View File

@ -76,6 +76,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@ -392,19 +393,22 @@ class KernelArgsArrayBase {
template <size_t kNumArgs>
class KernelArgsArray : public KernelArgsArrayBase {
public:
explicit KernelArgsArray()
: total_shared_memory_bytes_(0),
number_of_argument_addresses_(0),
number_of_shared_memory_arguments_(0) {}
static constexpr int kMaxGenericArgSize = 8;
// Adds an argument to the list.
//
// Note that the address of the argument is stored, so the input must not go
// out of scope before the instance of this class that calls this method does.
template <typename T>
void add_argument(const T &arg) {
argument_addresses_[number_of_argument_addresses_] =
static_cast<const void *>(&arg);
static_assert(sizeof(T) <= kMaxGenericArgSize,
"Please adjust kMaxGenericArgSize");
static_assert(std::is_pod<T>::value, "Only pod types supported!");
char *generic_arg_storage =
&generic_arguments_[number_of_generic_arguments_++ *
kMaxGenericArgSize];
CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
std::memcpy(generic_arg_storage, &arg, sizeof(T));
argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
++number_of_argument_addresses_;
}
@ -463,6 +467,10 @@ class KernelArgsArray : public KernelArgsArrayBase {
// Addresses for non-shared-memory arguments.
std::array<const void *, kNumArgs> argument_addresses_;
// Storage for arguments of templated type.
alignas(kMaxGenericArgSize)
std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
// Sizes for non-shared-memory arguments.
std::array<size_t, kNumArgs> argument_sizes_;
@ -473,14 +481,17 @@ class KernelArgsArray : public KernelArgsArrayBase {
std::array<size_t, kNumArgs> shared_memory_indices_;
// Total of all shared memory sizes.
size_t total_shared_memory_bytes_;
size_t total_shared_memory_bytes_ = 0;
// Number of significant entries in argument_addresses_ and argument_sizes_.
size_t number_of_argument_addresses_;
size_t number_of_argument_addresses_ = 0;
// Number of significant entries in shared_memory_bytes_ and
// shared_memory_indices_.
size_t number_of_shared_memory_arguments_;
size_t number_of_shared_memory_arguments_ = 0;
// The number of generic arguments that have been added to generic_arguments_.
size_t number_of_generic_arguments_ = 0;
};
// Typed variant of KernelBase, like a typed device function pointer. See the