[TPU] Make NoncopyableBuffer::Clone take an optional alignment requirement.

PiperOrigin-RevId: 345895635
Change-Id: Idfe95565afb995a37bc279979ddbceace15021a4
This commit is contained in:
Ce Zheng 2020-12-05 17:16:28 -08:00 committed by TensorFlower Gardener
parent 1350613d6f
commit f31efcb824
2 changed files with 26 additions and 12 deletions

View File

@ -64,7 +64,8 @@ cc_library(
hdrs = ["noncopyable_buffer.h"], hdrs = ["noncopyable_buffer.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/core:lib", "//tensorflow/core/platform:logging",
"//tensorflow/core/platform:platform_port",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/base/casts.h" #include "absl/base/casts.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/mem.h"
namespace tensorflow { namespace tensorflow {
namespace tpu { namespace tpu {
@ -38,7 +38,9 @@ class NoncopyableBuffer {
// will be filled by a subsequent function and want to avoid initialization // will be filled by a subsequent function and want to avoid initialization
// cost. Size is specified in number of bytes. // cost. Size is specified in number of bytes.
explicit NoncopyableBuffer(size_t size) explicit NoncopyableBuffer(size_t size)
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {} : data_(static_cast<uint8_t*>(malloc(size)), free),
buf_(data_.get()),
size_(size) {}
// Allocates an owning buffer and initializes it with the specified data. Size // Allocates an owning buffer and initializes it with the specified data. Size
// is specified in number of uint32's. // is specified in number of uint32's.
@ -60,10 +62,10 @@ class NoncopyableBuffer {
// the memcpy until mutable access is requested. "buf" is not owned by this // the memcpy until mutable access is requested. "buf" is not owned by this
// data structure, so it is the user's duty to ensure the live range of "buf" // data structure, so it is the user's duty to ensure the live range of "buf"
// is longer than this data structure. // is longer than this data structure.
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's. NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
: buf_(buf), size_(size) {} : buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t* buf, NoncopyableBuffer(const uint32_t* buf,
uint64 size_in_u32s) // Size is in uint32_t's. size_t size_in_u32s) // Size is in uint32_t's.
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {} : buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
NoncopyableBuffer(const NoncopyableBuffer&) = delete; NoncopyableBuffer(const NoncopyableBuffer&) = delete;
@ -95,8 +97,10 @@ class NoncopyableBuffer {
bool owns_data() const { return data_ != nullptr; } bool owns_data() const { return data_ != nullptr; }
// Returns a copy of the object that owns its buffer. // Returns a copy of the object that owns its buffer.
NoncopyableBuffer Clone() const { NoncopyableBuffer Clone(size_t alignment = 1) const {
NoncopyableBuffer clone(size_); auto clone = alignment <= 1
? NoncopyableBuffer(size_)
: NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
memcpy(clone.data_.get(), buf_, size_); memcpy(clone.data_.get(), buf_, size_);
return clone; return clone;
} }
@ -104,17 +108,26 @@ class NoncopyableBuffer {
// Ensure that the buffer owns the data. // Ensure that the buffer owns the data.
void EnsureDataOwned() { void EnsureDataOwned() {
if (data_ == nullptr) { if (data_ == nullptr) {
data_.reset(new uint8_t[size_]); data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
memcpy(data_.get(), buf_, size_); memcpy(data_.get(), buf_, size_);
buf_ = data_.get(); buf_ = data_.get();
} }
} }
private: private:
using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
NoncopyableBuffer(OwnedDataPtr data, size_t size)
: data_(std::move(data)), buf_(data_.get()), size_(size) {}
static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
return OwnedDataPtr(
static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
port::AlignedFree);
}
// If data_ != nullptr then buf_ == data_.get() // If data_ != nullptr then buf_ == data_.get()
std::unique_ptr<uint8_t[]> data_; // Owning data pointer. OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer.
const void* buf_; // Non-owning data pointer. const void* buf_; // Non-owning data pointer.
uint64 size_; // Size in number of bytes. size_t size_; // Size in number of bytes.
}; };
} // namespace tpu } // namespace tpu