[TPU] Make NoncopyableBuffer::Clone take an optional alignment requirement.

PiperOrigin-RevId: 345895635
Change-Id: Idfe95565afb995a37bc279979ddbceace15021a4
This commit is contained in:
Ce Zheng 2020-12-05 17:16:28 -08:00 committed by TensorFlower Gardener
parent 1350613d6f
commit f31efcb824
2 changed files with 26 additions and 12 deletions

View File

@ -64,7 +64,8 @@ cc_library(
hdrs = ["noncopyable_buffer.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:platform_port",
"@com_google_absl//absl/base",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
namespace tpu {
@ -38,7 +38,9 @@ class NoncopyableBuffer {
// will be filled by a subsequent function and want to avoid initialization
// cost. Size is specified in number of bytes.
explicit NoncopyableBuffer(size_t size)
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
: data_(static_cast<uint8_t*>(malloc(size)), free),
buf_(data_.get()),
size_(size) {}
// Allocates an owning buffer and initializes it with the specified data. Size
// is specified in number of uint32's.
@ -60,10 +62,10 @@ class NoncopyableBuffer {
// the memcpy until mutable access is requested. "buf" is not owned by this
// data structure, so it is the user's duty to ensure the live range of "buf"
// is longer than this data structure.
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
: buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t* buf,
uint64 size_in_u32s) // Size is in uint32_t's.
size_t size_in_u32s) // Size is in uint32_t's.
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
@ -95,8 +97,10 @@ class NoncopyableBuffer {
bool owns_data() const { return data_ != nullptr; }
// Returns a copy of the object that owns its buffer.
NoncopyableBuffer Clone() const {
NoncopyableBuffer clone(size_);
NoncopyableBuffer Clone(size_t alignment = 1) const {
auto clone = alignment <= 1
? NoncopyableBuffer(size_)
: NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
memcpy(clone.data_.get(), buf_, size_);
return clone;
}
@ -104,17 +108,26 @@ class NoncopyableBuffer {
// Ensure that the buffer owns the data.
void EnsureDataOwned() {
if (data_ == nullptr) {
data_.reset(new uint8_t[size_]);
data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
memcpy(data_.get(), buf_, size_);
buf_ = data_.get();
}
}
private:
using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
NoncopyableBuffer(OwnedDataPtr data, size_t size)
: data_(std::move(data)), buf_(data_.get()), size_(size) {}
static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
return OwnedDataPtr(
static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
port::AlignedFree);
}
// If data_ != nullptr then buf_ == data_.get()
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
const void* buf_; // Non-owning data pointer.
uint64 size_; // Size in number of bytes.
OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer.
const void* buf_; // Non-owning data pointer.
size_t size_; // Size in number of bytes.
};
} // namespace tpu