Avoid a footgun in KernelArgsArray
Do not store the address of a `const T&`. Storing the address of a `const T&` means add_argument(42) does not work, which very counter-intuitive. PiperOrigin-RevId: 298405422 Change-Id: I769dfa8d7dad92b1e73b1a4f591768b4536cca39
This commit is contained in:
parent
c578ea6343
commit
847d6a8958
@ -76,6 +76,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/kernel_cache_config.h"
|
||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||
@ -392,19 +393,22 @@ class KernelArgsArrayBase {
|
||||
template <size_t kNumArgs>
|
||||
class KernelArgsArray : public KernelArgsArrayBase {
|
||||
public:
|
||||
explicit KernelArgsArray()
|
||||
: total_shared_memory_bytes_(0),
|
||||
number_of_argument_addresses_(0),
|
||||
number_of_shared_memory_arguments_(0) {}
|
||||
static constexpr int kMaxGenericArgSize = 8;
|
||||
|
||||
// Adds an argument to the list.
|
||||
//
|
||||
// Note that the address of the argument is stored, so the input must not go
|
||||
// out of scope before the instance of this class that calls this method does.
|
||||
template <typename T>
|
||||
void add_argument(const T &arg) {
|
||||
argument_addresses_[number_of_argument_addresses_] =
|
||||
static_cast<const void *>(&arg);
|
||||
static_assert(sizeof(T) <= kMaxGenericArgSize,
|
||||
"Please adjust kMaxGenericArgSize");
|
||||
static_assert(std::is_pod<T>::value, "Only pod types supported!");
|
||||
char *generic_arg_storage =
|
||||
&generic_arguments_[number_of_generic_arguments_++ *
|
||||
kMaxGenericArgSize];
|
||||
|
||||
CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
|
||||
std::memcpy(generic_arg_storage, &arg, sizeof(T));
|
||||
|
||||
argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
|
||||
argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
|
||||
++number_of_argument_addresses_;
|
||||
}
|
||||
@ -463,6 +467,10 @@ class KernelArgsArray : public KernelArgsArrayBase {
|
||||
// Addresses for non-shared-memory arguments.
|
||||
std::array<const void *, kNumArgs> argument_addresses_;
|
||||
|
||||
// Storage for arguments of templated type.
|
||||
alignas(kMaxGenericArgSize)
|
||||
std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
|
||||
|
||||
// Sizes for non-shared-memory arguments.
|
||||
std::array<size_t, kNumArgs> argument_sizes_;
|
||||
|
||||
@ -473,14 +481,17 @@ class KernelArgsArray : public KernelArgsArrayBase {
|
||||
std::array<size_t, kNumArgs> shared_memory_indices_;
|
||||
|
||||
// Total of all shared memory sizes.
|
||||
size_t total_shared_memory_bytes_;
|
||||
size_t total_shared_memory_bytes_ = 0;
|
||||
|
||||
// Number of significant entries in argument_addresses_ and argument_sizes_.
|
||||
size_t number_of_argument_addresses_;
|
||||
size_t number_of_argument_addresses_ = 0;
|
||||
|
||||
// Number of significant entries in shared_memory_bytes_ and
|
||||
// shared_memory_indices_.
|
||||
size_t number_of_shared_memory_arguments_;
|
||||
size_t number_of_shared_memory_arguments_ = 0;
|
||||
|
||||
// The number of generic arguments that have been added to generic_arguments_.
|
||||
size_t number_of_generic_arguments_ = 0;
|
||||
};
|
||||
|
||||
// Typed variant of KernelBase, like a typed device function pointer. See the
|
||||
|
Loading…
Reference in New Issue
Block a user