[TF] Update Variant object to store small objects inline.
PiperOrigin-RevId: 247047534
This commit is contained in:
parent
b17b0103a9
commit
2f345d145e
@ -50,6 +50,14 @@ class DatasetVariantWrapper {
|
|||||||
if (dataset_) dataset_->Ref();
|
if (dataset_) dataset_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
|
||||||
|
if (&other == this) return *this;
|
||||||
|
std::swap(dataset_, other.dataset_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
|
||||||
|
|
||||||
~DatasetVariantWrapper() {
|
~DatasetVariantWrapper() {
|
||||||
if (dataset_) dataset_->Unref();
|
if (dataset_) dataset_->Unref();
|
||||||
}
|
}
|
||||||
@ -75,7 +83,7 @@ class DatasetVariantWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DatasetBase* const dataset_; // Owns one reference.
|
DatasetBase* dataset_; // Owns one reference.
|
||||||
};
|
};
|
||||||
|
|
||||||
const char kWrappedDatasetVariantTypeName[] =
|
const char kWrappedDatasetVariantTypeName[] =
|
||||||
|
@ -43,6 +43,7 @@ class OpKernelContext;
|
|||||||
class Tensor;
|
class Tensor;
|
||||||
class TensorBuffer;
|
class TensorBuffer;
|
||||||
class TensorCApi;
|
class TensorCApi;
|
||||||
|
class TensorCord;
|
||||||
class TensorDescription;
|
class TensorDescription;
|
||||||
class TensorProto;
|
class TensorProto;
|
||||||
class Var;
|
class Var;
|
||||||
@ -607,6 +608,7 @@ class Tensor {
|
|||||||
|
|
||||||
friend class DMAHelper;
|
friend class DMAHelper;
|
||||||
friend class TensorCApi;
|
friend class TensorCApi;
|
||||||
|
friend class TensorCord; // For access to buf_
|
||||||
friend class TensorReference; // For access to buf_
|
friend class TensorReference; // For access to buf_
|
||||||
friend class VariableOp; // For access to set_shape
|
friend class VariableOp; // For access to set_shape
|
||||||
friend class AutoReloadVariableOp; // For access to set_shape
|
friend class AutoReloadVariableOp; // For access to set_shape
|
||||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Variant::~Variant() { clear(); }
|
||||||
|
|
||||||
bool Variant::Decode(VariantTensorData data) {
|
bool Variant::Decode(VariantTensorData data) {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
return value_->Decode(std::move(data));
|
return GetValue()->Decode(std::move(data));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -35,7 +37,7 @@ void* Variant::get() {
|
|||||||
if (is_empty()) {
|
if (is_empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return value_->RawPtr();
|
return GetValue()->RawPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -43,7 +45,7 @@ const void* Variant::get() const {
|
|||||||
if (is_empty()) {
|
if (is_empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return value_->RawPtr();
|
return GetValue()->RawPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/framework/type_index.h"
|
#include "tensorflow/core/framework/type_index.h"
|
||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -68,7 +69,7 @@ void EncodeVariant(const T& value, string* buf);
|
|||||||
//
|
//
|
||||||
// string TypeName() const;
|
// string TypeName() const;
|
||||||
// void Encode(VariantTensorData* data) const;
|
// void Encode(VariantTensorData* data) const;
|
||||||
// void Decode(VariantTensorData data);
|
// bool Decode(VariantTensorData data);
|
||||||
//
|
//
|
||||||
// Simple POD types can elide the Encode/Decode functions, they are provided by
|
// Simple POD types can elide the Encode/Decode functions, they are provided by
|
||||||
// helper methods.
|
// helper methods.
|
||||||
@ -149,39 +150,57 @@ void EncodeVariant(const T& value, string* buf);
|
|||||||
//
|
//
|
||||||
class Variant {
|
class Variant {
|
||||||
public:
|
public:
|
||||||
constexpr Variant() noexcept = default;
|
Variant() noexcept : is_inline_(false) {}
|
||||||
|
|
||||||
Variant(const Variant& other)
|
~Variant();
|
||||||
: value_(other.is_empty() ? std::unique_ptr<ValueInterface>()
|
|
||||||
: other.value_->Clone()) {}
|
|
||||||
|
|
||||||
Variant(Variant&& other) noexcept = default;
|
Variant(const Variant& other);
|
||||||
|
Variant(Variant&& other) noexcept;
|
||||||
|
|
||||||
|
// Make sure that the type is CopyConstructible and not a
|
||||||
|
// tensorflow::Variant object itself. We want the copy constructor to be
|
||||||
|
// chosen for the tensorflow::Variant case.
|
||||||
|
template <typename T, typename VT = typename std::decay<T>::type,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_move_constructible<VT>::value,
|
||||||
|
void>::type* = nullptr>
|
||||||
|
Variant(T&& value);
|
||||||
|
|
||||||
// Make sure that the type is CopyConstructible and not a tensorflow::Variant
|
|
||||||
// object itself. We want the copy constructor to be chosen for the
|
|
||||||
// tensorflow::Variant case.
|
|
||||||
template <typename T, typename VT = typename std::decay<T>::type,
|
template <typename T, typename VT = typename std::decay<T>::type,
|
||||||
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
std::is_copy_constructible<VT>::value,
|
std::is_copy_constructible<VT>::value,
|
||||||
void>::type* = nullptr>
|
void>::type* = nullptr>
|
||||||
Variant(T&& value) // NOLINT
|
Variant(const T& value);
|
||||||
: value_(new Value<VT>(in_place, std::forward<T>(value))) {}
|
|
||||||
|
template <typename T, typename VT = typename std::decay<T>::type,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_copy_constructible<VT>::value,
|
||||||
|
void>::type* = nullptr>
|
||||||
|
Variant& operator=(const T& value);
|
||||||
|
|
||||||
|
template <typename T, typename VT = typename std::decay<T>::type,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_move_constructible<VT>::value,
|
||||||
|
void>::type* = nullptr>
|
||||||
|
Variant& operator=(T&& value);
|
||||||
|
|
||||||
Variant& operator=(const Variant& rhs) {
|
Variant& operator=(const Variant& rhs) {
|
||||||
|
if (&rhs == this) return *this;
|
||||||
Variant(rhs).swap(*this);
|
Variant(rhs).swap(*this);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Variant& operator=(Variant&& rhs) noexcept {
|
Variant& operator=(Variant&& rhs) noexcept {
|
||||||
|
if (&rhs == this) return *this;
|
||||||
Variant(std::move(rhs)).swap(*this);
|
Variant(std::move(rhs)).swap(*this);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_empty() const { return value_ == nullptr; }
|
bool is_empty() const { return GetValue() == nullptr; }
|
||||||
|
|
||||||
void clear() noexcept { value_.reset(); }
|
void clear() noexcept;
|
||||||
|
|
||||||
void swap(Variant& other) noexcept { value_.swap(other.value_); }
|
void swap(Variant& other) noexcept;
|
||||||
|
|
||||||
// Note, unlike TypeName(), TypeId() does not return the TypeIndex
|
// Note, unlike TypeName(), TypeId() does not return the TypeIndex
|
||||||
// of the original type when a TensorValueDataProto is stored as the
|
// of the original type when a TensorValueDataProto is stored as the
|
||||||
@ -191,12 +210,13 @@ class Variant {
|
|||||||
if (is_empty()) {
|
if (is_empty()) {
|
||||||
return VoidTypeIndex;
|
return VoidTypeIndex;
|
||||||
}
|
}
|
||||||
return value_->TypeId();
|
return GetValue()->TypeId();
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const {
|
string DebugString() const {
|
||||||
return strings::StrCat("Variant<type: ", TypeName(),
|
return strings::StrCat(
|
||||||
" value: ", value_->DebugString(), ">");
|
"Variant<type: ", TypeName(),
|
||||||
|
" value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a pointer to the stored value if it is type T, or nullptr
|
// Returns a pointer to the stored value if it is type T, or nullptr
|
||||||
@ -205,7 +225,7 @@ class Variant {
|
|||||||
T* get() {
|
T* get() {
|
||||||
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
||||||
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
||||||
return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value);
|
return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a pointer to the stored value if it is type T, or nullptr
|
// Returns a pointer to the stored value if it is type T, or nullptr
|
||||||
@ -215,7 +235,7 @@ class Variant {
|
|||||||
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
||||||
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
||||||
return std::addressof(
|
return std::addressof(
|
||||||
static_cast<const Variant::Value<T>*>(value_.get())->value);
|
static_cast<const Variant::Value<T>*>(GetValue())->value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns TypeNameVariant(value).
|
// Returns TypeNameVariant(value).
|
||||||
@ -227,13 +247,13 @@ class Variant {
|
|||||||
if (is_empty()) {
|
if (is_empty()) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
return value_->TypeName();
|
return GetValue()->TypeName();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize the contents of the stored object into `data`.
|
// Serialize the contents of the stored object into `data`.
|
||||||
void Encode(VariantTensorData* data) const {
|
void Encode(VariantTensorData* data) const {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
value_->Encode(data);
|
GetValue()->Encode(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,26 +263,36 @@ class Variant {
|
|||||||
// Helper methods to directly serialize/deserialize from strings.
|
// Helper methods to directly serialize/deserialize from strings.
|
||||||
void Encode(string* buf) const {
|
void Encode(string* buf) const {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
value_->Encode(buf);
|
GetValue()->Encode(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Decode(string buf) {
|
bool Decode(string buf) {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
return value_->Decode(std::move(buf));
|
return GetValue()->Decode(std::move(buf));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename VT>
|
||||||
|
static constexpr bool CanInlineType() {
|
||||||
|
return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
|
||||||
|
(alignof(Value<VT>) <= kMaxInlineValueAlignSize));
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct in_place_t {};
|
struct in_place_t {};
|
||||||
static constexpr in_place_t in_place{};
|
static constexpr in_place_t kInPlace{};
|
||||||
|
|
||||||
struct ValueInterface {
|
struct ValueInterface {
|
||||||
virtual ~ValueInterface() = default;
|
virtual ~ValueInterface() = default;
|
||||||
virtual TypeIndex TypeId() const = 0;
|
virtual TypeIndex TypeId() const = 0;
|
||||||
virtual void* RawPtr() = 0;
|
virtual void* RawPtr() = 0;
|
||||||
virtual const void* RawPtr() const = 0;
|
virtual const void* RawPtr() const = 0;
|
||||||
virtual std::unique_ptr<ValueInterface> Clone() const = 0;
|
virtual ValueInterface* Clone() const = 0;
|
||||||
|
virtual void CloneInto(ValueInterface* memory) const = 0;
|
||||||
|
virtual void DefaultConstructIn(ValueInterface* memory) const = 0;
|
||||||
|
virtual void Swap(ValueInterface* memory) = 0;
|
||||||
|
virtual void MoveTo(ValueInterface* memory) = 0;
|
||||||
virtual string TypeName() const = 0;
|
virtual string TypeName() const = 0;
|
||||||
virtual string DebugString() const = 0;
|
virtual string DebugString() const = 0;
|
||||||
virtual void Encode(VariantTensorData* data) const = 0;
|
virtual void Encode(VariantTensorData* data) const = 0;
|
||||||
@ -277,6 +307,10 @@ class Variant {
|
|||||||
explicit Value(in_place_t /*tag*/, Args&&... args)
|
explicit Value(in_place_t /*tag*/, Args&&... args)
|
||||||
: value(std::forward<Args>(args)...) {}
|
: value(std::forward<Args>(args)...) {}
|
||||||
|
|
||||||
|
// NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
|
||||||
|
// build `alignof(Variant<void*>)`.
|
||||||
|
~Value() final = default;
|
||||||
|
|
||||||
TypeIndex TypeId() const override {
|
TypeIndex TypeId() const override {
|
||||||
const TypeIndex value_type_index =
|
const TypeIndex value_type_index =
|
||||||
MakeTypeIndex<typename std::decay<T>::type>();
|
MakeTypeIndex<typename std::decay<T>::type>();
|
||||||
@ -287,8 +321,33 @@ class Variant {
|
|||||||
|
|
||||||
const void* RawPtr() const override { return &value; }
|
const void* RawPtr() const override { return &value; }
|
||||||
|
|
||||||
std::unique_ptr<ValueInterface> Clone() const override {
|
ValueInterface* Clone() const override {
|
||||||
return std::unique_ptr<ValueInterface>(new Value(in_place, value));
|
// NOTE: Use placement new here because we override `operator delete`,
|
||||||
|
// and need to match the call to `port::Free()` with a call to
|
||||||
|
// `port::Malloc()`.
|
||||||
|
auto* clone = static_cast<Value*>(port::Malloc(sizeof(Value)));
|
||||||
|
new (clone) Value(kInPlace, value);
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DefaultConstructIn(ValueInterface* memory) const override {
|
||||||
|
new (memory) Value(kInPlace, T());
|
||||||
|
}
|
||||||
|
|
||||||
|
void MoveTo(ValueInterface* memory) override {
|
||||||
|
CHECK(TypeId() == memory->TypeId())
|
||||||
|
<< TypeId().name() << " vs. " << memory->TypeId().name();
|
||||||
|
static_cast<Value*>(memory)->value = std::move(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CloneInto(ValueInterface* memory) const override {
|
||||||
|
new (memory) Value(kInPlace, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Swap(ValueInterface* memory) override {
|
||||||
|
CHECK(TypeId() == memory->TypeId())
|
||||||
|
<< TypeId().name() << " vs. " << memory->TypeId().name();
|
||||||
|
std::swap(value, static_cast<Value*>(memory)->value);
|
||||||
}
|
}
|
||||||
|
|
||||||
string TypeName() const override { return TypeNameVariant(value); }
|
string TypeName() const override { return TypeNameVariant(value); }
|
||||||
@ -307,14 +366,368 @@ class Variant {
|
|||||||
|
|
||||||
bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
|
bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
|
||||||
|
|
||||||
|
// We override operator delete in order to selectively free memory
|
||||||
|
// depending on if Value<VT> is stored inline or on the heap:
|
||||||
|
//
|
||||||
|
// Value<VT> is stored inline if its size <= InlineValue::kMaxValueSize and
|
||||||
|
// its alignment <= kMaxInlineValueAlignSize. This check is performed by
|
||||||
|
// CanInlineType<VT>().
|
||||||
|
//
|
||||||
|
// We only need to call its destructor in this case and then overwrite
|
||||||
|
// the inline memory with zeros. Variant::clear() does this.
|
||||||
|
// Thus, in the inline case, the delete operator does nothing (calling
|
||||||
|
// delete on the memory location calls the destructor only).
|
||||||
|
//
|
||||||
|
// If !CanInlineType<VT>(), then it is stored as a pointer inside HeapValue.
|
||||||
|
// The memory buffer it resides in on the heap was allocated with
|
||||||
|
// port::Malloc, and it should be deallocated via port::Free.
|
||||||
|
//
|
||||||
|
// operator delete is stored in the vtable since ~ValueInterface is a
|
||||||
|
// virtual destructor; furthermore it has access to VT and can calculate
|
||||||
|
// CanInlineType<VT>().
|
||||||
|
static void operator delete(void* ptr);
|
||||||
|
|
||||||
|
static void operator delete(void*, void*) {
|
||||||
|
// Some compilers require an overridden class-specific deallocation
|
||||||
|
// function, which will be called if placement `new` throws an
|
||||||
|
// exception.
|
||||||
|
}
|
||||||
|
|
||||||
T value;
|
T value;
|
||||||
};
|
};
|
||||||
|
static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
|
||||||
|
|
||||||
|
using HeapValue = std::unique_ptr<ValueInterface>;
|
||||||
|
|
||||||
|
struct InlineValue {
|
||||||
|
// We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
|
||||||
|
// into the aligned space of a TensorBuffer.
|
||||||
|
static constexpr int kMaxValueSize = (64 - /*some extra padding=*/16);
|
||||||
|
|
||||||
|
typedef char ValueDataArray[kMaxValueSize];
|
||||||
|
alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
|
||||||
|
bool has_value = false;
|
||||||
|
|
||||||
|
explicit InlineValue() {}
|
||||||
|
|
||||||
|
InlineValue(const InlineValue& other) noexcept
|
||||||
|
: has_value(other.has_value) {
|
||||||
|
if (other.has_value) {
|
||||||
|
other.AsValueInterface()->CloneInto(AsValueInterface());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InlineValue(InlineValue&& other) noexcept : has_value(other.has_value) {
|
||||||
|
if (other.has_value) {
|
||||||
|
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
|
||||||
|
other.AsValueInterface()->MoveTo(AsValueInterface());
|
||||||
|
other.Cleanup();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cleanup() {
|
||||||
|
// **NOTE** This must be a no-op if the memory representation of
|
||||||
|
// InlineValue is all zeros, in order to properly interact with
|
||||||
|
// HeapOrInline::ResetMemory().
|
||||||
|
if (has_value) {
|
||||||
|
// This doesn't actually delete anything on the heap; the delete
|
||||||
|
// operator of Value<VT> is overridden to do nothing for inline
|
||||||
|
// values; the side-effect of delete is that the virtual destructor is
|
||||||
|
// called.
|
||||||
|
//
|
||||||
|
// We leave it to callers to overwrite the data buffer in value_data
|
||||||
|
// with new objects.
|
||||||
|
delete AsValueInterface();
|
||||||
|
}
|
||||||
|
has_value = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
InlineValue& operator=(const InlineValue& other) {
|
||||||
|
if (&other == this) return *this;
|
||||||
|
Cleanup();
|
||||||
|
if (other.has_value) {
|
||||||
|
other.AsValueInterface()->CloneInto(AsValueInterface());
|
||||||
|
}
|
||||||
|
has_value = other.has_value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
InlineValue& operator=(InlineValue&& other) {
|
||||||
|
if (&other == this) return *this;
|
||||||
|
if (other.has_value) {
|
||||||
|
if (has_value && AsValueInterface()->TypeId() ==
|
||||||
|
other.AsValueInterface()->TypeId()) {
|
||||||
|
other.AsValueInterface()->Swap(AsValueInterface());
|
||||||
|
} else {
|
||||||
|
if (has_value) {
|
||||||
|
if (AsValueInterface()->TypeId() !=
|
||||||
|
other.AsValueInterface()->TypeId()) {
|
||||||
|
Cleanup();
|
||||||
|
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
|
||||||
|
}
|
||||||
|
other.AsValueInterface()->MoveTo(AsValueInterface());
|
||||||
|
} else {
|
||||||
|
other.AsValueInterface()->DefaultConstructIn(AsValueInterface());
|
||||||
|
other.AsValueInterface()->MoveTo(AsValueInterface());
|
||||||
|
}
|
||||||
|
other.Cleanup();
|
||||||
|
has_value = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Cleanup();
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueInterface* AsValueInterface() {
|
||||||
|
return reinterpret_cast<ValueInterface*>(value_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ValueInterface* AsValueInterface() const {
|
||||||
|
return reinterpret_cast<const ValueInterface*>(value_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// **WARNING** This must be a no-op when the byte-representation of
|
||||||
|
// InlineValue is all zeros.
|
||||||
|
~InlineValue() { Cleanup(); }
|
||||||
|
};
|
||||||
|
|
||||||
// value_ can point to any type T as wrapped by a ValueInterface.
|
// value_ can point to any type T as wrapped by a ValueInterface.
|
||||||
// The only real requirement is that T is default-constructible.
|
// The only real requirement is that T is default-constructible.
|
||||||
std::unique_ptr<ValueInterface> value_;
|
union HeapOrInline {
|
||||||
|
HeapOrInline() { ResetMemory(); }
|
||||||
|
explicit HeapOrInline(HeapValue&& v) : heap_value(std::move(v)) {}
|
||||||
|
explicit HeapOrInline(InlineValue&& v) : inline_value(std::move(v)) {}
|
||||||
|
~HeapOrInline() {} // Taken care of by owner.
|
||||||
|
|
||||||
|
// This must be called when modifying which element of HeapOrInline is
|
||||||
|
// being used, because the destructor of the new class may be called
|
||||||
|
// while the memory is still a representation of the old class.
|
||||||
|
// **WARNING** This code assumes that the destructors of HeapValue and
|
||||||
|
// InlineValue are no-ops when the internal representation is zeros.
|
||||||
|
//
|
||||||
|
// Example of when this is needed:
|
||||||
|
// value.heap_value = HeapValue(...);
|
||||||
|
// // Segfault. This calls InlineValue::Cleanup on value.inline_value
|
||||||
|
// // but the internal memory representation is that of HeapValue.
|
||||||
|
// value.inline_value = InlineValue();
|
||||||
|
//
|
||||||
|
// The correct way to do this:
|
||||||
|
// value.heap_value = HeapValue(...);
|
||||||
|
// value.ResetMemory();
|
||||||
|
// value.inline_value = InlineValue();
|
||||||
|
void ResetMemory();
|
||||||
|
|
||||||
|
HeapValue heap_value;
|
||||||
|
InlineValue inline_value;
|
||||||
|
} value_;
|
||||||
|
bool is_inline_;
|
||||||
|
|
||||||
|
bool IsInlineValue() const { return is_inline_; }
|
||||||
|
|
||||||
|
ValueInterface* GetValue() {
|
||||||
|
if (IsInlineValue()) {
|
||||||
|
return value_.inline_value.AsValueInterface();
|
||||||
|
} else {
|
||||||
|
return value_.heap_value.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ValueInterface* GetValue() const {
|
||||||
|
if (IsInlineValue()) {
|
||||||
|
return value_.inline_value.AsValueInterface();
|
||||||
|
} else {
|
||||||
|
return value_.heap_value.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PRECONDITION: Called on construction or clear() has been called before
|
||||||
|
// this method.
|
||||||
|
template <typename T, typename VT>
|
||||||
|
void InsertValueMove(T&& value) {
|
||||||
|
if (is_inline_) {
|
||||||
|
Value<VT>* inline_value_data =
|
||||||
|
reinterpret_cast<Value<VT>*>(value_.inline_value.value_data);
|
||||||
|
new (inline_value_data) Value<VT>(kInPlace, std::forward<T>(value));
|
||||||
|
value_.inline_value.has_value = true;
|
||||||
|
} else {
|
||||||
|
auto* moved = static_cast<Value<VT>*>(port::Malloc(sizeof(Value<VT>)));
|
||||||
|
new (moved) Value<VT>(kInPlace, std::forward<T>(value));
|
||||||
|
value_.heap_value = HeapValue(moved);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PRECONDITION: Called on construction or clear() has been called before
|
||||||
|
// this method.
|
||||||
|
template <typename T, typename VT>
|
||||||
|
void InsertValueCopy(const T& value) {
|
||||||
|
if (is_inline_) {
|
||||||
|
Value<VT>* inline_value_data =
|
||||||
|
reinterpret_cast<Value<VT>*>(value_.inline_value.value_data);
|
||||||
|
new (inline_value_data) Value<VT>(kInPlace, value);
|
||||||
|
value_.inline_value.has_value = true;
|
||||||
|
} else {
|
||||||
|
auto* moved = static_cast<Value<VT>*>(port::Malloc(sizeof(Value<VT>)));
|
||||||
|
new (moved) Value<VT>(kInPlace, value);
|
||||||
|
value_.heap_value = HeapValue(moved);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Make sure that a Variant object can reside in a 64-byte aligned Tensor
|
||||||
|
// buffer.
|
||||||
|
static_assert(sizeof(Variant) <= 64,
|
||||||
|
"Expected internal representation to be 64 bytes.");
|
||||||
|
|
||||||
|
inline Variant::Variant(const Variant& other) : is_inline_(other.is_inline_) {
|
||||||
|
if (!other.is_empty()) {
|
||||||
|
if (other.IsInlineValue()) {
|
||||||
|
value_.inline_value = InlineValue();
|
||||||
|
other.GetValue()->CloneInto(GetValue());
|
||||||
|
value_.inline_value.has_value = true;
|
||||||
|
} else {
|
||||||
|
value_.heap_value = HeapValue(other.GetValue()->Clone());
|
||||||
|
is_inline_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Variant::Variant(Variant&& other) noexcept
|
||||||
|
: is_inline_(other.is_inline_) {
|
||||||
|
if (!other.is_empty()) {
|
||||||
|
if (other.IsInlineValue()) {
|
||||||
|
value_.inline_value = InlineValue();
|
||||||
|
other.GetValue()->DefaultConstructIn(GetValue());
|
||||||
|
other.GetValue()->MoveTo(GetValue());
|
||||||
|
value_.inline_value.has_value = true;
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.is_inline_ = false;
|
||||||
|
} else {
|
||||||
|
value_.heap_value = std::move(other.value_.heap_value);
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
is_inline_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename VT>
|
||||||
|
void Variant::Value<VT>::operator delete(void* ptr) {
|
||||||
|
if (!CanInlineType<VT>()) port::Free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_move_constructible<VT>::value,
|
||||||
|
void>::type*>
|
||||||
|
inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
|
||||||
|
InsertValueMove<T, VT>(std::forward<T>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_copy_constructible<VT>::value,
|
||||||
|
void>::type*>
|
||||||
|
inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
|
||||||
|
InsertValueCopy<T, VT>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_move_constructible<VT>::value,
|
||||||
|
void>::type*>
|
||||||
|
inline Variant& Variant::operator=(T&& value) {
|
||||||
|
clear();
|
||||||
|
is_inline_ = CanInlineType<VT>();
|
||||||
|
InsertValueMove<T, VT>(std::forward<T>(value));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT,
|
||||||
|
typename std::enable_if<!std::is_same<Variant, VT>::value &&
|
||||||
|
std::is_copy_constructible<VT>::value,
|
||||||
|
void>::type*>
|
||||||
|
inline Variant& Variant::operator=(const T& value) {
|
||||||
|
clear();
|
||||||
|
is_inline_ = CanInlineType<VT>();
|
||||||
|
InsertValueCopy<T, VT>(value);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Variant::HeapOrInline::ResetMemory() {
|
||||||
|
memset( // NOLINT: not TriviallyCopyable
|
||||||
|
this, 0, sizeof(Variant::HeapOrInline));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Variant::clear() noexcept {
|
||||||
|
if (!is_empty()) {
|
||||||
|
if (IsInlineValue()) {
|
||||||
|
value_.inline_value.~InlineValue();
|
||||||
|
} else {
|
||||||
|
value_.heap_value.~HeapValue();
|
||||||
|
}
|
||||||
|
value_.ResetMemory();
|
||||||
|
}
|
||||||
|
is_inline_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Variant::swap(Variant& other) noexcept {
|
||||||
|
if (is_empty()) {
|
||||||
|
if (other.IsInlineValue()) {
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.inline_value = std::move(other.value_.inline_value);
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.heap_value = HeapValue();
|
||||||
|
is_inline_ = true;
|
||||||
|
other.is_inline_ = false;
|
||||||
|
} else {
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.heap_value = std::move(other.value_.heap_value);
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.heap_value = HeapValue();
|
||||||
|
is_inline_ = false;
|
||||||
|
other.is_inline_ = false;
|
||||||
|
}
|
||||||
|
} else if (other.is_empty()) {
|
||||||
|
if (IsInlineValue()) {
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.inline_value = std::move(value_.inline_value);
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.heap_value = HeapValue();
|
||||||
|
other.is_inline_ = true;
|
||||||
|
is_inline_ = false;
|
||||||
|
} else {
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.heap_value = std::move(value_.heap_value);
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.heap_value = HeapValue();
|
||||||
|
other.is_inline_ = false;
|
||||||
|
is_inline_ = false;
|
||||||
|
}
|
||||||
|
} else { // Both Variants have values.
|
||||||
|
if (other.IsInlineValue() && IsInlineValue()) {
|
||||||
|
std::swap(value_.inline_value, other.value_.inline_value);
|
||||||
|
} else if (!other.IsInlineValue() && !IsInlineValue()) {
|
||||||
|
std::swap(value_.heap_value, other.value_.heap_value);
|
||||||
|
} else if (other.IsInlineValue() && !IsInlineValue()) {
|
||||||
|
HeapValue v = std::move(value_.heap_value);
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.inline_value = std::move(other.value_.inline_value);
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.heap_value = std::move(v);
|
||||||
|
is_inline_ = true;
|
||||||
|
other.is_inline_ = false;
|
||||||
|
} else { // !other.IsInlineValue() && IsInlineValue()
|
||||||
|
HeapValue v = std::move(other.value_.heap_value);
|
||||||
|
other.value_.ResetMemory();
|
||||||
|
other.value_.inline_value = std::move(value_.inline_value);
|
||||||
|
value_.ResetMemory();
|
||||||
|
value_.heap_value = std::move(v);
|
||||||
|
is_inline_ = false;
|
||||||
|
other.is_inline_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void* Variant::get();
|
void* Variant::get();
|
||||||
|
|
||||||
|
@ -62,6 +62,10 @@ class VariantTensorData {
|
|||||||
return GetMetadata<T>(value, PODResolver<T>());
|
return GetMetadata<T>(value, PODResolver<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string& metadata_string() { return metadata_; }
|
||||||
|
|
||||||
|
const string& metadata_string() const { return metadata_; }
|
||||||
|
|
||||||
// Tensors contained within objects being serialized.
|
// Tensors contained within objects being serialized.
|
||||||
int tensors_size() const;
|
int tensors_size() const;
|
||||||
const Tensor& tensors(int index) const;
|
const Tensor& tensors(int index) const;
|
||||||
|
@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/variant.h"
|
#include "tensorflow/core/framework/variant.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
|
||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include <xmmintrin.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -29,17 +32,133 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool BIG>
|
||||||
struct Wrapper {
|
struct Wrapper {
|
||||||
T value;
|
T value;
|
||||||
|
char big[BIG ? 256 : 0];
|
||||||
string TypeName() const { return "POD"; }
|
string TypeName() const { return "POD"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
using Int = Wrapper<int>;
|
template <bool BIG>
|
||||||
using Float = Wrapper<float>;
|
using Int = Wrapper<int, BIG>;
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
using Float = Wrapper<float, BIG>;
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
class DeleteCounter {
|
||||||
|
public:
|
||||||
|
DeleteCounter() : big_{}, counter_(nullptr) {}
|
||||||
|
explicit DeleteCounter(int* counter) : big_{}, counter_(counter) {}
|
||||||
|
~DeleteCounter() {
|
||||||
|
if (counter_) ++*counter_;
|
||||||
|
}
|
||||||
|
// Need custom move operations because int* just gets copied on move, but we
|
||||||
|
// need to clear counter_ on move.
|
||||||
|
DeleteCounter& operator=(const DeleteCounter& rhs) = default;
|
||||||
|
DeleteCounter& operator=(DeleteCounter&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
counter_ = rhs.counter_;
|
||||||
|
rhs.counter_ = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
DeleteCounter(DeleteCounter&& rhs) {
|
||||||
|
counter_ = rhs.counter_;
|
||||||
|
rhs.counter_ = nullptr;
|
||||||
|
}
|
||||||
|
DeleteCounter(const DeleteCounter& rhs) = default;
|
||||||
|
char big_[BIG ? 256 : 0];
|
||||||
|
int* counter_;
|
||||||
|
|
||||||
|
string TypeName() const { return "DeleteCounter"; }
|
||||||
|
void Encode(VariantTensorData* data) const {}
|
||||||
|
bool Decode(VariantTensorData data) { return false; }
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
|
TEST(VariantTest, MoveAndCopyBetweenBigAndSmall) {
|
||||||
|
Variant x;
|
||||||
|
int deleted_big = 0;
|
||||||
|
int deleted_small = 0;
|
||||||
|
x = DeleteCounter</*BIG=*/true>(&deleted_big);
|
||||||
|
EXPECT_EQ(deleted_big, 0);
|
||||||
|
x = DeleteCounter</*BIG=*/false>(&deleted_small);
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_EQ(deleted_small, 0);
|
||||||
|
x = DeleteCounter</*BIG=*/true>(&deleted_big);
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
x.clear();
|
||||||
|
EXPECT_EQ(deleted_big, 2);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
DeleteCounter</*BIG=*/true> big(&deleted_big);
|
||||||
|
DeleteCounter</*BIG=*/false> small(&deleted_small);
|
||||||
|
EXPECT_EQ(deleted_big, 2);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
x = big;
|
||||||
|
EXPECT_EQ(deleted_big, 2);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
x = small;
|
||||||
|
EXPECT_EQ(deleted_big, 3);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
x = std::move(big);
|
||||||
|
EXPECT_EQ(deleted_big, 3);
|
||||||
|
EXPECT_EQ(deleted_small, 2);
|
||||||
|
x = std::move(small);
|
||||||
|
EXPECT_EQ(deleted_big, 4);
|
||||||
|
EXPECT_EQ(deleted_small, 2);
|
||||||
|
x.clear();
|
||||||
|
EXPECT_EQ(deleted_big, 4);
|
||||||
|
EXPECT_EQ(deleted_small, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, MoveAndCopyBetweenBigAndSmallVariants) {
|
||||||
|
int deleted_big = 0;
|
||||||
|
int deleted_small = 0;
|
||||||
|
{
|
||||||
|
Variant x = DeleteCounter</*BIG=*/true>(&deleted_big);
|
||||||
|
Variant y = DeleteCounter</*BIG=*/false>(&deleted_small);
|
||||||
|
EXPECT_EQ(deleted_big, 0);
|
||||||
|
EXPECT_EQ(deleted_small, 0);
|
||||||
|
x = y;
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_EQ(deleted_small, 0);
|
||||||
|
x = x;
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_EQ(deleted_small, 0);
|
||||||
|
EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr);
|
||||||
|
EXPECT_NE(y.get<DeleteCounter<false>>(), nullptr);
|
||||||
|
x = std::move(y);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_EQ(deleted_small, 2);
|
||||||
|
|
||||||
|
deleted_big = 0;
|
||||||
|
deleted_small = 0;
|
||||||
|
{
|
||||||
|
Variant x = DeleteCounter</*BIG=*/false>(&deleted_small);
|
||||||
|
Variant y = DeleteCounter</*BIG=*/true>(&deleted_big);
|
||||||
|
EXPECT_EQ(deleted_big, 0);
|
||||||
|
EXPECT_EQ(deleted_small, 0);
|
||||||
|
x = y;
|
||||||
|
EXPECT_EQ(deleted_big, 0);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
x = x;
|
||||||
|
EXPECT_EQ(deleted_big, 0);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr);
|
||||||
|
EXPECT_NE(y.get<DeleteCounter<true>>(), nullptr);
|
||||||
|
x = std::move(y);
|
||||||
|
EXPECT_EQ(deleted_big, 1);
|
||||||
|
EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(deleted_big, 2);
|
||||||
|
EXPECT_EQ(deleted_small, 1);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(VariantTest, Int) {
|
TEST(VariantTest, Int) {
|
||||||
Variant x;
|
Variant x;
|
||||||
EXPECT_EQ(x.get<void>(), nullptr);
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
@ -49,45 +168,125 @@ TEST(VariantTest, Int) {
|
|||||||
EXPECT_EQ(x.TypeName(), "int");
|
EXPECT_EQ(x.TypeName(), "int");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, Basic) {
|
struct MayCreateAlignmentDifficulties {
|
||||||
|
int a;
|
||||||
|
__m128 b;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool M128AllEqual(const __m128& a, const __m128& b) {
|
||||||
|
return _mm_movemask_ps(_mm_cmpeq_ps(a, b)) == 0xf;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, NotAlignable) {
|
||||||
|
Variant x;
|
||||||
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
__m128 v = _mm_set_ps(1.0, 2.0, 3.0, 4.0);
|
||||||
|
x = MayCreateAlignmentDifficulties{-1, v};
|
||||||
|
EXPECT_NE(x.get<void>(), nullptr);
|
||||||
|
auto* x_val = x.get<MayCreateAlignmentDifficulties>();
|
||||||
|
// check that *x_val == x
|
||||||
|
Variant y = x;
|
||||||
|
EXPECT_EQ(x_val->a, -1);
|
||||||
|
EXPECT_TRUE(M128AllEqual(x_val->b, v));
|
||||||
|
auto* y_val = y.get<MayCreateAlignmentDifficulties>();
|
||||||
|
EXPECT_EQ(y_val->a, -1);
|
||||||
|
EXPECT_TRUE(M128AllEqual(y_val->b, v));
|
||||||
|
Variant z = std::move(y);
|
||||||
|
auto* z_val = z.get<MayCreateAlignmentDifficulties>();
|
||||||
|
EXPECT_EQ(z_val->a, -1);
|
||||||
|
EXPECT_TRUE(M128AllEqual(z_val->b, v));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestBasic() {
|
||||||
Variant x;
|
Variant x;
|
||||||
EXPECT_EQ(x.get<void>(), nullptr);
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
|
||||||
x = Int{42};
|
x = Int<BIG>{42};
|
||||||
|
|
||||||
EXPECT_NE(x.get<void>(), nullptr);
|
EXPECT_NE(x.get<void>(), nullptr);
|
||||||
EXPECT_NE(x.get<Int>(), nullptr);
|
EXPECT_NE(x.get<Int<BIG>>(), nullptr);
|
||||||
EXPECT_EQ(x.get<Int>()->value, 42);
|
EXPECT_EQ(x.get<Int<BIG>>()->value, 42);
|
||||||
EXPECT_EQ(x.TypeName(), "POD");
|
EXPECT_EQ(x.TypeName(), "POD");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, ConstGet) {
|
TEST(VariantTest, Basic) { TestBasic<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, BasicBig) { TestBasic<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestConstGet() {
|
||||||
Variant x;
|
Variant x;
|
||||||
EXPECT_EQ(x.get<void>(), nullptr);
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
|
||||||
x = Int{42};
|
x = Int<BIG>{42};
|
||||||
|
|
||||||
const Variant y = x;
|
const Variant y = x;
|
||||||
|
|
||||||
EXPECT_NE(y.get<void>(), nullptr);
|
EXPECT_NE(y.get<void>(), nullptr);
|
||||||
EXPECT_NE(y.get<Int>(), nullptr);
|
EXPECT_NE(y.get<Int<BIG>>(), nullptr);
|
||||||
EXPECT_EQ(y.get<Int>()->value, 42);
|
EXPECT_EQ(y.get<Int<BIG>>()->value, 42);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, Clear) {
|
TEST(VariantTest, ConstGet) { TestConstGet<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, ConstGetBig) { TestConstGet<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestClear() {
|
||||||
Variant x;
|
Variant x;
|
||||||
EXPECT_EQ(x.get<void>(), nullptr);
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
|
||||||
x = Int{42};
|
x = Int<BIG>{42};
|
||||||
|
|
||||||
EXPECT_NE(x.get<void>(), nullptr);
|
EXPECT_NE(x.get<void>(), nullptr);
|
||||||
EXPECT_NE(x.get<Int>(), nullptr);
|
EXPECT_NE(x.get<Int<BIG>>(), nullptr);
|
||||||
EXPECT_EQ(x.get<Int>()->value, 42);
|
EXPECT_EQ(x.get<Int<BIG>>()->value, 42);
|
||||||
|
|
||||||
x.clear();
|
x.clear();
|
||||||
EXPECT_EQ(x.get<void>(), nullptr);
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, Clear) { TestClear<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, ClearBig) { TestClear<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestClearDeletes() {
|
||||||
|
Variant x;
|
||||||
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
|
||||||
|
int deleted_count = 0;
|
||||||
|
using DC = DeleteCounter<BIG>;
|
||||||
|
DC dc(&deleted_count);
|
||||||
|
EXPECT_EQ(deleted_count, 0);
|
||||||
|
x = dc;
|
||||||
|
EXPECT_EQ(deleted_count, 0);
|
||||||
|
|
||||||
|
EXPECT_NE(x.get<void>(), nullptr);
|
||||||
|
EXPECT_NE(x.get<DC>(), nullptr);
|
||||||
|
|
||||||
|
x.clear();
|
||||||
|
EXPECT_EQ(x.get<void>(), nullptr);
|
||||||
|
EXPECT_EQ(deleted_count, 1);
|
||||||
|
|
||||||
|
x = dc;
|
||||||
|
EXPECT_EQ(deleted_count, 1);
|
||||||
|
|
||||||
|
Variant y = x;
|
||||||
|
EXPECT_EQ(deleted_count, 1);
|
||||||
|
|
||||||
|
x.clear();
|
||||||
|
EXPECT_EQ(deleted_count, 2);
|
||||||
|
|
||||||
|
y.clear();
|
||||||
|
EXPECT_EQ(deleted_count, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, ClearDeletesOnHeap) { TestClearDeletes</*BIG=*/true>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, ClearDeletesOnStack) { TestClearDeletes</*BIG=*/false>(); }
|
||||||
|
|
||||||
TEST(VariantTest, Tensor) {
|
TEST(VariantTest, Tensor) {
|
||||||
Variant x;
|
Variant x;
|
||||||
Tensor t(DT_FLOAT, {});
|
Tensor t(DT_FLOAT, {});
|
||||||
@ -101,6 +300,16 @@ TEST(VariantTest, Tensor) {
|
|||||||
EXPECT_EQ(x.TypeName(), "tensorflow::Tensor");
|
EXPECT_EQ(x.TypeName(), "tensorflow::Tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, NontrivialTensorVariantCopy) {
|
||||||
|
Tensor variants(DT_VARIANT, {});
|
||||||
|
Tensor t(true);
|
||||||
|
test::FillValues<Variant>(&variants, gtl::ArraySlice<Variant>({t}));
|
||||||
|
const Tensor* t_c = variants.flat<Variant>()(0).get<Tensor>();
|
||||||
|
EXPECT_EQ(t_c->dtype(), t.dtype());
|
||||||
|
EXPECT_EQ(t_c->shape(), t.shape());
|
||||||
|
EXPECT_EQ(t_c->scalar<bool>()(), t.scalar<bool>()());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(VariantTest, TensorProto) {
|
TEST(VariantTest, TensorProto) {
|
||||||
Variant x;
|
Variant x;
|
||||||
TensorProto t;
|
TensorProto t;
|
||||||
@ -114,31 +323,41 @@ TEST(VariantTest, TensorProto) {
|
|||||||
EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true);
|
EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, CopyValue) {
|
template <bool BIG>
|
||||||
|
void TestCopyValue() {
|
||||||
Variant x, y;
|
Variant x, y;
|
||||||
x = Int{10};
|
x = Int<BIG>{10};
|
||||||
y = x;
|
y = x;
|
||||||
|
|
||||||
EXPECT_EQ(x.get<Int>()->value, 10);
|
EXPECT_EQ(x.get<Int<BIG>>()->value, 10);
|
||||||
EXPECT_EQ(x.get<Int>()->value, y.get<Int>()->value);
|
EXPECT_EQ(x.get<Int<BIG>>()->value, y.get<Int<BIG>>()->value);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, MoveValue) {
|
TEST(VariantTest, CopyValue) { TestCopyValue<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, CopyValueBig) { TestCopyValue<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestMoveValue() {
|
||||||
Variant x;
|
Variant x;
|
||||||
x = []() -> Variant {
|
x = []() -> Variant {
|
||||||
Variant y;
|
Variant y;
|
||||||
y = Int{10};
|
y = Int<BIG>{10};
|
||||||
return y;
|
return y;
|
||||||
}();
|
}();
|
||||||
EXPECT_EQ(x.get<Int>()->value, 10);
|
EXPECT_EQ(x.get<Int<BIG>>()->value, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, MoveValue) { TestMoveValue<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, MoveValueBig) { TestMoveValue<true>(); }
|
||||||
|
|
||||||
TEST(VariantTest, TypeMismatch) {
|
TEST(VariantTest, TypeMismatch) {
|
||||||
Variant x;
|
Variant x;
|
||||||
x = Int{10};
|
x = Int<false>{10};
|
||||||
EXPECT_EQ(x.get<float>(), nullptr);
|
EXPECT_EQ(x.get<float>(), nullptr);
|
||||||
EXPECT_EQ(x.get<int>(), nullptr);
|
EXPECT_EQ(x.get<int>(), nullptr);
|
||||||
EXPECT_NE(x.get<Int>(), nullptr);
|
EXPECT_NE(x.get<Int<false>>(), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TensorList {
|
struct TensorList {
|
||||||
@ -206,19 +425,26 @@ TEST(VariantTest, TensorListTest) {
|
|||||||
"Variant<type: TensorList value: ", data.DebugString(), ">"));
|
"Variant<type: TensorList value: ", data.DebugString(), ">"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, VariantArray) {
|
template <bool BIG>
|
||||||
|
void TestVariantArray() {
|
||||||
Variant x[2];
|
Variant x[2];
|
||||||
x[0] = Int{2};
|
x[0] = Int<BIG>{2};
|
||||||
x[1] = Float{2.0f};
|
x[1] = Float<BIG>{2.0f};
|
||||||
|
|
||||||
EXPECT_EQ(x[0].get<Int>()->value, 2);
|
EXPECT_EQ(x[0].get<Int<BIG>>()->value, 2);
|
||||||
EXPECT_EQ(x[1].get<Float>()->value, 2.0f);
|
EXPECT_EQ(x[1].get<Float<BIG>>()->value, 2.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, PodUpdate) {
|
TEST(VariantTest, VariantArray) { TestVariantArray<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, VariantArrayBig) { TestVariantArray<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void PodUpdateTest() {
|
||||||
struct Pod {
|
struct Pod {
|
||||||
int x;
|
int x;
|
||||||
float y;
|
float y;
|
||||||
|
char big[BIG ? 256 : 0];
|
||||||
|
|
||||||
string TypeName() const { return "POD"; }
|
string TypeName() const { return "POD"; }
|
||||||
};
|
};
|
||||||
@ -232,10 +458,16 @@ TEST(VariantTest, PodUpdate) {
|
|||||||
EXPECT_EQ(x.get<Pod>()->x, 30);
|
EXPECT_EQ(x.get<Pod>()->x, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantTest, EncodeDecodePod) {
|
TEST(VariantTest, PodUpdate) { PodUpdateTest<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, PodUpdateBig) { PodUpdateTest<true>(); }
|
||||||
|
|
||||||
|
template <bool BIG>
|
||||||
|
void TestEncodeDecodePod() {
|
||||||
struct Pod {
|
struct Pod {
|
||||||
int x;
|
int x;
|
||||||
float y;
|
float y;
|
||||||
|
char big[BIG ? 256 : 0];
|
||||||
|
|
||||||
string TypeName() const { return "POD"; }
|
string TypeName() const { return "POD"; }
|
||||||
};
|
};
|
||||||
@ -247,14 +479,17 @@ TEST(VariantTest, EncodeDecodePod) {
|
|||||||
VariantTensorData serialized;
|
VariantTensorData serialized;
|
||||||
x.Encode(&serialized);
|
x.Encode(&serialized);
|
||||||
|
|
||||||
Variant y;
|
Variant y = Pod{};
|
||||||
y = Pod();
|
|
||||||
y.Decode(serialized);
|
y.Decode(serialized);
|
||||||
|
|
||||||
EXPECT_EQ(p.x, y.get<Pod>()->x);
|
EXPECT_EQ(p.x, y.get<Pod>()->x);
|
||||||
EXPECT_EQ(p.y, y.get<Pod>()->y);
|
EXPECT_EQ(p.y, y.get<Pod>()->y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(VariantTest, EncodeDecodePod) { TestEncodeDecodePod<false>(); }
|
||||||
|
|
||||||
|
TEST(VariantTest, EncodeDecodePodBig) { TestEncodeDecodePod<true>(); }
|
||||||
|
|
||||||
TEST(VariantTest, EncodeDecodeTensor) {
|
TEST(VariantTest, EncodeDecodeTensor) {
|
||||||
Variant x;
|
Variant x;
|
||||||
Tensor t(DT_INT32, {});
|
Tensor t(DT_INT32, {});
|
||||||
|
@ -237,6 +237,8 @@ class IteratorResource : public ResourceBase {
|
|||||||
// destroyed, essentially triggering the iterator deletion.
|
// destroyed, essentially triggering the iterator deletion.
|
||||||
class Deleter {
|
class Deleter {
|
||||||
public:
|
public:
|
||||||
|
Deleter() : deleter_() {}
|
||||||
|
|
||||||
Deleter(ResourceHandle handle, ResourceMgr* resource_manager)
|
Deleter(ResourceHandle handle, ResourceMgr* resource_manager)
|
||||||
: deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
|
: deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
|
||||||
|
|
||||||
@ -248,6 +250,10 @@ class IteratorResource : public ResourceBase {
|
|||||||
VLOG(3) << "IteratorResource::Deleter copy constructor called.";
|
VLOG(3) << "IteratorResource::Deleter copy constructor called.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Deleter& operator=(const Deleter& rhs) = delete;
|
||||||
|
|
||||||
|
Deleter& operator=(Deleter&& rhs) = default;
|
||||||
|
|
||||||
virtual ~Deleter() {
|
virtual ~Deleter() {
|
||||||
VLOG(3) << "IteratorResource::Deleter destructor called.";
|
VLOG(3) << "IteratorResource::Deleter destructor called.";
|
||||||
}
|
}
|
||||||
@ -358,6 +364,9 @@ class IteratorStateVariant {
|
|||||||
Decode(*other.data_);
|
Decode(*other.data_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
|
||||||
|
IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
|
||||||
|
|
||||||
// Initializes this object with the current state of the iterator so
|
// Initializes this object with the current state of the iterator so
|
||||||
// that it can be written on the next call to Encode().
|
// that it can be written on the next call to Encode().
|
||||||
Status InitializeFromIterator(OpKernelContext* ctx,
|
Status InitializeFromIterator(OpKernelContext* ctx,
|
||||||
|
@ -74,6 +74,8 @@ class Mutex : public ResourceBase {
|
|||||||
struct SharedLockReleaser {
|
struct SharedLockReleaser {
|
||||||
std::shared_ptr<LockReleaser> shared_lock;
|
std::shared_ptr<LockReleaser> shared_lock;
|
||||||
|
|
||||||
|
SharedLockReleaser() : shared_lock() {}
|
||||||
|
|
||||||
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
|
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
|
||||||
: shared_lock(std::forward<decltype(lock)>(lock)) {
|
: shared_lock(std::forward<decltype(lock)>(lock)) {
|
||||||
VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
|
VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
|
||||||
@ -86,6 +88,16 @@ class Mutex : public ResourceBase {
|
|||||||
<< " count is: " << shared_lock.use_count();
|
<< " count is: " << shared_lock.use_count();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete;
|
||||||
|
|
||||||
|
SharedLockReleaser& operator=(SharedLockReleaser&& rhs) {
|
||||||
|
if (&rhs == this) return *this;
|
||||||
|
std::swap(shared_lock, rhs.shared_lock);
|
||||||
|
VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get()
|
||||||
|
<< " count is: " << shared_lock.use_count();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
SharedLockReleaser(const SharedLockReleaser& rhs)
|
SharedLockReleaser(const SharedLockReleaser& rhs)
|
||||||
: shared_lock(rhs.shared_lock) {
|
: shared_lock(rhs.shared_lock) {
|
||||||
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
|
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
|
||||||
|
Loading…
Reference in New Issue
Block a user