Factor out Variant wrapper for shared ptrs.

PiperOrigin-RevId: 269446584
This commit is contained in:
A. Unique TensorFlower 2019-09-16 16:24:32 -07:00 committed by TensorFlower Gardener
parent 5059b5968b
commit 8f9e0e304d
3 changed files with 80 additions and 48 deletions

View File

@ -869,6 +869,7 @@ tf_cuda_library(
"framework/variant_encode_decode.h",
"framework/variant_op_registry.h",
"framework/variant_tensor_data.h",
"framework/shared_ptr_variant.h",
"framework/allocator_registry.h",
"framework/attr_value_util.h",
"framework/bfloat16.h",
@ -2790,6 +2791,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/resource_var.h",
"framework/run_handler.h",
"framework/run_handler_util.h",
"framework/shared_ptr_variant.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
#define TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
#include <memory>
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
template <typename T>
struct SharedPtrVariant {
std::shared_ptr<T> shared_ptr;
SharedPtrVariant() : shared_ptr() {}
explicit SharedPtrVariant(std::shared_ptr<T>&& ptr)
: shared_ptr(std::forward<decltype(ptr)>(ptr)) {
VLOG(3) << "Creating shared_ptr of " << shared_ptr.get()
<< " count is: " << shared_ptr.use_count();
}
SharedPtrVariant(SharedPtrVariant&& rhs)
: shared_ptr(std::move(rhs.shared_ptr)) {
VLOG(3) << "Moving SharedPtrVariant of " << shared_ptr.get()
<< " count is: " << shared_ptr.use_count();
}
SharedPtrVariant& operator=(const SharedPtrVariant& rhs) = delete;
SharedPtrVariant& operator=(SharedPtrVariant&& rhs) {
if (&rhs == this) return *this;
std::swap(shared_ptr, rhs.shared_ptr);
VLOG(3) << "Move-assign of SharedPtrVariant of " << shared_ptr.get()
<< " count is: " << shared_ptr.use_count();
return *this;
}
SharedPtrVariant(const SharedPtrVariant& rhs) : shared_ptr(rhs.shared_ptr) {
VLOG(3) << "Copying SharedPtrVariant of " << shared_ptr.get()
<< " count is: " << shared_ptr.use_count();
}
~SharedPtrVariant() {
VLOG(3) << "Destroying SharedPtrVariant of " << shared_ptr.get()
<< " count is: " << shared_ptr.use_count();
}
void Encode(VariantTensorData*) const {
// Not supported.
}
bool Decode(const VariantTensorData&) {
return false; // Not supported.
}
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shared_ptr_variant.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/ops_util.h"
@ -71,52 +72,7 @@ class Mutex : public ResourceBase {
Mutex* mutex_;
};
struct SharedLockReleaser {
std::shared_ptr<LockReleaser> shared_lock;
SharedLockReleaser() : shared_lock() {}
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
: shared_lock(std::forward<decltype(lock)>(lock)) {
VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
}
SharedLockReleaser(SharedLockReleaser&& rhs)
: shared_lock(std::move(rhs.shared_lock)) {
VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
}
SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete;
SharedLockReleaser& operator=(SharedLockReleaser&& rhs) {
if (&rhs == this) return *this;
std::swap(shared_lock, rhs.shared_lock);
VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
return *this;
}
SharedLockReleaser(const SharedLockReleaser& rhs)
: shared_lock(rhs.shared_lock) {
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
}
~SharedLockReleaser() {
VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
<< " count is: " << shared_lock.use_count();
}
void Encode(VariantTensorData*) const {
// Not supported.
}
bool Decode(const VariantTensorData&) {
return false; // Not supported.
}
};
typedef SharedPtrVariant<LockReleaser> SharedLockReleaser;
void AcquireAsync(
OpKernelContext* c,
@ -205,7 +161,7 @@ class MutexLockOp : public AsyncOpKernel {
const Status& s,
Mutex::SharedLockReleaser&& lock) {
VLOG(2) << "Finished locking mutex " << mutex
<< " with lock: " << lock.shared_lock.get()
<< " with lock: " << lock.shared_ptr.get()
<< " status: " << s.ToString();
if (s.ok()) {
variant->scalar<Variant>()() = std::move(lock);
@ -242,7 +198,7 @@ class ConsumeMutexLockOp : public OpKernel {
"Expected input to contain a SharedLockReleaser "
"object, but saw variant: '",
lock_t.scalar<Variant>()().DebugString(), "'"));
const int use_count = lock->shared_lock.use_count();
const int use_count = lock->shared_ptr.use_count();
OP_REQUIRES(
c, use_count == 1,
errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",