[TPU] Make NoncopyableBuffer::Clone take an optional alignment requirement.
PiperOrigin-RevId: 345895635 Change-Id: Idfe95565afb995a37bc279979ddbceace15021a4
This commit is contained in:
parent
1350613d6f
commit
f31efcb824
@ -64,7 +64,8 @@ cc_library(
|
|||||||
hdrs = ["noncopyable_buffer.h"],
|
hdrs = ["noncopyable_buffer.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core/platform:logging",
|
||||||
|
"//tensorflow/core/platform:platform_port",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
|||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
@ -38,7 +38,9 @@ class NoncopyableBuffer {
|
|||||||
// will be filled by a subsequent function and want to avoid initialization
|
// will be filled by a subsequent function and want to avoid initialization
|
||||||
// cost. Size is specified in number of bytes.
|
// cost. Size is specified in number of bytes.
|
||||||
explicit NoncopyableBuffer(size_t size)
|
explicit NoncopyableBuffer(size_t size)
|
||||||
: data_(new uint8_t[size]), buf_(data_.get()), size_(size) {}
|
: data_(static_cast<uint8_t*>(malloc(size)), free),
|
||||||
|
buf_(data_.get()),
|
||||||
|
size_(size) {}
|
||||||
|
|
||||||
// Allocates an owning buffer and initializes it with the specified data. Size
|
// Allocates an owning buffer and initializes it with the specified data. Size
|
||||||
// is specified in number of uint32's.
|
// is specified in number of uint32's.
|
||||||
@ -60,10 +62,10 @@ class NoncopyableBuffer {
|
|||||||
// the memcpy until mutable access is requested. "buf" is not owned by this
|
// the memcpy until mutable access is requested. "buf" is not owned by this
|
||||||
// data structure, so it is the user's duty to ensure the live range of "buf"
|
// data structure, so it is the user's duty to ensure the live range of "buf"
|
||||||
// is longer than this data structure.
|
// is longer than this data structure.
|
||||||
NoncopyableBuffer(const uint8_t* buf, uint64 size) // Size is in uint8's.
|
NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
|
||||||
: buf_(buf), size_(size) {}
|
: buf_(buf), size_(size) {}
|
||||||
NoncopyableBuffer(const uint32_t* buf,
|
NoncopyableBuffer(const uint32_t* buf,
|
||||||
uint64 size_in_u32s) // Size is in uint32_t's.
|
size_t size_in_u32s) // Size is in uint32_t's.
|
||||||
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
|
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
|
||||||
|
|
||||||
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
||||||
@ -95,8 +97,10 @@ class NoncopyableBuffer {
|
|||||||
bool owns_data() const { return data_ != nullptr; }
|
bool owns_data() const { return data_ != nullptr; }
|
||||||
|
|
||||||
// Returns a copy of the object that owns its buffer.
|
// Returns a copy of the object that owns its buffer.
|
||||||
NoncopyableBuffer Clone() const {
|
NoncopyableBuffer Clone(size_t alignment = 1) const {
|
||||||
NoncopyableBuffer clone(size_);
|
auto clone = alignment <= 1
|
||||||
|
? NoncopyableBuffer(size_)
|
||||||
|
: NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
|
||||||
memcpy(clone.data_.get(), buf_, size_);
|
memcpy(clone.data_.get(), buf_, size_);
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
@ -104,17 +108,26 @@ class NoncopyableBuffer {
|
|||||||
// Ensure that the buffer owns the data.
|
// Ensure that the buffer owns the data.
|
||||||
void EnsureDataOwned() {
|
void EnsureDataOwned() {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
data_.reset(new uint8_t[size_]);
|
data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
|
||||||
memcpy(data_.get(), buf_, size_);
|
memcpy(data_.get(), buf_, size_);
|
||||||
buf_ = data_.get();
|
buf_ = data_.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
|
||||||
|
NoncopyableBuffer(OwnedDataPtr data, size_t size)
|
||||||
|
: data_(std::move(data)), buf_(data_.get()), size_(size) {}
|
||||||
|
|
||||||
|
static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
|
||||||
|
return OwnedDataPtr(
|
||||||
|
static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
|
||||||
|
port::AlignedFree);
|
||||||
|
}
|
||||||
// If data_ != nullptr then buf_ == data_.get()
|
// If data_ != nullptr then buf_ == data_.get()
|
||||||
std::unique_ptr<uint8_t[]> data_; // Owning data pointer.
|
OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer.
|
||||||
const void* buf_; // Non-owning data pointer.
|
const void* buf_; // Non-owning data pointer.
|
||||||
uint64 size_; // Size in number of bytes.
|
size_t size_; // Size in number of bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
Loading…
Reference in New Issue
Block a user