[TPU] Make NoncopyableBuffer::Clone take an optional alignment requirement.
PiperOrigin-RevId: 345895635 Change-Id: Idfe95565afb995a37bc279979ddbceace15021a4
This commit is contained in:
parent
1350613d6f
commit
f31efcb824
@ -64,7 +64,8 @@ cc_library(
|
||||
hdrs = ["noncopyable_buffer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -38,7 +38,9 @@ class NoncopyableBuffer {
|
||||
// will be filled by a subsequent function and want to avoid initialization
|
||||
// cost. Size is specified in number of bytes.
|
||||
explicit NoncopyableBuffer(size_t size)
|
||||
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
|
||||
: data_(static_cast<uint8_t*>(malloc(size)), free),
|
||||
buf_(data_.get()),
|
||||
size_(size) {}
|
||||
|
||||
// Allocates an owning buffer and initializes it with the specified data. Size
|
||||
// is specified in number of uint32's.
|
||||
@ -60,10 +62,10 @@ class NoncopyableBuffer {
|
||||
// the memcpy until mutable access is requested. "buf" is not owned by this
|
||||
// data structure, so it is the user's duty to ensure the live range of "buf"
|
||||
// is longer than this data structure.
|
||||
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
|
||||
NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
|
||||
: buf_(buf), size_(size) {}
|
||||
NoncopyableBuffer(const uint32_t* buf,
|
||||
uint64 size_in_u32s) // Size is in uint32_t's.
|
||||
size_t size_in_u32s) // Size is in uint32_t's.
|
||||
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
|
||||
|
||||
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
||||
@ -95,8 +97,10 @@ class NoncopyableBuffer {
|
||||
bool owns_data() const { return data_ != nullptr; }
|
||||
|
||||
// Returns a copy of the object that owns its buffer.
|
||||
NoncopyableBuffer Clone() const {
|
||||
NoncopyableBuffer clone(size_);
|
||||
NoncopyableBuffer Clone(size_t alignment = 1) const {
|
||||
auto clone = alignment <= 1
|
||||
? NoncopyableBuffer(size_)
|
||||
: NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
|
||||
memcpy(clone.data_.get(), buf_, size_);
|
||||
return clone;
|
||||
}
|
||||
@ -104,17 +108,26 @@ class NoncopyableBuffer {
|
||||
// Ensure that the buffer owns the data.
|
||||
void EnsureDataOwned() {
|
||||
if (data_ == nullptr) {
|
||||
data_.reset(new uint8_t[size_]);
|
||||
data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
|
||||
memcpy(data_.get(), buf_, size_);
|
||||
buf_ = data_.get();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
|
||||
NoncopyableBuffer(OwnedDataPtr data, size_t size)
|
||||
: data_(std::move(data)), buf_(data_.get()), size_(size) {}
|
||||
|
||||
static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
|
||||
return OwnedDataPtr(
|
||||
static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
|
||||
port::AlignedFree);
|
||||
}
|
||||
// If data_ != nullptr then buf_ == data_.get()
|
||||
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
|
||||
OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer.
|
||||
const void* buf_; // Non-owning data pointer.
|
||||
uint64 size_; // Size in number of bytes.
|
||||
size_t size_; // Size in number of bytes.
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
Loading…
Reference in New Issue
Block a user