diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index b0533fbc508..703c4a7e356 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -50,6 +50,14 @@ class DatasetVariantWrapper { if (dataset_) dataset_->Ref(); } + DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) { + if (&other == this) return *this; + std::swap(dataset_, other.dataset_); + return *this; + } + + DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete; + ~DatasetVariantWrapper() { if (dataset_) dataset_->Unref(); } @@ -75,7 +83,7 @@ class DatasetVariantWrapper { } private: - DatasetBase* const dataset_; // Owns one reference. + DatasetBase* dataset_; // Owns one reference. }; const char kWrappedDatasetVariantTypeName[] = diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 1c74ce2ca21..edbdc29db0c 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -43,6 +43,7 @@ class OpKernelContext; class Tensor; class TensorBuffer; class TensorCApi; +class TensorCord; class TensorDescription; class TensorProto; class Var; @@ -607,6 +608,7 @@ class Tensor { friend class DMAHelper; friend class TensorCApi; + friend class TensorCord; // For access to buf_ friend class TensorReference; // For access to buf_ friend class VariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index d43e3c72ece..e61afeada90 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,9 +23,11 @@ limitations under the License. namespace tensorflow { +Variant::~Variant() { clear(); } + bool Variant::Decode(VariantTensorData data) { if (!is_empty()) { - return value_->Decode(std::move(data)); + return GetValue()->Decode(std::move(data)); } return true; } @@ -35,7 +37,7 @@ void* Variant::get() { if (is_empty()) { return nullptr; } - return value_->RawPtr(); + return GetValue()->RawPtr(); } template <> @@ -43,7 +45,7 @@ const void* Variant::get() const { if (is_empty()) { return nullptr; } - return value_->RawPtr(); + return GetValue()->RawPtr(); } template <> diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index 10eabbc85fd..fa95bc83447 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/status.h" @@ -68,7 +69,7 @@ void EncodeVariant(const T& value, string* buf); // // string TypeName() const; // void Encode(VariantTensorData* data) const; -// void Decode(VariantTensorData data); +// bool Decode(VariantTensorData data); // // Simple POD types can elide the Encode/Decode functions, they are provided by // helper methods. @@ -149,39 +150,57 @@ void EncodeVariant(const T& value, string* buf); // class Variant { public: - constexpr Variant() noexcept = default; + Variant() noexcept : is_inline_(false) {} - Variant(const Variant& other) - : value_(other.is_empty() ? std::unique_ptr() - : other.value_->Clone()) {} + ~Variant(); - Variant(Variant&& other) noexcept = default; + Variant(const Variant& other); + Variant(Variant&& other) noexcept; + + // Make sure that the type is CopyConstructible and not a + // tensorflow::Variant object itself. We want the copy constructor to be + // chosen for the tensorflow::Variant case. + template ::type, + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type* = nullptr> + Variant(T&& value); - // Make sure that the type is CopyConstructible and not a tensorflow::Variant - // object itself. We want the copy constructor to be chosen for the - // tensorflow::Variant case. template ::type, typename std::enable_if::value && std::is_copy_constructible::value, void>::type* = nullptr> - Variant(T&& value) // NOLINT - : value_(new Value(in_place, std::forward(value))) {} + Variant(const T& value); + + template ::type, + typename std::enable_if::value && + std::is_copy_constructible::value, + void>::type* = nullptr> + Variant& operator=(const T& value); + + template ::type, + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type* = nullptr> + Variant& operator=(T&& value); Variant& operator=(const Variant& rhs) { + if (&rhs == this) return *this; Variant(rhs).swap(*this); return *this; } Variant& operator=(Variant&& rhs) noexcept { + if (&rhs == this) return *this; Variant(std::move(rhs)).swap(*this); return *this; } - bool is_empty() const { return value_ == nullptr; } + bool is_empty() const { return GetValue() == nullptr; } - void clear() noexcept { value_.reset(); } + void clear() noexcept; - void swap(Variant& other) noexcept { value_.swap(other.value_); } + void swap(Variant& other) noexcept; // Note, unlike TypeName(), TypeId() does not return the TypeIndex // of the original type when a TensorValueDataProto is stored as the @@ -191,12 +210,13 @@ class Variant { if (is_empty()) { return VoidTypeIndex; } - return value_->TypeId(); + return GetValue()->TypeId(); } string DebugString() const { - return strings::StrCat("VariantDebugString(), ">"); + return strings::StrCat( + "VariantDebugString(), ">"); } // Returns a pointer to the stored value if it is type T, or nullptr @@ -205,7 +225,7 @@ class Variant { T* get() { const TypeIndex TTypeIndex = MakeTypeIndex(); if (is_empty() || (TTypeIndex != TypeId())) return nullptr; - return std::addressof(static_cast*>(value_.get())->value); + return std::addressof(static_cast*>(GetValue())->value); } // Returns a pointer to the stored value if it is type T, or nullptr @@ -215,7 +235,7 @@ class Variant { const TypeIndex TTypeIndex = MakeTypeIndex(); if (is_empty() || (TTypeIndex != TypeId())) return nullptr; return std::addressof( - static_cast*>(value_.get())->value); + static_cast*>(GetValue())->value); } // Returns TypeNameVariant(value). @@ -227,13 +247,13 @@ class Variant { if (is_empty()) { return ""; } - return value_->TypeName(); + return GetValue()->TypeName(); } // Serialize the contents of the stored object into `data`. void Encode(VariantTensorData* data) const { if (!is_empty()) { - value_->Encode(data); + GetValue()->Encode(data); } } @@ -243,26 +263,36 @@ class Variant { // Helper methods to directly serialize/deserialize from strings. void Encode(string* buf) const { if (!is_empty()) { - value_->Encode(buf); + GetValue()->Encode(buf); } } bool Decode(string buf) { if (!is_empty()) { - return value_->Decode(std::move(buf)); + return GetValue()->Decode(std::move(buf)); } return true; } + template + static constexpr bool CanInlineType() { + return ((sizeof(Value) <= InlineValue::kMaxValueSize) && + (alignof(Value) <= kMaxInlineValueAlignSize)); + } + private: struct in_place_t {}; - static constexpr in_place_t in_place{}; + static constexpr in_place_t kInPlace{}; struct ValueInterface { virtual ~ValueInterface() = default; virtual TypeIndex TypeId() const = 0; virtual void* RawPtr() = 0; virtual const void* RawPtr() const = 0; - virtual std::unique_ptr Clone() const = 0; + virtual ValueInterface* Clone() const = 0; + virtual void CloneInto(ValueInterface* memory) const = 0; + virtual void DefaultConstructIn(ValueInterface* memory) const = 0; + virtual void Swap(ValueInterface* memory) = 0; + virtual void MoveTo(ValueInterface* memory) = 0; virtual string TypeName() const = 0; virtual string DebugString() const = 0; virtual void Encode(VariantTensorData* data) const = 0; @@ -277,6 +307,10 @@ class Variant { explicit Value(in_place_t /*tag*/, Args&&... args) : value(std::forward(args)...) {} + // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily + // build `alignof(Variant)`. + ~Value() final = default; + TypeIndex TypeId() const override { const TypeIndex value_type_index = MakeTypeIndex::type>(); @@ -287,8 +321,33 @@ class Variant { const void* RawPtr() const override { return &value; } - std::unique_ptr Clone() const override { - return std::unique_ptr(new Value(in_place, value)); + ValueInterface* Clone() const override { + // NOTE: Use placement new here because we override `operator delete`, + // and need to match the call to `port::Free()` with a call to + // `port::Malloc()`. + auto* clone = static_cast(port::Malloc(sizeof(Value))); + new (clone) Value(kInPlace, value); + return clone; + } + + void DefaultConstructIn(ValueInterface* memory) const override { + new (memory) Value(kInPlace, T()); + } + + void MoveTo(ValueInterface* memory) override { + CHECK(TypeId() == memory->TypeId()) + << TypeId().name() << " vs. " << memory->TypeId().name(); + static_cast(memory)->value = std::move(value); + } + + void CloneInto(ValueInterface* memory) const override { + new (memory) Value(kInPlace, value); + } + + void Swap(ValueInterface* memory) override { + CHECK(TypeId() == memory->TypeId()) + << TypeId().name() << " vs. " << memory->TypeId().name(); + std::swap(value, static_cast(memory)->value); } string TypeName() const override { return TypeNameVariant(value); } @@ -307,14 +366,368 @@ class Variant { bool Decode(string buf) override { return DecodeVariant(&buf, &value); } + // We override operator delete in order to selectively free memory + // depending on if Value is stored inline or on the heap: + // + // Value is stored inline if its size <= InlineValue::kMaxValueSize and + // its alignment <= kMaxInlineValueAlignSize. This check is performed by + // CanInlineType(). + // + // We only need to call its destructor in this case and then overwrite + // the inline memory with zeros. Variant::clear() does this. + // Thus, in the inline case, the delete operator does nothing (calling + // delete on the memory location calls the destructor only). + // + // If !CanInlineType(), then it is stored as a pointer inside HeapValue. + // The memory buffer it resides in on the heap was allocated with + // port::Malloc, and it should be deallocated via port::Free. + // + // operator delete is stored in the vtable since ~ValueInterface is a + // virtual destructor; furthermore it has access to VT and can calculate + // CanInlineType(). + static void operator delete(void* ptr); + + static void operator delete(void*, void*) { + // Some compilers require an overridden class-specific deallocation + // function, which will be called if placement `new` throws an + // exception. + } + T value; }; + static constexpr int kMaxInlineValueAlignSize = alignof(Value); + + using HeapValue = std::unique_ptr; + + struct InlineValue { + // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit + // into the aligned space of a TensorBuffer. + static constexpr int kMaxValueSize = (64 - /*some extra padding=*/16); + + typedef char ValueDataArray[kMaxValueSize]; + alignas(kMaxInlineValueAlignSize) ValueDataArray value_data; + bool has_value = false; + + explicit InlineValue() {} + + InlineValue(const InlineValue& other) noexcept + : has_value(other.has_value) { + if (other.has_value) { + other.AsValueInterface()->CloneInto(AsValueInterface()); + } + } + + InlineValue(InlineValue&& other) noexcept : has_value(other.has_value) { + if (other.has_value) { + other.AsValueInterface()->DefaultConstructIn(AsValueInterface()); + other.AsValueInterface()->MoveTo(AsValueInterface()); + other.Cleanup(); + } + } + + void Cleanup() { + // **NOTE** This must be a no-op if the memory representation of + // InlineValue is all zeros, in order to properly interact with + // HeapOrInline::ResetMemory(). + if (has_value) { + // This doesn't actually delete anything on the heap; the delete + // operator of Value is overridden to do nothing for inline + // values; the side-effect of delete is that the virtual destructor is + // called. + // + // We leave it to callers to overwrite the data buffer in value_data + // with new objects. + delete AsValueInterface(); + } + has_value = false; + } + + InlineValue& operator=(const InlineValue& other) { + if (&other == this) return *this; + Cleanup(); + if (other.has_value) { + other.AsValueInterface()->CloneInto(AsValueInterface()); + } + has_value = other.has_value; + return *this; + } + + InlineValue& operator=(InlineValue&& other) { + if (&other == this) return *this; + if (other.has_value) { + if (has_value && AsValueInterface()->TypeId() == + other.AsValueInterface()->TypeId()) { + other.AsValueInterface()->Swap(AsValueInterface()); + } else { + if (has_value) { + if (AsValueInterface()->TypeId() != + other.AsValueInterface()->TypeId()) { + Cleanup(); + other.AsValueInterface()->DefaultConstructIn(AsValueInterface()); + } + other.AsValueInterface()->MoveTo(AsValueInterface()); + } else { + other.AsValueInterface()->DefaultConstructIn(AsValueInterface()); + other.AsValueInterface()->MoveTo(AsValueInterface()); + } + other.Cleanup(); + has_value = true; + } + } else { + Cleanup(); + } + return *this; + } + + ValueInterface* AsValueInterface() { + return reinterpret_cast(value_data); + } + + const ValueInterface* AsValueInterface() const { + return reinterpret_cast(value_data); + } + + // **WARNING** This must be a no-op when the byte-representation of + // InlineValue is all zeros. + ~InlineValue() { Cleanup(); } + }; // value_ can point to any type T as wrapped by a ValueInterface. // The only real requirement is that T is default-constructible. - std::unique_ptr value_; + union HeapOrInline { + HeapOrInline() { ResetMemory(); } + explicit HeapOrInline(HeapValue&& v) : heap_value(std::move(v)) {} + explicit HeapOrInline(InlineValue&& v) : inline_value(std::move(v)) {} + ~HeapOrInline() {} // Taken care of by owner. + + // This must be called when modifying which element of HeapOrInline is + // being used, because the destructor of the new class may be called + // while the memory is still a representation of the old class. + // **WARNING** This code assumes that the destructors of HeapValue and + // InlineValue are no-ops when the internal representation is zeros. + // + // Example of when this is needed: + // value.heap_value = HeapValue(...); + // // Segfault. This calls InlineValue::Cleanup on value.inline_value + // // but the internal memory representation is that of HeapValue. + // value.inline_value = InlineValue(); + // + // The correct way to do this: + // value.heap_value = HeapValue(...); + // value.ResetMemory(); + // value.inline_value = InlineValue(); + void ResetMemory(); + + HeapValue heap_value; + InlineValue inline_value; + } value_; + bool is_inline_; + + bool IsInlineValue() const { return is_inline_; } + + ValueInterface* GetValue() { + if (IsInlineValue()) { + return value_.inline_value.AsValueInterface(); + } else { + return value_.heap_value.get(); + } + } + + const ValueInterface* GetValue() const { + if (IsInlineValue()) { + return value_.inline_value.AsValueInterface(); + } else { + return value_.heap_value.get(); + } + } + + // PRECONDITION: Called on construction or clear() has been called before + // this method. + template + void InsertValueMove(T&& value) { + if (is_inline_) { + Value* inline_value_data = + reinterpret_cast*>(value_.inline_value.value_data); + new (inline_value_data) Value(kInPlace, std::forward(value)); + value_.inline_value.has_value = true; + } else { + auto* moved = static_cast*>(port::Malloc(sizeof(Value))); + new (moved) Value(kInPlace, std::forward(value)); + value_.heap_value = HeapValue(moved); + } + } + + // PRECONDITION: Called on construction or clear() has been called before + // this method. + template + void InsertValueCopy(const T& value) { + if (is_inline_) { + Value* inline_value_data = + reinterpret_cast*>(value_.inline_value.value_data); + new (inline_value_data) Value(kInPlace, value); + value_.inline_value.has_value = true; + } else { + auto* moved = static_cast*>(port::Malloc(sizeof(Value))); + new (moved) Value(kInPlace, value); + value_.heap_value = HeapValue(moved); + } + } }; +// Make sure that a Variant object can reside in a 64-byte aligned Tensor +// buffer. +static_assert(sizeof(Variant) <= 64, + "Expected internal representation to be 64 bytes."); + +inline Variant::Variant(const Variant& other) : is_inline_(other.is_inline_) { + if (!other.is_empty()) { + if (other.IsInlineValue()) { + value_.inline_value = InlineValue(); + other.GetValue()->CloneInto(GetValue()); + value_.inline_value.has_value = true; + } else { + value_.heap_value = HeapValue(other.GetValue()->Clone()); + is_inline_ = false; + } + } +} + +inline Variant::Variant(Variant&& other) noexcept + : is_inline_(other.is_inline_) { + if (!other.is_empty()) { + if (other.IsInlineValue()) { + value_.inline_value = InlineValue(); + other.GetValue()->DefaultConstructIn(GetValue()); + other.GetValue()->MoveTo(GetValue()); + value_.inline_value.has_value = true; + other.value_.ResetMemory(); + other.is_inline_ = false; + } else { + value_.heap_value = std::move(other.value_.heap_value); + other.value_.ResetMemory(); + is_inline_ = false; + } + } +} + +template +void Variant::Value::operator delete(void* ptr) { + if (!CanInlineType()) port::Free(ptr); +} + +template ::value && + std::is_move_constructible::value, + void>::type*> +inline Variant::Variant(T&& value) : is_inline_(CanInlineType()) { + InsertValueMove(std::forward(value)); +} + +template ::value && + std::is_copy_constructible::value, + void>::type*> +inline Variant::Variant(const T& value) : is_inline_(CanInlineType()) { + InsertValueCopy(value); +} + +template ::value && + std::is_move_constructible::value, + void>::type*> +inline Variant& Variant::operator=(T&& value) { + clear(); + is_inline_ = CanInlineType(); + InsertValueMove(std::forward(value)); + return *this; +} + +template ::value && + std::is_copy_constructible::value, + void>::type*> +inline Variant& Variant::operator=(const T& value) { + clear(); + is_inline_ = CanInlineType(); + InsertValueCopy(value); + return *this; +} + +inline void Variant::HeapOrInline::ResetMemory() { + memset( // NOLINT: not TriviallyCopyable + this, 0, sizeof(Variant::HeapOrInline)); +} + +inline void Variant::clear() noexcept { + if (!is_empty()) { + if (IsInlineValue()) { + value_.inline_value.~InlineValue(); + } else { + value_.heap_value.~HeapValue(); + } + value_.ResetMemory(); + } + is_inline_ = false; +} + +inline void Variant::swap(Variant& other) noexcept { + if (is_empty()) { + if (other.IsInlineValue()) { + value_.ResetMemory(); + value_.inline_value = std::move(other.value_.inline_value); + other.value_.ResetMemory(); + other.value_.heap_value = HeapValue(); + is_inline_ = true; + other.is_inline_ = false; + } else { + value_.ResetMemory(); + value_.heap_value = std::move(other.value_.heap_value); + other.value_.ResetMemory(); + other.value_.heap_value = HeapValue(); + is_inline_ = false; + other.is_inline_ = false; + } + } else if (other.is_empty()) { + if (IsInlineValue()) { + other.value_.ResetMemory(); + other.value_.inline_value = std::move(value_.inline_value); + value_.ResetMemory(); + value_.heap_value = HeapValue(); + other.is_inline_ = true; + is_inline_ = false; + } else { + other.value_.ResetMemory(); + other.value_.heap_value = std::move(value_.heap_value); + value_.ResetMemory(); + value_.heap_value = HeapValue(); + other.is_inline_ = false; + is_inline_ = false; + } + } else { // Both Variants have values. + if (other.IsInlineValue() && IsInlineValue()) { + std::swap(value_.inline_value, other.value_.inline_value); + } else if (!other.IsInlineValue() && !IsInlineValue()) { + std::swap(value_.heap_value, other.value_.heap_value); + } else if (other.IsInlineValue() && !IsInlineValue()) { + HeapValue v = std::move(value_.heap_value); + value_.ResetMemory(); + value_.inline_value = std::move(other.value_.inline_value); + other.value_.ResetMemory(); + other.value_.heap_value = std::move(v); + is_inline_ = true; + other.is_inline_ = false; + } else { // !other.IsInlineValue() && IsInlineValue() + HeapValue v = std::move(other.value_.heap_value); + other.value_.ResetMemory(); + other.value_.inline_value = std::move(value_.inline_value); + value_.ResetMemory(); + value_.heap_value = std::move(v); + is_inline_ = false; + other.is_inline_ = true; + } + } +} + template <> void* Variant::get(); diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index d98cf6b5e1f..8c654ccec82 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -62,6 +62,10 @@ class VariantTensorData { return GetMetadata(value, PODResolver()); } + string& metadata_string() { return metadata_; } + + const string& metadata_string() const { return metadata_; } + // Tensors contained within objects being serialized. int tensors_size() const; const Tensor& tensors(int index) const; diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc index 8947f93887a..096143c6eb6 100644 --- a/tensorflow/core/framework/variant_test.cc +++ b/tensorflow/core/framework/variant_test.cc @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_encode_decode.h" -#include "tensorflow/core/framework/variant_tensor_data.h" + +#include + +#include #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/platform/test.h" @@ -29,17 +32,133 @@ namespace tensorflow { namespace { -template +template struct Wrapper { T value; + char big[BIG ? 256 : 0]; string TypeName() const { return "POD"; } }; -using Int = Wrapper; -using Float = Wrapper; +template +using Int = Wrapper; + +template +using Float = Wrapper; + +template +class DeleteCounter { + public: + DeleteCounter() : big_{}, counter_(nullptr) {} + explicit DeleteCounter(int* counter) : big_{}, counter_(counter) {} + ~DeleteCounter() { + if (counter_) ++*counter_; + } + // Need custom move operations because int* just gets copied on move, but we + // need to clear counter_ on move. + DeleteCounter& operator=(const DeleteCounter& rhs) = default; + DeleteCounter& operator=(DeleteCounter&& rhs) { + if (this == &rhs) return *this; + counter_ = rhs.counter_; + rhs.counter_ = nullptr; + return *this; + } + DeleteCounter(DeleteCounter&& rhs) { + counter_ = rhs.counter_; + rhs.counter_ = nullptr; + } + DeleteCounter(const DeleteCounter& rhs) = default; + char big_[BIG ? 256 : 0]; + int* counter_; + + string TypeName() const { return "DeleteCounter"; } + void Encode(VariantTensorData* data) const {} + bool Decode(VariantTensorData data) { return false; } +}; } // end namespace +TEST(VariantTest, MoveAndCopyBetweenBigAndSmall) { + Variant x; + int deleted_big = 0; + int deleted_small = 0; + x = DeleteCounter(&deleted_big); + EXPECT_EQ(deleted_big, 0); + x = DeleteCounter(&deleted_small); + EXPECT_EQ(deleted_big, 1); + EXPECT_EQ(deleted_small, 0); + x = DeleteCounter(&deleted_big); + EXPECT_EQ(deleted_big, 1); + EXPECT_EQ(deleted_small, 1); + x.clear(); + EXPECT_EQ(deleted_big, 2); + EXPECT_EQ(deleted_small, 1); + DeleteCounter big(&deleted_big); + DeleteCounter small(&deleted_small); + EXPECT_EQ(deleted_big, 2); + EXPECT_EQ(deleted_small, 1); + x = big; + EXPECT_EQ(deleted_big, 2); + EXPECT_EQ(deleted_small, 1); + x = small; + EXPECT_EQ(deleted_big, 3); + EXPECT_EQ(deleted_small, 1); + x = std::move(big); + EXPECT_EQ(deleted_big, 3); + EXPECT_EQ(deleted_small, 2); + x = std::move(small); + EXPECT_EQ(deleted_big, 4); + EXPECT_EQ(deleted_small, 2); + x.clear(); + EXPECT_EQ(deleted_big, 4); + EXPECT_EQ(deleted_small, 3); +} + +TEST(VariantTest, MoveAndCopyBetweenBigAndSmallVariants) { + int deleted_big = 0; + int deleted_small = 0; + { + Variant x = DeleteCounter(&deleted_big); + Variant y = DeleteCounter(&deleted_small); + EXPECT_EQ(deleted_big, 0); + EXPECT_EQ(deleted_small, 0); + x = y; + EXPECT_EQ(deleted_big, 1); + EXPECT_EQ(deleted_small, 0); + x = x; + EXPECT_EQ(deleted_big, 1); + EXPECT_EQ(deleted_small, 0); + EXPECT_NE(x.get>(), nullptr); + EXPECT_NE(y.get>(), nullptr); + x = std::move(y); + EXPECT_EQ(deleted_small, 1); + EXPECT_NE(x.get>(), nullptr); + } + EXPECT_EQ(deleted_big, 1); + EXPECT_EQ(deleted_small, 2); + + deleted_big = 0; + deleted_small = 0; + { + Variant x = DeleteCounter(&deleted_small); + Variant y = DeleteCounter(&deleted_big); + EXPECT_EQ(deleted_big, 0); + EXPECT_EQ(deleted_small, 0); + x = y; + EXPECT_EQ(deleted_big, 0); + EXPECT_EQ(deleted_small, 1); + x = x; + EXPECT_EQ(deleted_big, 0); + EXPECT_EQ(deleted_small, 1); + EXPECT_NE(x.get>(), nullptr); + EXPECT_NE(y.get>(), nullptr); + x = std::move(y); + EXPECT_EQ(deleted_big, 1); + EXPECT_NE(x.get>(), nullptr); + } + EXPECT_EQ(deleted_big, 2); + EXPECT_EQ(deleted_small, 1); +} + TEST(VariantTest, Int) { Variant x; EXPECT_EQ(x.get(), nullptr); @@ -49,45 +168,125 @@ TEST(VariantTest, Int) { EXPECT_EQ(x.TypeName(), "int"); } -TEST(VariantTest, Basic) { +struct MayCreateAlignmentDifficulties { + int a; + __m128 b; +}; + +bool M128AllEqual(const __m128& a, const __m128& b) { + return _mm_movemask_ps(_mm_cmpeq_ps(a, b)) == 0xf; +} + +TEST(VariantTest, NotAlignable) { + Variant x; + EXPECT_EQ(x.get(), nullptr); + __m128 v = _mm_set_ps(1.0, 2.0, 3.0, 4.0); + x = MayCreateAlignmentDifficulties{-1, v}; + EXPECT_NE(x.get(), nullptr); + auto* x_val = x.get(); + // check that *x_val == x + Variant y = x; + EXPECT_EQ(x_val->a, -1); + EXPECT_TRUE(M128AllEqual(x_val->b, v)); + auto* y_val = y.get(); + EXPECT_EQ(y_val->a, -1); + EXPECT_TRUE(M128AllEqual(y_val->b, v)); + Variant z = std::move(y); + auto* z_val = z.get(); + EXPECT_EQ(z_val->a, -1); + EXPECT_TRUE(M128AllEqual(z_val->b, v)); +} + +template +void TestBasic() { Variant x; EXPECT_EQ(x.get(), nullptr); - x = Int{42}; + x = Int{42}; EXPECT_NE(x.get(), nullptr); - EXPECT_NE(x.get(), nullptr); - EXPECT_EQ(x.get()->value, 42); + EXPECT_NE(x.get>(), nullptr); + EXPECT_EQ(x.get>()->value, 42); EXPECT_EQ(x.TypeName(), "POD"); } -TEST(VariantTest, ConstGet) { +TEST(VariantTest, Basic) { TestBasic(); } + +TEST(VariantTest, BasicBig) { TestBasic(); } + +template +void TestConstGet() { Variant x; EXPECT_EQ(x.get(), nullptr); - x = Int{42}; + x = Int{42}; const Variant y = x; EXPECT_NE(y.get(), nullptr); - EXPECT_NE(y.get(), nullptr); - EXPECT_EQ(y.get()->value, 42); + EXPECT_NE(y.get>(), nullptr); + EXPECT_EQ(y.get>()->value, 42); } -TEST(VariantTest, Clear) { +TEST(VariantTest, ConstGet) { TestConstGet(); } + +TEST(VariantTest, ConstGetBig) { TestConstGet(); } + +template +void TestClear() { Variant x; EXPECT_EQ(x.get(), nullptr); - x = Int{42}; + x = Int{42}; EXPECT_NE(x.get(), nullptr); - EXPECT_NE(x.get(), nullptr); - EXPECT_EQ(x.get()->value, 42); + EXPECT_NE(x.get>(), nullptr); + EXPECT_EQ(x.get>()->value, 42); x.clear(); EXPECT_EQ(x.get(), nullptr); } +TEST(VariantTest, Clear) { TestClear(); } + +TEST(VariantTest, ClearBig) { TestClear(); } + +template +void TestClearDeletes() { + Variant x; + EXPECT_EQ(x.get(), nullptr); + + int deleted_count = 0; + using DC = DeleteCounter; + DC dc(&deleted_count); + EXPECT_EQ(deleted_count, 0); + x = dc; + EXPECT_EQ(deleted_count, 0); + + EXPECT_NE(x.get(), nullptr); + EXPECT_NE(x.get(), nullptr); + + x.clear(); + EXPECT_EQ(x.get(), nullptr); + EXPECT_EQ(deleted_count, 1); + + x = dc; + EXPECT_EQ(deleted_count, 1); + + Variant y = x; + EXPECT_EQ(deleted_count, 1); + + x.clear(); + EXPECT_EQ(deleted_count, 2); + + y.clear(); + EXPECT_EQ(deleted_count, 3); +} + +TEST(VariantTest, ClearDeletesOnHeap) { TestClearDeletes(); } + +TEST(VariantTest, ClearDeletesOnStack) { TestClearDeletes(); } + TEST(VariantTest, Tensor) { Variant x; Tensor t(DT_FLOAT, {}); @@ -101,6 +300,16 @@ TEST(VariantTest, Tensor) { EXPECT_EQ(x.TypeName(), "tensorflow::Tensor"); } +TEST(VariantTest, NontrivialTensorVariantCopy) { + Tensor variants(DT_VARIANT, {}); + Tensor t(true); + test::FillValues(&variants, gtl::ArraySlice({t})); + const Tensor* t_c = variants.flat()(0).get(); + EXPECT_EQ(t_c->dtype(), t.dtype()); + EXPECT_EQ(t_c->shape(), t.shape()); + EXPECT_EQ(t_c->scalar()(), t.scalar()()); +} + TEST(VariantTest, TensorProto) { Variant x; TensorProto t; @@ -114,31 +323,41 @@ TEST(VariantTest, TensorProto) { EXPECT_EQ(x.get()->tensor_shape().unknown_rank(), true); } -TEST(VariantTest, CopyValue) { +template +void TestCopyValue() { Variant x, y; - x = Int{10}; + x = Int{10}; y = x; - EXPECT_EQ(x.get()->value, 10); - EXPECT_EQ(x.get()->value, y.get()->value); + EXPECT_EQ(x.get>()->value, 10); + EXPECT_EQ(x.get>()->value, y.get>()->value); } -TEST(VariantTest, MoveValue) { +TEST(VariantTest, CopyValue) { TestCopyValue(); } + +TEST(VariantTest, CopyValueBig) { TestCopyValue(); } + +template +void TestMoveValue() { Variant x; x = []() -> Variant { Variant y; - y = Int{10}; + y = Int{10}; return y; }(); - EXPECT_EQ(x.get()->value, 10); + EXPECT_EQ(x.get>()->value, 10); } +TEST(VariantTest, MoveValue) { TestMoveValue(); } + +TEST(VariantTest, MoveValueBig) { TestMoveValue(); } + TEST(VariantTest, TypeMismatch) { Variant x; - x = Int{10}; + x = Int{10}; EXPECT_EQ(x.get(), nullptr); EXPECT_EQ(x.get(), nullptr); - EXPECT_NE(x.get(), nullptr); + EXPECT_NE(x.get>(), nullptr); } struct TensorList { @@ -206,19 +425,26 @@ TEST(VariantTest, TensorListTest) { "Variant")); } -TEST(VariantTest, VariantArray) { +template +void TestVariantArray() { Variant x[2]; - x[0] = Int{2}; - x[1] = Float{2.0f}; + x[0] = Int{2}; + x[1] = Float{2.0f}; - EXPECT_EQ(x[0].get()->value, 2); - EXPECT_EQ(x[1].get()->value, 2.0f); + EXPECT_EQ(x[0].get>()->value, 2); + EXPECT_EQ(x[1].get>()->value, 2.0f); } -TEST(VariantTest, PodUpdate) { +TEST(VariantTest, VariantArray) { TestVariantArray(); } + +TEST(VariantTest, VariantArrayBig) { TestVariantArray(); } + +template +void PodUpdateTest() { struct Pod { int x; float y; + char big[BIG ? 256 : 0]; string TypeName() const { return "POD"; } }; @@ -232,10 +458,16 @@ TEST(VariantTest, PodUpdate) { EXPECT_EQ(x.get()->x, 30); } -TEST(VariantTest, EncodeDecodePod) { +TEST(VariantTest, PodUpdate) { PodUpdateTest(); } + +TEST(VariantTest, PodUpdateBig) { PodUpdateTest(); } + +template +void TestEncodeDecodePod() { struct Pod { int x; float y; + char big[BIG ? 256 : 0]; string TypeName() const { return "POD"; } }; @@ -247,14 +479,17 @@ TEST(VariantTest, EncodeDecodePod) { VariantTensorData serialized; x.Encode(&serialized); - Variant y; - y = Pod(); + Variant y = Pod{}; y.Decode(serialized); EXPECT_EQ(p.x, y.get()->x); EXPECT_EQ(p.y, y.get()->y); } +TEST(VariantTest, EncodeDecodePod) { TestEncodeDecodePod(); } + +TEST(VariantTest, EncodeDecodePodBig) { TestEncodeDecodePod(); } + TEST(VariantTest, EncodeDecodeTensor) { Variant x; Tensor t(DT_INT32, {}); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 65b13b7c1eb..59b90db917b 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -237,6 +237,8 @@ class IteratorResource : public ResourceBase { // destroyed, essentially triggering the iterator deletion. class Deleter { public: + Deleter() : deleter_() {} + Deleter(ResourceHandle handle, ResourceMgr* resource_manager) : deleter_(std::make_shared(handle, resource_manager)) {} @@ -248,6 +250,10 @@ class IteratorResource : public ResourceBase { VLOG(3) << "IteratorResource::Deleter copy constructor called."; } + Deleter& operator=(const Deleter& rhs) = delete; + + Deleter& operator=(Deleter&& rhs) = default; + virtual ~Deleter() { VLOG(3) << "IteratorResource::Deleter destructor called."; } @@ -358,6 +364,9 @@ class IteratorStateVariant { Decode(*other.data_); } } + IteratorStateVariant& operator=(IteratorStateVariant&& other) = default; + IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete; + // Initializes this object with the current state of the iterator so // that it can be written on the next call to Encode(). Status InitializeFromIterator(OpKernelContext* ctx, diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc index 2f4a5e9aa03..0cc29b42d93 100644 --- a/tensorflow/core/kernels/mutex_ops.cc +++ b/tensorflow/core/kernels/mutex_ops.cc @@ -74,6 +74,8 @@ class Mutex : public ResourceBase { struct SharedLockReleaser { std::shared_ptr shared_lock; + SharedLockReleaser() : shared_lock() {} + explicit SharedLockReleaser(std::shared_ptr&& lock) : shared_lock(std::forward(lock)) { VLOG(3) << "Creating shared_ptr of " << shared_lock.get() @@ -86,6 +88,16 @@ class Mutex : public ResourceBase { << " count is: " << shared_lock.use_count(); } + SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete; + + SharedLockReleaser& operator=(SharedLockReleaser&& rhs) { + if (&rhs == this) return *this; + std::swap(shared_lock, rhs.shared_lock); + VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + return *this; + } + SharedLockReleaser(const SharedLockReleaser& rhs) : shared_lock(rhs.shared_lock) { VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()