STT-tensorflow/tensorflow/stream_executor/tpu/noncopyable_buffer.h
Ce Zheng f31efcb824 [TPU] Make NoncopyableBuffer::Clone take an optional alignment requirement.
PiperOrigin-RevId: 345895635
Change-Id: Idfe95565afb995a37bc279979ddbceace15021a4
2020-12-05 17:21:39 -08:00

137 lines
5.1 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
#include <memory>
#include "absl/base/casts.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
namespace tpu {
// Uncopyable buffer type with optional ownership of the underlying data. If
// data is not owned then ensuring lifetime of the data exceeds the lifetime of
// the buffer is the responsibility of the user.
class NoncopyableBuffer {
public:
NoncopyableBuffer() = default;
// Allocate an owning buffer without initializing the data. Useful when it
// will be filled by a subsequent function and want to avoid initialization
// cost. Size is specified in number of bytes.
explicit NoncopyableBuffer(size_t size)
: data_(static_cast<uint8_t*>(malloc(size)), free),
buf_(data_.get()),
size_(size) {}
// Allocates an owning buffer and initializes it with the specified data. Size
// is specified in number of uint32's.
NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value)
: NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) {
#ifndef MEMORY_SANITIZER
if (!value.has_value()) {
return;
}
#endif
uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
uint32_t v = value.value_or(0);
for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
*p = v;
}
}
// Directly use buf pointer without copying it to owning data_. This delays
// the memcpy until mutable access is requested. "buf" is not owned by this
// data structure, so it is the user's duty to ensure the live range of "buf"
// is longer than this data structure.
NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
: buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t* buf,
size_t size_in_u32s) // Size is in uint32_t's.
: buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
NoncopyableBuffer(NoncopyableBuffer&&) = default;
NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete;
NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default;
// Ensure that the buffer owns the data and returns a mutable view into the
// owned data for modification.
template <typename T>
absl::Span<T> mutable_data() {
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
EnsureDataOwned();
DCHECK_EQ(size_ % sizeof(T), 0);
return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
}
template <typename T>
absl::Span<const T> const_data() const {
static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
DCHECK_EQ(size_ % sizeof(T), 0);
return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
}
// Clone the content to a given buffer.
void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
// Return true if data is owned by this buffer (have been copied to `data_`).
bool owns_data() const { return data_ != nullptr; }
// Returns a copy of the object that owns its buffer.
NoncopyableBuffer Clone(size_t alignment = 1) const {
auto clone = alignment <= 1
? NoncopyableBuffer(size_)
: NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
memcpy(clone.data_.get(), buf_, size_);
return clone;
}
// Ensure that the buffer owns the data.
void EnsureDataOwned() {
if (data_ == nullptr) {
data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
memcpy(data_.get(), buf_, size_);
buf_ = data_.get();
}
}
private:
using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
NoncopyableBuffer(OwnedDataPtr data, size_t size)
: data_(std::move(data)), buf_(data_.get()), size_(size) {}
static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
return OwnedDataPtr(
static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
port::AlignedFree);
}
// If data_ != nullptr then buf_ == data_.get()
OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer.
const void* buf_; // Non-owning data pointer.
size_t size_; // Size in number of bytes.
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_