Make tensorflow::tpu::NoncopyableBuffer allocate uint8 pointer rather than uint32 to allow unaligned U8/S8/U16/S16/BF16 arrays.
PiperOrigin-RevId: 341069760 Change-Id: I3df031ed970ec60a0e918768341b27eea1c954d7
This commit is contained in:
parent
e2286aba8e
commit
173e9bd169
@ -36,22 +36,23 @@ class NoncopyableBuffer {
|
||||
|
||||
// Allocate an owning buffer without initializing the data. Useful when it
|
||||
// will be filled by a subsequent function and want to avoid initialization
|
||||
// cost. Size is specified in number of uint32's.
|
||||
// cost. Size is specified in number of bytes.
|
||||
explicit NoncopyableBuffer(size_t size)
|
||||
: data_(new uint32[size]), buf_(data_.get()), size_(size) {}
|
||||
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
|
||||
|
||||
// Allocates an owning buffer and initializes it with the specified data. Size
|
||||
// is specified in number of uint32's.
|
||||
NoncopyableBuffer(size_t size, absl::optional<uint32> value)
|
||||
: NoncopyableBuffer(size) {
|
||||
NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value)
|
||||
: NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) {
|
||||
#ifndef MEMORY_SANITIZER
|
||||
if (!value.has_value()) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
uint32 v = value.value_or(0);
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
data_[i] = v;
|
||||
uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
|
||||
uint32_t v = value.value_or(0);
|
||||
for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
|
||||
*p = v;
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,12 +60,11 @@ class NoncopyableBuffer {
|
||||
// the memcpy until mutable access is requested. "buf" is not owned by this
|
||||
// data structure, so it is the user's duty to ensure the live range of "buf"
|
||||
// is longer than this data structure.
|
||||
NoncopyableBuffer(const uint8* buf, uint64 size) // Size is in uint8's.
|
||||
: buf_(buf), size_(size / sizeof(uint32)) {
|
||||
CHECK_EQ(size % sizeof(uint32), 0);
|
||||
}
|
||||
NoncopyableBuffer(const uint32* buf, uint64 size) // Size is in uint32's.
|
||||
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
|
||||
: buf_(buf), size_(size) {}
|
||||
NoncopyableBuffer(const uint32_t* buf,
|
||||
uint64 size_in_u32s) // Size is in uint32_t's.
|
||||
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
|
||||
|
||||
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
||||
NoncopyableBuffer(NoncopyableBuffer&&) = default;
|
||||
@ -74,20 +74,22 @@ class NoncopyableBuffer {
|
||||
|
||||
// Ensure that the buffer owns the data and returns a mutable view into the
|
||||
// owned data for modification.
|
||||
absl::Span<uint32> mutable_data() {
|
||||
if (data_ == nullptr) {
|
||||
data_.reset(new uint32[size_]);
|
||||
memcpy(data_.get(), buf_, size_ * sizeof(uint32));
|
||||
buf_ = data_.get();
|
||||
}
|
||||
return absl::Span<uint32>(data_.get(), size_);
|
||||
template <typename T>
|
||||
absl::Span<T> mutable_data() {
|
||||
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
|
||||
EnsureDataOwned();
|
||||
DCHECK_EQ(size_ % sizeof(T), 0);
|
||||
return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
|
||||
}
|
||||
|
||||
absl::Span<const uint32> const_data() const {
|
||||
return absl::Span<const uint32>(absl::bit_cast<uint32*>(buf_), size_);
|
||||
template <typename T>
|
||||
absl::Span<const T> const_data() const {
|
||||
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
|
||||
DCHECK_EQ(size_ % sizeof(T), 0);
|
||||
return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
|
||||
}
|
||||
// Clone the content to a given buffer.
|
||||
void CloneTo(void* buf) { memcpy(buf, buf_, size_ * sizeof(uint32)); }
|
||||
void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
|
||||
|
||||
// Return true if data is owned by this buffer (have been copied to `data_`).
|
||||
bool owns_data() const { return data_ != nullptr; }
|
||||
@ -95,15 +97,24 @@ class NoncopyableBuffer {
|
||||
// Returns a copy of the object that owns its buffer.
|
||||
NoncopyableBuffer Clone() const {
|
||||
NoncopyableBuffer clone(size_);
|
||||
memcpy(clone.data_.get(), buf_, size_ * sizeof(uint32));
|
||||
memcpy(clone.data_.get(), buf_, size_);
|
||||
return clone;
|
||||
}
|
||||
|
||||
// Ensure that the buffer owns the data.
|
||||
void EnsureDataOwned() {
|
||||
if (data_ == nullptr) {
|
||||
data_.reset(new uint8_t[size_]);
|
||||
memcpy(data_.get(), buf_, size_);
|
||||
buf_ = data_.get();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// If data_ != nullptr then buf_ == data_.get()
|
||||
std::unique_ptr<uint32[]> data_; // Owning data pointer.
|
||||
const void* buf_; // Non-owning data pointer.
|
||||
uint64 size_; // Size in number of uint32's.
|
||||
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
|
||||
const void* buf_; // Non-owning data pointer.
|
||||
uint64 size_; // Size in number of bytes.
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -114,9 +114,9 @@ Status TpuTransferManager::TransferBuffersToInfeed(
|
||||
buffers_array.reserve(buffers.size());
|
||||
|
||||
for (int64_t i = 0; i < buffers.size(); ++i) {
|
||||
buffers_array.push_back(
|
||||
const_cast<unsigned int*>(buffers[i].const_data().data()));
|
||||
buffers_size.push_back(buffers[i].const_data().size());
|
||||
absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
|
||||
buffers_array.push_back(const_cast<uint32_t*>(span.data()));
|
||||
buffers_size.push_back(span.size());
|
||||
}
|
||||
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
|
||||
@ -309,7 +309,8 @@ Status TpuTransferManager::LinearizeToBuffers(
|
||||
|
||||
for (int64_t i = 0; i < buffers_array_size; ++i) {
|
||||
tpu::NoncopyableBuffer buf(buffers_size[i]);
|
||||
memcpy(buf.mutable_data().data(), buffers_array[i], buffers_size[i]);
|
||||
memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
|
||||
buffers_size[i]);
|
||||
buffers->push_back(std::move(buf));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user