[TF] Update Variant object to store small objects inline.

PiperOrigin-RevId: 247047534
This commit is contained in:
Eugene Brevdo 2019-05-07 10:40:34 -07:00 committed by TensorFlower Gardener
parent b17b0103a9
commit 2f345d145e
8 changed files with 754 additions and 69 deletions

View File

@ -50,6 +50,14 @@ class DatasetVariantWrapper {
if (dataset_) dataset_->Ref(); if (dataset_) dataset_->Ref();
} }
DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
if (&other == this) return *this;
std::swap(dataset_, other.dataset_);
return *this;
}
DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
~DatasetVariantWrapper() { ~DatasetVariantWrapper() {
if (dataset_) dataset_->Unref(); if (dataset_) dataset_->Unref();
} }
@ -75,7 +83,7 @@ class DatasetVariantWrapper {
} }
private: private:
DatasetBase* const dataset_; // Owns one reference. DatasetBase* dataset_; // Owns one reference.
}; };
const char kWrappedDatasetVariantTypeName[] = const char kWrappedDatasetVariantTypeName[] =

View File

@ -43,6 +43,7 @@ class OpKernelContext;
class Tensor; class Tensor;
class TensorBuffer; class TensorBuffer;
class TensorCApi; class TensorCApi;
class TensorCord;
class TensorDescription; class TensorDescription;
class TensorProto; class TensorProto;
class Var; class Var;
@ -607,6 +608,7 @@ class Tensor {
friend class DMAHelper; friend class DMAHelper;
friend class TensorCApi; friend class TensorCApi;
friend class TensorCord; // For access to buf_
friend class TensorReference; // For access to buf_ friend class TensorReference; // For access to buf_
friend class VariableOp; // For access to set_shape friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape

View File

@ -23,9 +23,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
Variant::~Variant() { clear(); }
bool Variant::Decode(VariantTensorData data) { bool Variant::Decode(VariantTensorData data) {
if (!is_empty()) { if (!is_empty()) {
return value_->Decode(std::move(data)); return GetValue()->Decode(std::move(data));
} }
return true; return true;
} }
@ -35,7 +37,7 @@ void* Variant::get() {
if (is_empty()) { if (is_empty()) {
return nullptr; return nullptr;
} }
return value_->RawPtr(); return GetValue()->RawPtr();
} }
template <> template <>
@ -43,7 +45,7 @@ const void* Variant::get() const {
if (is_empty()) { if (is_empty()) {
return nullptr; return nullptr;
} }
return value_->RawPtr(); return GetValue()->RawPtr();
} }
template <> template <>

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -68,7 +69,7 @@ void EncodeVariant(const T& value, string* buf);
// //
// string TypeName() const; // string TypeName() const;
// void Encode(VariantTensorData* data) const; // void Encode(VariantTensorData* data) const;
// void Decode(VariantTensorData data); // bool Decode(VariantTensorData data);
// //
// Simple POD types can elide the Encode/Decode functions, they are provided by // Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods. // helper methods.
@ -149,39 +150,57 @@ void EncodeVariant(const T& value, string* buf);
// //
class Variant { class Variant {
public: public:
constexpr Variant() noexcept = default; Variant() noexcept : is_inline_(false) {}
Variant(const Variant& other) ~Variant();
: value_(other.is_empty() ? std::unique_ptr<ValueInterface>()
: other.value_->Clone()) {}
Variant(Variant&& other) noexcept = default; Variant(const Variant& other);
Variant(Variant&& other) noexcept;
// Make sure that the type is CopyConstructible and not a
// tensorflow::Variant object itself. We want the copy constructor to be
// chosen for the tensorflow::Variant case.
template <typename T, typename VT = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_move_constructible<VT>::value,
void>::type* = nullptr>
Variant(T&& value);
// Make sure that the type is CopyConstructible and not a tensorflow::Variant
// object itself. We want the copy constructor to be chosen for the
// tensorflow::Variant case.
template <typename T, typename VT = typename std::decay<T>::type, template <typename T, typename VT = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<Variant, VT>::value && typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_copy_constructible<VT>::value, std::is_copy_constructible<VT>::value,
void>::type* = nullptr> void>::type* = nullptr>
Variant(T&& value) // NOLINT Variant(const T& value);
: value_(new Value<VT>(in_place, std::forward<T>(value))) {}
template <typename T, typename VT = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_copy_constructible<VT>::value,
void>::type* = nullptr>
Variant& operator=(const T& value);
template <typename T, typename VT = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_move_constructible<VT>::value,
void>::type* = nullptr>
Variant& operator=(T&& value);
Variant& operator=(const Variant& rhs) { Variant& operator=(const Variant& rhs) {
if (&rhs == this) return *this;
Variant(rhs).swap(*this); Variant(rhs).swap(*this);
return *this; return *this;
} }
Variant& operator=(Variant&& rhs) noexcept { Variant& operator=(Variant&& rhs) noexcept {
if (&rhs == this) return *this;
Variant(std::move(rhs)).swap(*this); Variant(std::move(rhs)).swap(*this);
return *this; return *this;
} }
bool is_empty() const { return value_ == nullptr; } bool is_empty() const { return GetValue() == nullptr; }
void clear() noexcept { value_.reset(); } void clear() noexcept;
void swap(Variant& other) noexcept { value_.swap(other.value_); } void swap(Variant& other) noexcept;
// Note, unlike TypeName(), TypeId() does not return the TypeIndex // Note, unlike TypeName(), TypeId() does not return the TypeIndex
// of the original type when a TensorValueDataProto is stored as the // of the original type when a TensorValueDataProto is stored as the
@ -191,12 +210,13 @@ class Variant {
if (is_empty()) { if (is_empty()) {
return VoidTypeIndex; return VoidTypeIndex;
} }
return value_->TypeId(); return GetValue()->TypeId();
} }
string DebugString() const { string DebugString() const {
return strings::StrCat("Variant<type: ", TypeName(), return strings::StrCat(
" value: ", value_->DebugString(), ">"); "Variant<type: ", TypeName(),
" value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
} }
// Returns a pointer to the stored value if it is type T, or nullptr // Returns a pointer to the stored value if it is type T, or nullptr
@ -205,7 +225,7 @@ class Variant {
T* get() { T* get() {
const TypeIndex TTypeIndex = MakeTypeIndex<T>(); const TypeIndex TTypeIndex = MakeTypeIndex<T>();
if (is_empty() || (TTypeIndex != TypeId())) return nullptr; if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value); return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
} }
// Returns a pointer to the stored value if it is type T, or nullptr // Returns a pointer to the stored value if it is type T, or nullptr
@ -215,7 +235,7 @@ class Variant {
const TypeIndex TTypeIndex = MakeTypeIndex<T>(); const TypeIndex TTypeIndex = MakeTypeIndex<T>();
if (is_empty() || (TTypeIndex != TypeId())) return nullptr; if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
return std::addressof( return std::addressof(
static_cast<const Variant::Value<T>*>(value_.get())->value); static_cast<const Variant::Value<T>*>(GetValue())->value);
} }
// Returns TypeNameVariant(value). // Returns TypeNameVariant(value).
@ -227,13 +247,13 @@ class Variant {
if (is_empty()) { if (is_empty()) {
return ""; return "";
} }
return value_->TypeName(); return GetValue()->TypeName();
} }
// Serialize the contents of the stored object into `data`. // Serialize the contents of the stored object into `data`.
void Encode(VariantTensorData* data) const { void Encode(VariantTensorData* data) const {
if (!is_empty()) { if (!is_empty()) {
value_->Encode(data); GetValue()->Encode(data);
} }
} }
@ -243,26 +263,36 @@ class Variant {
// Helper methods to directly serialize/deserialize from strings. // Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const { void Encode(string* buf) const {
if (!is_empty()) { if (!is_empty()) {
value_->Encode(buf); GetValue()->Encode(buf);
} }
} }
bool Decode(string buf) { bool Decode(string buf) {
if (!is_empty()) { if (!is_empty()) {
return value_->Decode(std::move(buf)); return GetValue()->Decode(std::move(buf));
} }
return true; return true;
} }
template <typename VT>
static constexpr bool CanInlineType() {
return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
(alignof(Value<VT>) <= kMaxInlineValueAlignSize));
}
private: private:
struct in_place_t {}; struct in_place_t {};
static constexpr in_place_t in_place{}; static constexpr in_place_t kInPlace{};
struct ValueInterface { struct ValueInterface {
virtual ~ValueInterface() = default; virtual ~ValueInterface() = default;
virtual TypeIndex TypeId() const = 0; virtual TypeIndex TypeId() const = 0;
virtual void* RawPtr() = 0; virtual void* RawPtr() = 0;
virtual const void* RawPtr() const = 0; virtual const void* RawPtr() const = 0;
virtual std::unique_ptr<ValueInterface> Clone() const = 0; virtual ValueInterface* Clone() const = 0;
virtual void CloneInto(ValueInterface* memory) const = 0;
virtual void DefaultConstructIn(ValueInterface* memory) const = 0;
virtual void Swap(ValueInterface* memory) = 0;
virtual void MoveTo(ValueInterface* memory) = 0;
virtual string TypeName() const = 0; virtual string TypeName() const = 0;
virtual string DebugString() const = 0; virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0; virtual void Encode(VariantTensorData* data) const = 0;
@ -277,6 +307,10 @@ class Variant {
explicit Value(in_place_t /*tag*/, Args&&... args) explicit Value(in_place_t /*tag*/, Args&&... args)
: value(std::forward<Args>(args)...) {} : value(std::forward<Args>(args)...) {}
// NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
// build `alignof(Variant<void*>)`.
~Value() final = default;
TypeIndex TypeId() const override { TypeIndex TypeId() const override {
const TypeIndex value_type_index = const TypeIndex value_type_index =
MakeTypeIndex<typename std::decay<T>::type>(); MakeTypeIndex<typename std::decay<T>::type>();
@ -287,8 +321,33 @@ class Variant {
const void* RawPtr() const override { return &value; } const void* RawPtr() const override { return &value; }
std::unique_ptr<ValueInterface> Clone() const override { ValueInterface* Clone() const override {
return std::unique_ptr<ValueInterface>(new Value(in_place, value)); // NOTE: Use placement new here because we override `operator delete`,
// and need to match the call to `port::Free()` with a call to
// `port::Malloc()`.
auto* clone = static_cast<Value*>(port::Malloc(sizeof(Value)));
new (clone) Value(kInPlace, value);
return clone;
}
void DefaultConstructIn(ValueInterface* memory) const override {
new (memory) Value(kInPlace, T());
}
void MoveTo(ValueInterface* memory) override {
CHECK(TypeId() == memory->TypeId())
<< TypeId().name() << " vs. " << memory->TypeId().name();
static_cast<Value*>(memory)->value = std::move(value);
}
void CloneInto(ValueInterface* memory) const override {
new (memory) Value(kInPlace, value);
}
void Swap(ValueInterface* memory) override {
CHECK(TypeId() == memory->TypeId())
<< TypeId().name() << " vs. " << memory->TypeId().name();
std::swap(value, static_cast<Value*>(memory)->value);
} }
string TypeName() const override { return TypeNameVariant(value); } string TypeName() const override { return TypeNameVariant(value); }
@ -307,14 +366,368 @@ class Variant {
bool Decode(string buf) override { return DecodeVariant(&buf, &value); } bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
// We override operator delete in order to selectively free memory
// depending on if Value<VT> is stored inline or on the heap:
//
// Value<VT> is stored inline if its size <= InlineValue::kMaxValueSize and
// its alignment <= kMaxInlineValueAlignSize. This check is performed by
// CanInlineType<VT>().
//
// We only need to call its destructor in this case and then overwrite
// the inline memory with zeros. Variant::clear() does this.
// Thus, in the inline case, the delete operator does nothing (calling
// delete on the memory location calls the destructor only).
//
// If !CanInlineType<VT>(), then it is stored as a pointer inside HeapValue.
// The memory buffer it resides in on the heap was allocated with
// port::Malloc, and it should be deallocated via port::Free.
//
// operator delete is stored in the vtable since ~ValueInterface is a
// virtual destructor; furthermore it has access to VT and can calculate
// CanInlineType<VT>().
static void operator delete(void* ptr);
static void operator delete(void*, void*) {
// Some compilers require an overridden class-specific deallocation
// function, which will be called if placement `new` throws an
// exception.
}
T value; T value;
}; };
static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
using HeapValue = std::unique_ptr<ValueInterface>;
struct InlineValue {
// We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
// into the aligned space of a TensorBuffer.
static constexpr int kMaxValueSize = (64 - /*some extra padding=*/16);
typedef char ValueDataArray[kMaxValueSize];
alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
bool has_value = false;
explicit InlineValue() {}
InlineValue(const InlineValue& other) noexcept
: has_value(other.has_value) {
if (other.has_value) {
other.AsValueInterface()->CloneInto(AsValueInterface());
}
}
InlineValue(InlineValue&& other) noexcept : has_value(other.has_value) {
if (other.has_value) {
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
other.AsValueInterface()->MoveTo(AsValueInterface());
other.Cleanup();
}
}
void Cleanup() {
// **NOTE** This must be a no-op if the memory representation of
// InlineValue is all zeros, in order to properly interact with
// HeapOrInline::ResetMemory().
if (has_value) {
// This doesn't actually delete anything on the heap; the delete
// operator of Value<VT> is overridden to do nothing for inline
// values; the side-effect of delete is that the virtual destructor is
// called.
//
// We leave it to callers to overwrite the data buffer in value_data
// with new objects.
delete AsValueInterface();
}
has_value = false;
}
InlineValue& operator=(const InlineValue& other) {
if (&other == this) return *this;
Cleanup();
if (other.has_value) {
other.AsValueInterface()->CloneInto(AsValueInterface());
}
has_value = other.has_value;
return *this;
}
InlineValue& operator=(InlineValue&& other) {
if (&other == this) return *this;
if (other.has_value) {
if (has_value && AsValueInterface()->TypeId() ==
other.AsValueInterface()->TypeId()) {
other.AsValueInterface()->Swap(AsValueInterface());
} else {
if (has_value) {
if (AsValueInterface()->TypeId() !=
other.AsValueInterface()->TypeId()) {
Cleanup();
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
}
other.AsValueInterface()->MoveTo(AsValueInterface());
} else {
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
other.AsValueInterface()->MoveTo(AsValueInterface());
}
other.Cleanup();
has_value = true;
}
} else {
Cleanup();
}
return *this;
}
ValueInterface* AsValueInterface() {
return reinterpret_cast<ValueInterface*>(value_data);
}
const ValueInterface* AsValueInterface() const {
return reinterpret_cast<const ValueInterface*>(value_data);
}
// **WARNING** This must be a no-op when the byte-representation of
// InlineValue is all zeros.
~InlineValue() { Cleanup(); }
};
// value_ can point to any type T as wrapped by a ValueInterface. // value_ can point to any type T as wrapped by a ValueInterface.
// The only real requirement is that T is default-constructible. // The only real requirement is that T is default-constructible.
std::unique_ptr<ValueInterface> value_; union HeapOrInline {
HeapOrInline() { ResetMemory(); }
explicit HeapOrInline(HeapValue&& v) : heap_value(std::move(v)) {}
explicit HeapOrInline(InlineValue&& v) : inline_value(std::move(v)) {}
~HeapOrInline() {} // Taken care of by owner.
// This must be called when modifying which element of HeapOrInline is
// being used, because the destructor of the new class may be called
// while the memory is still a representation of the old class.
// **WARNING** This code assumes that the destructors of HeapValue and
// InlineValue are no-ops when the internal representation is zeros.
//
// Example of when this is needed:
// value.heap_value = HeapValue(...);
// // Segfault. This calls InlineValue::Cleanup on value.inline_value
// // but the internal memory representation is that of HeapValue.
// value.inline_value = InlineValue();
//
// The correct way to do this:
// value.heap_value = HeapValue(...);
// value.ResetMemory();
// value.inline_value = InlineValue();
void ResetMemory();
HeapValue heap_value;
InlineValue inline_value;
} value_;
bool is_inline_;
bool IsInlineValue() const { return is_inline_; }
ValueInterface* GetValue() {
if (IsInlineValue()) {
return value_.inline_value.AsValueInterface();
} else {
return value_.heap_value.get();
}
}
const ValueInterface* GetValue() const {
if (IsInlineValue()) {
return value_.inline_value.AsValueInterface();
} else {
return value_.heap_value.get();
}
}
// PRECONDITION: Called on construction or clear() has been called before
// this method.
template <typename T, typename VT>
void InsertValueMove(T&& value) {
if (is_inline_) {
Value<VT>* inline_value_data =
reinterpret_cast<Value<VT>*>(value_.inline_value.value_data);
new (inline_value_data) Value<VT>(kInPlace, std::forward<T>(value));
value_.inline_value.has_value = true;
} else {
auto* moved = static_cast<Value<VT>*>(port::Malloc(sizeof(Value<VT>)));
new (moved) Value<VT>(kInPlace, std::forward<T>(value));
value_.heap_value = HeapValue(moved);
}
}
// PRECONDITION: Called on construction or clear() has been called before
// this method.
template <typename T, typename VT>
void InsertValueCopy(const T& value) {
if (is_inline_) {
Value<VT>* inline_value_data =
reinterpret_cast<Value<VT>*>(value_.inline_value.value_data);
new (inline_value_data) Value<VT>(kInPlace, value);
value_.inline_value.has_value = true;
} else {
auto* moved = static_cast<Value<VT>*>(port::Malloc(sizeof(Value<VT>)));
new (moved) Value<VT>(kInPlace, value);
value_.heap_value = HeapValue(moved);
}
}
}; };
// Make sure that a Variant object can reside in a 64-byte aligned Tensor
// buffer.
static_assert(sizeof(Variant) <= 64,
"Expected internal representation to be 64 bytes.");
inline Variant::Variant(const Variant& other) : is_inline_(other.is_inline_) {
if (!other.is_empty()) {
if (other.IsInlineValue()) {
value_.inline_value = InlineValue();
other.GetValue()->CloneInto(GetValue());
value_.inline_value.has_value = true;
} else {
value_.heap_value = HeapValue(other.GetValue()->Clone());
is_inline_ = false;
}
}
}
inline Variant::Variant(Variant&& other) noexcept
: is_inline_(other.is_inline_) {
if (!other.is_empty()) {
if (other.IsInlineValue()) {
value_.inline_value = InlineValue();
other.GetValue()->DefaultConstructIn(GetValue());
other.GetValue()->MoveTo(GetValue());
value_.inline_value.has_value = true;
other.value_.ResetMemory();
other.is_inline_ = false;
} else {
value_.heap_value = std::move(other.value_.heap_value);
other.value_.ResetMemory();
is_inline_ = false;
}
}
}
template <typename VT>
void Variant::Value<VT>::operator delete(void* ptr) {
if (!CanInlineType<VT>()) port::Free(ptr);
}
template <typename T, typename VT,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_move_constructible<VT>::value,
void>::type*>
inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
InsertValueMove<T, VT>(std::forward<T>(value));
}
template <typename T, typename VT,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_copy_constructible<VT>::value,
void>::type*>
inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
InsertValueCopy<T, VT>(value);
}
template <typename T, typename VT,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_move_constructible<VT>::value,
void>::type*>
inline Variant& Variant::operator=(T&& value) {
clear();
is_inline_ = CanInlineType<VT>();
InsertValueMove<T, VT>(std::forward<T>(value));
return *this;
}
template <typename T, typename VT,
typename std::enable_if<!std::is_same<Variant, VT>::value &&
std::is_copy_constructible<VT>::value,
void>::type*>
inline Variant& Variant::operator=(const T& value) {
clear();
is_inline_ = CanInlineType<VT>();
InsertValueCopy<T, VT>(value);
return *this;
}
inline void Variant::HeapOrInline::ResetMemory() {
memset( // NOLINT: not TriviallyCopyable
this, 0, sizeof(Variant::HeapOrInline));
}
inline void Variant::clear() noexcept {
if (!is_empty()) {
if (IsInlineValue()) {
value_.inline_value.~InlineValue();
} else {
value_.heap_value.~HeapValue();
}
value_.ResetMemory();
}
is_inline_ = false;
}
inline void Variant::swap(Variant& other) noexcept {
if (is_empty()) {
if (other.IsInlineValue()) {
value_.ResetMemory();
value_.inline_value = std::move(other.value_.inline_value);
other.value_.ResetMemory();
other.value_.heap_value = HeapValue();
is_inline_ = true;
other.is_inline_ = false;
} else {
value_.ResetMemory();
value_.heap_value = std::move(other.value_.heap_value);
other.value_.ResetMemory();
other.value_.heap_value = HeapValue();
is_inline_ = false;
other.is_inline_ = false;
}
} else if (other.is_empty()) {
if (IsInlineValue()) {
other.value_.ResetMemory();
other.value_.inline_value = std::move(value_.inline_value);
value_.ResetMemory();
value_.heap_value = HeapValue();
other.is_inline_ = true;
is_inline_ = false;
} else {
other.value_.ResetMemory();
other.value_.heap_value = std::move(value_.heap_value);
value_.ResetMemory();
value_.heap_value = HeapValue();
other.is_inline_ = false;
is_inline_ = false;
}
} else { // Both Variants have values.
if (other.IsInlineValue() && IsInlineValue()) {
std::swap(value_.inline_value, other.value_.inline_value);
} else if (!other.IsInlineValue() && !IsInlineValue()) {
std::swap(value_.heap_value, other.value_.heap_value);
} else if (other.IsInlineValue() && !IsInlineValue()) {
HeapValue v = std::move(value_.heap_value);
value_.ResetMemory();
value_.inline_value = std::move(other.value_.inline_value);
other.value_.ResetMemory();
other.value_.heap_value = std::move(v);
is_inline_ = true;
other.is_inline_ = false;
} else { // !other.IsInlineValue() && IsInlineValue()
HeapValue v = std::move(other.value_.heap_value);
other.value_.ResetMemory();
other.value_.inline_value = std::move(value_.inline_value);
value_.ResetMemory();
value_.heap_value = std::move(v);
is_inline_ = false;
other.is_inline_ = true;
}
}
}
template <> template <>
void* Variant::get(); void* Variant::get();

View File

@ -62,6 +62,10 @@ class VariantTensorData {
return GetMetadata<T>(value, PODResolver<T>()); return GetMetadata<T>(value, PODResolver<T>());
} }
string& metadata_string() { return metadata_; }
const string& metadata_string() const { return metadata_; }
// Tensors contained within objects being serialized. // Tensors contained within objects being serialized.
int tensors_size() const; int tensors_size() const;
const Tensor& tensors(int index) const; const Tensor& tensors(int index) const;

