Make tensorflow::tpu::NoncopyableBuffer allocate uint8 pointer rather than uint32 to allow unaligned U8/S8/U16/S16/BF16 arrays.

PiperOrigin-RevId: 341069760
Change-Id: I3df031ed970ec60a0e918768341b27eea1c954d7
This commit is contained in:
Ce Zheng 2020-11-06 10:05:32 -08:00 committed by TensorFlower Gardener
parent e2286aba8e
commit 173e9bd169
2 changed files with 42 additions and 30 deletions

View File

@ -36,22 +36,23 @@ class NoncopyableBuffer {
// Allocate an owning buffer without initializing the data. Useful when it
// will be filled by a subsequent function and want to avoid initialization
// cost. Size is specified in number of uint32's.
// cost. Size is specified in number of bytes.
explicit NoncopyableBuffer(size_t size)
: data_(new uint32[size]), buf_(data_.get()), size_(size) {}
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
// Allocates an owning buffer and initializes it with the specified data. Size
// is specified in number of uint32's.
NoncopyableBuffer(size_t size, absl::optional<uint32> value)
: NoncopyableBuffer(size) {
NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value)
: NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) {
#ifndef MEMORY_SANITIZER
if (!value.has_value()) {
return;
}
#endif
uint32 v = value.value_or(0);
for (int64 i = 0; i < size; ++i) {
data_[i] = v;
uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
uint32_t v = value.value_or(0);
for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
*p = v;
}
}
@ -59,12 +60,11 @@ class NoncopyableBuffer {
// the memcpy until mutable access is requested. "buf" is not owned by this
// data structure, so it is the user's duty to ensure the live range of "buf"
// is longer than this data structure.
NoncopyableBuffer(const uint8* buf, uint64 size) // Size is in uint8's.
: buf_(buf), size_(size / sizeof(uint32)) {
CHECK_EQ(size % sizeof(uint32), 0);
}
NoncopyableBuffer(const uint32* buf, uint64 size) // Size is in uint32's.
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
: buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t* buf,
uint64 size_in_u32s) // Size is in uint32_t's.
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
NoncopyableBuffer(NoncopyableBuffer&&) = default;
@ -74,20 +74,22 @@ class NoncopyableBuffer {
// Ensure that the buffer owns the data and returns a mutable view into the
// owned data for modification.
absl::Span<uint32> mutable_data() {
if (data_ == nullptr) {
data_.reset(new uint32[size_]);
memcpy(data_.get(), buf_, size_ * sizeof(uint32));
buf_ = data_.get();
}
return absl::Span<uint32>(data_.get(), size_);
template <typename T>
absl::Span<T> mutable_data() {
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
EnsureDataOwned();
DCHECK_EQ(size_ % sizeof(T), 0);
return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
}
absl::Span<const uint32> const_data() const {
return absl::Span<const uint32>(absl::bit_cast<uint32*>(buf_), size_);
template <typename T>
absl::Span<const T> const_data() const {
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
DCHECK_EQ(size_ % sizeof(T), 0);
return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
}
// Clone the content to a given buffer.
void CloneTo(void* buf) { memcpy(buf, buf_, size_ * sizeof(uint32)); }
void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
// Return true if data is owned by this buffer (have been copied to `data_`).
bool owns_data() const { return data_ != nullptr; }
@ -95,15 +97,24 @@ class NoncopyableBuffer {
// Returns a copy of the object that owns its buffer.
NoncopyableBuffer Clone() const {
NoncopyableBuffer clone(size_);
memcpy(clone.data_.get(), buf_, size_ * sizeof(uint32));
memcpy(clone.data_.get(), buf_, size_);
return clone;
}
// Ensure that the buffer owns the data.
void EnsureDataOwned() {
if (data_ == nullptr) {
data_.reset(new uint8_t[size_]);
memcpy(data_.get(), buf_, size_);
buf_ = data_.get();
}
}
private:
// If data_ != nullptr then buf_ == data_.get()
std::unique_ptr<uint32[]> data_; // Owning data pointer.
const void* buf_; // Non-owning data pointer.
uint64 size_; // Size in number of uint32's.
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
const void* buf_; // Non-owning data pointer.
uint64 size_; // Size in number of bytes.
};
} // namespace tpu

View File

@ -114,9 +114,9 @@ Status TpuTransferManager::TransferBuffersToInfeed(
buffers_array.reserve(buffers.size());
for (int64_t i = 0; i < buffers.size(); ++i) {
buffers_array.push_back(
const_cast<unsigned int*>(buffers[i].const_data().data()));
buffers_size.push_back(buffers[i].const_data().size());
absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
buffers_array.push_back(const_cast<uint32_t*>(span.data()));
buffers_size.push_back(span.size());
}
tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
@ -309,7 +309,8 @@ Status TpuTransferManager::LinearizeToBuffers(
for (int64_t i = 0; i < buffers_array_size; ++i) {
tpu::NoncopyableBuffer buf(buffers_size[i]);
memcpy(buf.mutable_data().data(), buffers_array[i], buffers_size[i]);
memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
buffers_size[i]);
buffers->push_back(std::move(buf));
}