Avoid a footgun in KernelArgsArray
Do not store the address of a `const T&`. Storing the address of a `const T&` means add_argument(42) does not work, which very counter-intuitive. PiperOrigin-RevId: 298405422 Change-Id: I769dfa8d7dad92b1e73b1a4f591768b4536cca39
This commit is contained in:
parent
c578ea6343
commit
847d6a8958
@ -76,6 +76,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/kernel_cache_config.h"
|
#include "tensorflow/stream_executor/kernel_cache_config.h"
|
||||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||||
@ -392,19 +393,22 @@ class KernelArgsArrayBase {
|
|||||||
template <size_t kNumArgs>
|
template <size_t kNumArgs>
|
||||||
class KernelArgsArray : public KernelArgsArrayBase {
|
class KernelArgsArray : public KernelArgsArrayBase {
|
||||||
public:
|
public:
|
||||||
explicit KernelArgsArray()
|
static constexpr int kMaxGenericArgSize = 8;
|
||||||
: total_shared_memory_bytes_(0),
|
|
||||||
number_of_argument_addresses_(0),
|
|
||||||
number_of_shared_memory_arguments_(0) {}
|
|
||||||
|
|
||||||
// Adds an argument to the list.
|
// Adds an argument to the list.
|
||||||
//
|
|
||||||
// Note that the address of the argument is stored, so the input must not go
|
|
||||||
// out of scope before the instance of this class that calls this method does.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void add_argument(const T &arg) {
|
void add_argument(const T &arg) {
|
||||||
argument_addresses_[number_of_argument_addresses_] =
|
static_assert(sizeof(T) <= kMaxGenericArgSize,
|
||||||
static_cast<const void *>(&arg);
|
"Please adjust kMaxGenericArgSize");
|
||||||
|
static_assert(std::is_pod<T>::value, "Only pod types supported!");
|
||||||
|
char *generic_arg_storage =
|
||||||
|
&generic_arguments_[number_of_generic_arguments_++ *
|
||||||
|
kMaxGenericArgSize];
|
||||||
|
|
||||||
|
CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
|
||||||
|
std::memcpy(generic_arg_storage, &arg, sizeof(T));
|
||||||
|
|
||||||
|
argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
|
||||||
argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
|
argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
|
||||||
++number_of_argument_addresses_;
|
++number_of_argument_addresses_;
|
||||||
}
|
}
|
||||||
@ -463,6 +467,10 @@ class KernelArgsArray : public KernelArgsArrayBase {
|
|||||||
// Addresses for non-shared-memory arguments.
|
// Addresses for non-shared-memory arguments.
|
||||||
std::array<const void *, kNumArgs> argument_addresses_;
|
std::array<const void *, kNumArgs> argument_addresses_;
|
||||||
|
|
||||||
|
// Storage for arguments of templated type.
|
||||||
|
alignas(kMaxGenericArgSize)
|
||||||
|
std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
|
||||||
|
|
||||||
// Sizes for non-shared-memory arguments.
|
// Sizes for non-shared-memory arguments.
|
||||||
std::array<size_t, kNumArgs> argument_sizes_;
|
std::array<size_t, kNumArgs> argument_sizes_;
|
||||||
|
|
||||||
@ -473,14 +481,17 @@ class KernelArgsArray : public KernelArgsArrayBase {
|
|||||||
std::array<size_t, kNumArgs> shared_memory_indices_;
|
std::array<size_t, kNumArgs> shared_memory_indices_;
|
||||||
|
|
||||||
// Total of all shared memory sizes.
|
// Total of all shared memory sizes.
|
||||||
size_t total_shared_memory_bytes_;
|
size_t total_shared_memory_bytes_ = 0;
|
||||||
|
|
||||||
// Number of significant entries in argument_addresses_ and argument_sizes_.
|
// Number of significant entries in argument_addresses_ and argument_sizes_.
|
||||||
size_t number_of_argument_addresses_;
|
size_t number_of_argument_addresses_ = 0;
|
||||||
|
|
||||||
// Number of significant entries in shared_memory_bytes_ and
|
// Number of significant entries in shared_memory_bytes_ and
|
||||||
// shared_memory_indices_.
|
// shared_memory_indices_.
|
||||||
size_t number_of_shared_memory_arguments_;
|
size_t number_of_shared_memory_arguments_ = 0;
|
||||||
|
|
||||||
|
// The number of generic arguments that have been added to generic_arguments_.
|
||||||
|
size_t number_of_generic_arguments_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Typed variant of KernelBase, like a typed device function pointer. See the
|
// Typed variant of KernelBase, like a typed device function pointer. See the
|
||||||
|
Loading…
Reference in New Issue
Block a user