View File

@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <vector>
#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h" #include <xmmintrin.h>
#include <vector>
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -29,17 +32,133 @@ namespace tensorflow {
namespace { namespace {
template <typename T> template <typename T, bool BIG>
struct Wrapper { struct Wrapper {
T value; T value;
char big[BIG ? 256 : 0];
string TypeName() const { return "POD"; } string TypeName() const { return "POD"; }
}; };
using Int = Wrapper<int>; template <bool BIG>
using Float = Wrapper<float>; using Int = Wrapper<int, BIG>;
template <bool BIG>
using Float = Wrapper<float, BIG>;
template <bool BIG>
class DeleteCounter {
public:
DeleteCounter() : big_{}, counter_(nullptr) {}
explicit DeleteCounter(int* counter) : big_{}, counter_(counter) {}
~DeleteCounter() {
if (counter_) ++*counter_;
}
// Need custom move operations because int* just gets copied on move, but we
// need to clear counter_ on move.
DeleteCounter& operator=(const DeleteCounter& rhs) = default;
DeleteCounter& operator=(DeleteCounter&& rhs) {
if (this == &rhs) return *this;
counter_ = rhs.counter_;
rhs.counter_ = nullptr;
return *this;
}
DeleteCounter(DeleteCounter&& rhs) {
counter_ = rhs.counter_;
rhs.counter_ = nullptr;
}
DeleteCounter(const DeleteCounter& rhs) = default;
char big_[BIG ? 256 : 0];
int* counter_;
string TypeName() const { return "DeleteCounter"; }
void Encode(VariantTensorData* data) const {}
bool Decode(VariantTensorData data) { return false; }
};
} // end namespace } // end namespace
TEST(VariantTest, MoveAndCopyBetweenBigAndSmall) {
Variant x;
int deleted_big = 0;
int deleted_small = 0;
x = DeleteCounter</*BIG=*/true>(&deleted_big);
EXPECT_EQ(deleted_big, 0);
x = DeleteCounter</*BIG=*/false>(&deleted_small);
EXPECT_EQ(deleted_big, 1);
EXPECT_EQ(deleted_small, 0);
x = DeleteCounter</*BIG=*/true>(&deleted_big);
EXPECT_EQ(deleted_big, 1);
EXPECT_EQ(deleted_small, 1);
x.clear();
EXPECT_EQ(deleted_big, 2);
EXPECT_EQ(deleted_small, 1);
DeleteCounter</*BIG=*/true> big(&deleted_big);
DeleteCounter</*BIG=*/false> small(&deleted_small);
EXPECT_EQ(deleted_big, 2);
EXPECT_EQ(deleted_small, 1);
x = big;
EXPECT_EQ(deleted_big, 2);
EXPECT_EQ(deleted_small, 1);
x = small;
EXPECT_EQ(deleted_big, 3);
EXPECT_EQ(deleted_small, 1);
x = std::move(big);
EXPECT_EQ(deleted_big, 3);
EXPECT_EQ(deleted_small, 2);
x = std::move(small);
EXPECT_EQ(deleted_big, 4);
EXPECT_EQ(deleted_small, 2);
x.clear();
EXPECT_EQ(deleted_big, 4);
EXPECT_EQ(deleted_small, 3);
}
TEST(VariantTest, MoveAndCopyBetweenBigAndSmallVariants) {
int deleted_big = 0;
int deleted_small = 0;
{
Variant x = DeleteCounter</*BIG=*/true>(&deleted_big);
Variant y = DeleteCounter</*BIG=*/false>(&deleted_small);
EXPECT_EQ(deleted_big, 0);
EXPECT_EQ(deleted_small, 0);
x = y;
EXPECT_EQ(deleted_big, 1);
EXPECT_EQ(deleted_small, 0);
x = x;
EXPECT_EQ(deleted_big, 1);
EXPECT_EQ(deleted_small, 0);
EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr);
EXPECT_NE(y.get<DeleteCounter<false>>(), nullptr);
x = std::move(y);
EXPECT_EQ(deleted_small, 1);
EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr);
}
EXPECT_EQ(deleted_big, 1);
EXPECT_EQ(deleted_small, 2);
deleted_big = 0;
deleted_small = 0;
{
Variant x = DeleteCounter</*BIG=*/false>(&deleted_small);
Variant y = DeleteCounter</*BIG=*/true>(&deleted_big);
EXPECT_EQ(deleted_big, 0);
EXPECT_EQ(deleted_small, 0);
x = y;
EXPECT_EQ(deleted_big, 0);
EXPECT_EQ(deleted_small, 1);
x = x;
EXPECT_EQ(deleted_big, 0);
EXPECT_EQ(deleted_small, 1);
EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr);
EXPECT_NE(y.get<DeleteCounter<true>>(), nullptr);
x = std::move(y);
EXPECT_EQ(deleted_big, 1);
EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr);
}
EXPECT_EQ(deleted_big, 2);
EXPECT_EQ(deleted_small, 1);
}
TEST(VariantTest, Int) { TEST(VariantTest, Int) {
Variant x; Variant x;
EXPECT_EQ(x.get<void>(), nullptr); EXPECT_EQ(x.get<void>(), nullptr);
@ -49,45 +168,125 @@ TEST(VariantTest, Int) {
EXPECT_EQ(x.TypeName(), "int"); EXPECT_EQ(x.TypeName(), "int");
} }
TEST(VariantTest, Basic) { struct MayCreateAlignmentDifficulties {
int a;
__m128 b;
};
bool M128AllEqual(const __m128& a, const __m128& b) {
return _mm_movemask_ps(_mm_cmpeq_ps(a, b)) == 0xf;
}
TEST(VariantTest, NotAlignable) {
Variant x;
EXPECT_EQ(x.get<void>(), nullptr);
__m128 v = _mm_set_ps(1.0, 2.0, 3.0, 4.0);
x = MayCreateAlignmentDifficulties{-1, v};
EXPECT_NE(x.get<void>(), nullptr);
auto* x_val = x.get<MayCreateAlignmentDifficulties>();
// check that *x_val == x
Variant y = x;
EXPECT_EQ(x_val->a, -1);
EXPECT_TRUE(M128AllEqual(x_val->b, v));
auto* y_val = y.get<MayCreateAlignmentDifficulties>();
EXPECT_EQ(y_val->a, -1);
EXPECT_TRUE(M128AllEqual(y_val->b, v));
Variant z = std::move(y);
auto* z_val = z.get<MayCreateAlignmentDifficulties>();
EXPECT_EQ(z_val->a, -1);
EXPECT_TRUE(M128AllEqual(z_val->b, v));
}
template <bool BIG>
void TestBasic() {
Variant x; Variant x;
EXPECT_EQ(x.get<void>(), nullptr); EXPECT_EQ(x.get<void>(), nullptr);
x = Int{42}; x = Int<BIG>{42};
EXPECT_NE(x.get<void>(), nullptr); EXPECT_NE(x.get<void>(), nullptr);
EXPECT_NE(x.get<Int>(), nullptr); EXPECT_NE(x.get<Int<BIG>>(), nullptr);
EXPECT_EQ(x.get<Int>()->value, 42); EXPECT_EQ(x.get<Int<BIG>>()->value, 42);
EXPECT_EQ(x.TypeName(), "POD"); EXPECT_EQ(x.TypeName(), "POD");
} }
TEST(VariantTest, ConstGet) { TEST(VariantTest, Basic) { TestBasic<false>(); }
TEST(VariantTest, BasicBig) { TestBasic<true>(); }
template <bool BIG>
void TestConstGet() {
Variant x; Variant x;
EXPECT_EQ(x.get<void>(), nullptr); EXPECT_EQ(x.get<void>(), nullptr);
x = Int{42}; x = Int<BIG>{42};
const Variant y = x; const Variant y = x;
EXPECT_NE(y.get<void>(), nullptr); EXPECT_NE(y.get<void>(), nullptr);
EXPECT_NE(y.get<Int>(), nullptr); EXPECT_NE(y.get<Int<BIG>>(), nullptr);
EXPECT_EQ(y.get<Int>()->value, 42); EXPECT_EQ(y.get<Int<BIG>>()->value, 42);
} }
TEST(VariantTest, Clear) { TEST(VariantTest, ConstGet) { TestConstGet<false>(); }
TEST(VariantTest, ConstGetBig) { TestConstGet<true>(); }
template <bool BIG>
void TestClear() {
Variant x; Variant x;
EXPECT_EQ(x.get<void>(), nullptr); EXPECT_EQ(x.get<void>(), nullptr);
x = Int{42}; x = Int<BIG>{42};
EXPECT_NE(x.get<void>(), nullptr); EXPECT_NE(x.get<void>(), nullptr);
EXPECT_NE(x.get<Int>(), nullptr); EXPECT_NE(x.get<Int<BIG>>(), nullptr);
EXPECT_EQ(x.get<Int>()->value, 42); EXPECT_EQ(x.get<Int<BIG>>()->value, 42);
x.clear(); x.clear();
EXPECT_EQ(x.get<void>(), nullptr); EXPECT_EQ(x.get<void>(), nullptr);
} }
TEST(VariantTest, Clear) { TestClear<false>(); }
TEST(VariantTest, ClearBig) { TestClear<true>(); }
template <bool BIG>
void TestClearDeletes() {
Variant x;
EXPECT_EQ(x.get<void>(), nullptr);
int deleted_count = 0;
using DC = DeleteCounter<BIG>;
DC dc(&deleted_count);
EXPECT_EQ(deleted_count, 0);
x = dc;
EXPECT_EQ(deleted_count, 0);
EXPECT_NE(x.get<void>(), nullptr);
EXPECT_NE(x.get<DC>(), nullptr);
x.clear();
EXPECT_EQ(x.get<void>(), nullptr);
EXPECT_EQ(deleted_count, 1);
x = dc;
EXPECT_EQ(deleted_count, 1);
Variant y = x;
EXPECT_EQ(deleted_count, 1);
x.clear();
EXPECT_EQ(deleted_count, 2);
y.clear();
EXPECT_EQ(deleted_count, 3);
}
TEST(VariantTest, ClearDeletesOnHeap) { TestClearDeletes</*BIG=*/true>(); }
TEST(VariantTest, ClearDeletesOnStack) { TestClearDeletes</*BIG=*/false>(); }
TEST(VariantTest, Tensor) { TEST(VariantTest, Tensor) {
Variant x; Variant x;
Tensor t(DT_FLOAT, {}); Tensor t(DT_FLOAT, {});
@ -101,6 +300,16 @@ TEST(VariantTest, Tensor) {
EXPECT_EQ(x.TypeName(), "tensorflow::Tensor"); EXPECT_EQ(x.TypeName(), "tensorflow::Tensor");
} }
TEST(VariantTest, NontrivialTensorVariantCopy) {
Tensor variants(DT_VARIANT, {});
Tensor t(true);
test::FillValues<Variant>(&variants, gtl::ArraySlice<Variant>({t}));
const Tensor* t_c = variants.flat<Variant>()(0).get<Tensor>();
EXPECT_EQ(t_c->dtype(), t.dtype());
EXPECT_EQ(t_c->shape(), t.shape());
EXPECT_EQ(t_c->scalar<bool>()(), t.scalar<bool>()());
}
TEST(VariantTest, TensorProto) { TEST(VariantTest, TensorProto) {
Variant x; Variant x;
TensorProto t; TensorProto t;
@ -114,31 +323,41 @@ TEST(VariantTest, TensorProto) {
EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true); EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true);
} }
TEST(VariantTest, CopyValue) { template <bool BIG>
void TestCopyValue() {
Variant x, y; Variant x, y;
x = Int{10}; x = Int<BIG>{10};
y = x; y = x;
EXPECT_EQ(x.get<Int>()->value, 10); EXPECT_EQ(x.get<Int<BIG>>()->value, 10);
EXPECT_EQ(x.get<Int>()->value, y.get<Int>()->value); EXPECT_EQ(x.get<Int<BIG>>()->value, y.get<Int<BIG>>()->value);
} }
TEST(VariantTest, MoveValue) { TEST(VariantTest, CopyValue) { TestCopyValue<false>(); }
TEST(VariantTest, CopyValueBig) { TestCopyValue<true>(); }
template <bool BIG>
void TestMoveValue() {
Variant x; Variant x;
x = []() -> Variant { x = []() -> Variant {
Variant y; Variant y;
y = Int{10}; y = Int<BIG>{10};
return y; return y;
}(); }();
EXPECT_EQ(x.get<Int>()->value, 10); EXPECT_EQ(x.get<Int<BIG>>()->value, 10);
} }
TEST(VariantTest, MoveValue) { TestMoveValue<false>(); }
TEST(VariantTest, MoveValueBig) { TestMoveValue<true>(); }
TEST(VariantTest, TypeMismatch) { TEST(VariantTest, TypeMismatch) {
Variant x; Variant x;
x = Int{10}; x = Int<false>{10};
EXPECT_EQ(x.get<float>(), nullptr); EXPECT_EQ(x.get<float>(), nullptr);
EXPECT_EQ(x.get<int>(), nullptr); EXPECT_EQ(x.get<int>(), nullptr);
EXPECT_NE(x.get<Int>(), nullptr); EXPECT_NE(x.get<Int<false>>(), nullptr);
} }
struct TensorList { struct TensorList {
@ -206,19 +425,26 @@ TEST(VariantTest, TensorListTest) {
"Variant<type: TensorList value: ", data.DebugString(), ">")); "Variant<type: TensorList value: ", data.DebugString(), ">"));
} }
TEST(VariantTest, VariantArray) { template <bool BIG>
void TestVariantArray() {
Variant x[2]; Variant x[2];
x[0] = Int{2}; x[0] = Int<BIG>{2};
x[1] = Float{2.0f}; x[1] = Float<BIG>{2.0f};
EXPECT_EQ(x[0].get<Int>()->value, 2); EXPECT_EQ(x[0].get<Int<BIG>>()->value, 2);
EXPECT_EQ(x[1].get<Float>()->value, 2.0f); EXPECT_EQ(x[1].get<Float<BIG>>()->value, 2.0f);
} }
TEST(VariantTest, PodUpdate) { TEST(VariantTest, VariantArray) { TestVariantArray<false>(); }
TEST(VariantTest, VariantArrayBig) { TestVariantArray<true>(); }
template <bool BIG>
void PodUpdateTest() {
struct Pod { struct Pod {
int x; int x;
float y; float y;
char big[BIG ? 256 : 0];
string TypeName() const { return "POD"; } string TypeName() const { return "POD"; }
}; };
@ -232,10 +458,16 @@ TEST(VariantTest, PodUpdate) {
EXPECT_EQ(x.get<Pod>()->x, 30); EXPECT_EQ(x.get<Pod>()->x, 30);
} }
TEST(VariantTest, EncodeDecodePod) { TEST(VariantTest, PodUpdate) { PodUpdateTest<false>(); }
TEST(VariantTest, PodUpdateBig) { PodUpdateTest<true>(); }
template <bool BIG>
void TestEncodeDecodePod() {
struct Pod { struct Pod {
int x; int x;
float y; float y;
char big[BIG ? 256 : 0];
string TypeName() const { return "POD"; } string TypeName() const { return "POD"; }
}; };
@ -247,14 +479,17 @@ TEST(VariantTest, EncodeDecodePod) {
VariantTensorData serialized; VariantTensorData serialized;
x.Encode(&serialized); x.Encode(&serialized);
Variant y; Variant y = Pod{};
y = Pod();
y.Decode(serialized); y.Decode(serialized);
EXPECT_EQ(p.x, y.get<Pod>()->x); EXPECT_EQ(p.x, y.get<Pod>()->x);
EXPECT_EQ(p.y, y.get<Pod>()->y); EXPECT_EQ(p.y, y.get<Pod>()->y);
} }
TEST(VariantTest, EncodeDecodePod) { TestEncodeDecodePod<false>(); }
TEST(VariantTest, EncodeDecodePodBig) { TestEncodeDecodePod<true>(); }
TEST(VariantTest, EncodeDecodeTensor) { TEST(VariantTest, EncodeDecodeTensor) {
Variant x; Variant x;
Tensor t(DT_INT32, {}); Tensor t(DT_INT32, {});

View File

@ -237,6 +237,8 @@ class IteratorResource : public ResourceBase {
// destroyed, essentially triggering the iterator deletion. // destroyed, essentially triggering the iterator deletion.
class Deleter { class Deleter {
public: public:
Deleter() : deleter_() {}
Deleter(ResourceHandle handle, ResourceMgr* resource_manager) Deleter(ResourceHandle handle, ResourceMgr* resource_manager)
: deleter_(std::make_shared<Helper>(handle, resource_manager)) {} : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
@ -248,6 +250,10 @@ class IteratorResource : public ResourceBase {
VLOG(3) << "IteratorResource::Deleter copy constructor called."; VLOG(3) << "IteratorResource::Deleter copy constructor called.";
} }
Deleter& operator=(const Deleter& rhs) = delete;
Deleter& operator=(Deleter&& rhs) = default;
virtual ~Deleter() { virtual ~Deleter() {
VLOG(3) << "IteratorResource::Deleter destructor called."; VLOG(3) << "IteratorResource::Deleter destructor called.";
} }
@ -358,6 +364,9 @@ class IteratorStateVariant {
Decode(*other.data_); Decode(*other.data_);
} }
} }
IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
// Initializes this object with the current state of the iterator so // Initializes this object with the current state of the iterator so
// that it can be written on the next call to Encode(). // that it can be written on the next call to Encode().
Status InitializeFromIterator(OpKernelContext* ctx, Status InitializeFromIterator(OpKernelContext* ctx,

View File

@ -74,6 +74,8 @@ class Mutex : public ResourceBase {
struct SharedLockReleaser { struct SharedLockReleaser {
std::shared_ptr<LockReleaser> shared_lock; std::shared_ptr<LockReleaser> shared_lock;
SharedLockReleaser() : shared_lock() {}
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock) explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
: shared_lock(std::forward<decltype(lock)>(lock)) { : shared_lock(std::forward<decltype(lock)>(lock)) {
VLOG(3) << "Creating shared_ptr of " << shared_lock.get() VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
@ -86,6 +88,16 @@ class Mutex : public ResourceBase {
<< " count is: " << shared_lock.use_count(); << " count is: " << shared_lock.use_count();
} }
SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete;
SharedLockReleaser& operator=(SharedLockReleaser&& rhs) {
if (&rhs == this) return *this;
std::swap(shared_lock, rhs.shared_lock);
VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
return *this;
}
SharedLockReleaser(const SharedLockReleaser& rhs) SharedLockReleaser(const SharedLockReleaser& rhs)
: shared_lock(rhs.shared_lock) { : shared_lock(rhs.shared_lock) {
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get() VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()