Make tensorflow::tpu::NoncopyableBuffer allocate uint8 pointer rather than uint32 to allow unaligned U8/S8/U16/S16/BF16 arrays.
PiperOrigin-RevId: 341069760 Change-Id: I3df031ed970ec60a0e918768341b27eea1c954d7
This commit is contained in:
parent
e2286aba8e
commit
173e9bd169
@ -36,22 +36,23 @@ class NoncopyableBuffer {
|
|||||||
|
|
||||||
// Allocate an owning buffer without initializing the data. Useful when it
|
// Allocate an owning buffer without initializing the data. Useful when it
|
||||||
// will be filled by a subsequent function and want to avoid initialization
|
// will be filled by a subsequent function and want to avoid initialization
|
||||||
// cost. Size is specified in number of uint32's.
|
// cost. Size is specified in number of bytes.
|
||||||
explicit NoncopyableBuffer(size_t size)
|
explicit NoncopyableBuffer(size_t size)
|
||||||
: data_(new uint32[size]), buf_(data_.get()), size_(size) {}
|
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
|
||||||
|
|
||||||
// Allocates an owning buffer and initializes it with the specified data. Size
|
// Allocates an owning buffer and initializes it with the specified data. Size
|
||||||
// is specified in number of uint32's.
|
// is specified in number of uint32's.
|
||||||
NoncopyableBuffer(size_t size, absl::optional<uint32> value)
|
NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value)
|
||||||
: NoncopyableBuffer(size) {
|
: NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) {
|
||||||
#ifndef MEMORY_SANITIZER
|
#ifndef MEMORY_SANITIZER
|
||||||
if (!value.has_value()) {
|
if (!value.has_value()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
uint32 v = value.value_or(0);
|
uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
|
||||||
for (int64 i = 0; i < size; ++i) {
|
uint32_t v = value.value_or(0);
|
||||||
data_[i] = v;
|
for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
|
||||||
|
*p = v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,12 +60,11 @@ class NoncopyableBuffer {
|
|||||||
// the memcpy until mutable access is requested. "buf" is not owned by this
|
// the memcpy until mutable access is requested. "buf" is not owned by this
|
||||||
// data structure, so it is the user's duty to ensure the live range of "buf"
|
// data structure, so it is the user's duty to ensure the live range of "buf"
|
||||||
// is longer than this data structure.
|
// is longer than this data structure.
|
||||||
NoncopyableBuffer(const uint8* buf, uint64 size) // Size is in uint8's.
|
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
|
||||||
: buf_(buf), size_(size / sizeof(uint32)) {
|
|
||||||
CHECK_EQ(size % sizeof(uint32), 0);
|
|
||||||
}
|
|
||||||
NoncopyableBuffer(const uint32* buf, uint64 size) // Size is in uint32's.
|
|
||||||
: buf_(buf), size_(size) {}
|
: buf_(buf), size_(size) {}
|
||||||
|
NoncopyableBuffer(const uint32_t* buf,
|
||||||
|
uint64 size_in_u32s) // Size is in uint32_t's.
|
||||||
|
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
|
||||||
|
|
||||||
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
||||||
NoncopyableBuffer(NoncopyableBuffer&&) = default;
|
NoncopyableBuffer(NoncopyableBuffer&&) = default;
|
||||||
@ -74,20 +74,22 @@ class NoncopyableBuffer {
|
|||||||
|
|
||||||
// Ensure that the buffer owns the data and returns a mutable view into the
|
// Ensure that the buffer owns the data and returns a mutable view into the
|
||||||
// owned data for modification.
|
// owned data for modification.
|
||||||
absl::Span<uint32> mutable_data() {
|
template <typename T>
|
||||||
if (data_ == nullptr) {
|
absl::Span<T> mutable_data() {
|
||||||
data_.reset(new uint32[size_]);
|
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
|
||||||
memcpy(data_.get(), buf_, size_ * sizeof(uint32));
|
EnsureDataOwned();
|
||||||
buf_ = data_.get();
|
DCHECK_EQ(size_ % sizeof(T), 0);
|
||||||
}
|
return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
|
||||||
return absl::Span<uint32>(data_.get(), size_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Span<const uint32> const_data() const {
|
template <typename T>
|
||||||
return absl::Span<const uint32>(absl::bit_cast<uint32*>(buf_), size_);
|
absl::Span<const T> const_data() const {
|
||||||
|
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
|
||||||
|
DCHECK_EQ(size_ % sizeof(T), 0);
|
||||||
|
return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
|
||||||
}
|
}
|
||||||
// Clone the content to a given buffer.
|
// Clone the content to a given buffer.
|
||||||
void CloneTo(void* buf) { memcpy(buf, buf_, size_ * sizeof(uint32)); }
|
void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
|
||||||
|
|
||||||
// Return true if data is owned by this buffer (have been copied to `data_`).
|
// Return true if data is owned by this buffer (have been copied to `data_`).
|
||||||
bool owns_data() const { return data_ != nullptr; }
|
bool owns_data() const { return data_ != nullptr; }
|
||||||
@ -95,15 +97,24 @@ class NoncopyableBuffer {
|
|||||||
// Returns a copy of the object that owns its buffer.
|
// Returns a copy of the object that owns its buffer.
|
||||||
NoncopyableBuffer Clone() const {
|
NoncopyableBuffer Clone() const {
|
||||||
NoncopyableBuffer clone(size_);
|
NoncopyableBuffer clone(size_);
|
||||||
memcpy(clone.data_.get(), buf_, size_ * sizeof(uint32));
|
memcpy(clone.data_.get(), buf_, size_);
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that the buffer owns the data.
|
||||||
|
void EnsureDataOwned() {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
data_.reset(new uint8_t[size_]);
|
||||||
|
memcpy(data_.get(), buf_, size_);
|
||||||
|
buf_ = data_.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// If data_ != nullptr then buf_ == data_.get()
|
// If data_ != nullptr then buf_ == data_.get()
|
||||||
std::unique_ptr<uint32[]> data_; // Owning data pointer.
|
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
|
||||||
const void* buf_; // Non-owning data pointer.
|
const void* buf_; // Non-owning data pointer.
|
||||||
uint64 size_; // Size in number of uint32's.
|
uint64 size_; // Size in number of bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -114,9 +114,9 @@ Status TpuTransferManager::TransferBuffersToInfeed(
|
|||||||
buffers_array.reserve(buffers.size());
|
buffers_array.reserve(buffers.size());
|
||||||
|
|
||||||
for (int64_t i = 0; i < buffers.size(); ++i) {
|
for (int64_t i = 0; i < buffers.size(); ++i) {
|
||||||
buffers_array.push_back(
|
absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
|
||||||
const_cast<unsigned int*>(buffers[i].const_data().data()));
|
buffers_array.push_back(const_cast<uint32_t*>(span.data()));
|
||||||
buffers_size.push_back(buffers[i].const_data().size());
|
buffers_size.push_back(span.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
|
tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
|
||||||
@ -309,7 +309,8 @@ Status TpuTransferManager::LinearizeToBuffers(
|
|||||||
|
|
||||||
for (int64_t i = 0; i < buffers_array_size; ++i) {
|
for (int64_t i = 0; i < buffers_array_size; ++i) {
|
||||||
tpu::NoncopyableBuffer buf(buffers_size[i]);
|
tpu::NoncopyableBuffer buf(buffers_size[i]);
|
||||||
memcpy(buf.mutable_data().data(), buffers_array[i], buffers_size[i]);
|
memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
|
||||||
|
buffers_size[i]);
|
||||||
buffers->push_back(std::move(buf));
|
buffers->push_back(std::move(buf));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user