Factor out Variant wrapper for shared ptrs.
PiperOrigin-RevId: 269446584
This commit is contained in:
parent
5059b5968b
commit
8f9e0e304d
@ -869,6 +869,7 @@ tf_cuda_library(
|
|||||||
"framework/variant_encode_decode.h",
|
"framework/variant_encode_decode.h",
|
||||||
"framework/variant_op_registry.h",
|
"framework/variant_op_registry.h",
|
||||||
"framework/variant_tensor_data.h",
|
"framework/variant_tensor_data.h",
|
||||||
|
"framework/shared_ptr_variant.h",
|
||||||
"framework/allocator_registry.h",
|
"framework/allocator_registry.h",
|
||||||
"framework/attr_value_util.h",
|
"framework/attr_value_util.h",
|
||||||
"framework/bfloat16.h",
|
"framework/bfloat16.h",
|
||||||
@ -2790,6 +2791,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
|
|||||||
"framework/resource_var.h",
|
"framework/resource_var.h",
|
||||||
"framework/run_handler.h",
|
"framework/run_handler.h",
|
||||||
"framework/run_handler_util.h",
|
"framework/run_handler_util.h",
|
||||||
|
"framework/shared_ptr_variant.h",
|
||||||
"framework/tensor_reference.h",
|
"framework/tensor_reference.h",
|
||||||
"framework/tracking_allocator.h", # only needed for tests
|
"framework/tracking_allocator.h", # only needed for tests
|
||||||
"framework/unique_tensor_references.h",
|
"framework/unique_tensor_references.h",
|
||||||
|
74
tensorflow/core/framework/shared_ptr_variant.h
Normal file
74
tensorflow/core/framework/shared_ptr_variant.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
||||||
|
#define TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SharedPtrVariant {
|
||||||
|
std::shared_ptr<T> shared_ptr;
|
||||||
|
|
||||||
|
SharedPtrVariant() : shared_ptr() {}
|
||||||
|
|
||||||
|
explicit SharedPtrVariant(std::shared_ptr<T>&& ptr)
|
||||||
|
: shared_ptr(std::forward<decltype(ptr)>(ptr)) {
|
||||||
|
VLOG(3) << "Creating shared_ptr of " << shared_ptr.get()
|
||||||
|
<< " count is: " << shared_ptr.use_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedPtrVariant(SharedPtrVariant&& rhs)
|
||||||
|
: shared_ptr(std::move(rhs.shared_ptr)) {
|
||||||
|
VLOG(3) << "Moving SharedPtrVariant of " << shared_ptr.get()
|
||||||
|
<< " count is: " << shared_ptr.use_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedPtrVariant& operator=(const SharedPtrVariant& rhs) = delete;
|
||||||
|
|
||||||
|
SharedPtrVariant& operator=(SharedPtrVariant&& rhs) {
|
||||||
|
if (&rhs == this) return *this;
|
||||||
|
std::swap(shared_ptr, rhs.shared_ptr);
|
||||||
|
VLOG(3) << "Move-assign of SharedPtrVariant of " << shared_ptr.get()
|
||||||
|
<< " count is: " << shared_ptr.use_count();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedPtrVariant(const SharedPtrVariant& rhs) : shared_ptr(rhs.shared_ptr) {
|
||||||
|
VLOG(3) << "Copying SharedPtrVariant of " << shared_ptr.get()
|
||||||
|
<< " count is: " << shared_ptr.use_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
~SharedPtrVariant() {
|
||||||
|
VLOG(3) << "Destroying SharedPtrVariant of " << shared_ptr.get()
|
||||||
|
<< " count is: " << shared_ptr.use_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Encode(VariantTensorData*) const {
|
||||||
|
// Not supported.
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Decode(const VariantTensorData&) {
|
||||||
|
return false; // Not supported.
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/shared_ptr_variant.h"
|
||||||
#include "tensorflow/core/framework/variant.h"
|
#include "tensorflow/core/framework/variant.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
@ -71,52 +72,7 @@ class Mutex : public ResourceBase {
|
|||||||
Mutex* mutex_;
|
Mutex* mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SharedLockReleaser {
|
typedef SharedPtrVariant<LockReleaser> SharedLockReleaser;
|
||||||
std::shared_ptr<LockReleaser> shared_lock;
|
|
||||||
|
|
||||||
SharedLockReleaser() : shared_lock() {}
|
|
||||||
|
|
||||||
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
|
|
||||||
: shared_lock(std::forward<decltype(lock)>(lock)) {
|
|
||||||
VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
|
|
||||||
<< " count is: " << shared_lock.use_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
SharedLockReleaser(SharedLockReleaser&& rhs)
|
|
||||||
: shared_lock(std::move(rhs.shared_lock)) {
|
|
||||||
VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
|
|
||||||
<< " count is: " << shared_lock.use_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete;
|
|
||||||
|
|
||||||
SharedLockReleaser& operator=(SharedLockReleaser&& rhs) {
|
|
||||||
if (&rhs == this) return *this;
|
|
||||||
std::swap(shared_lock, rhs.shared_lock);
|
|
||||||
VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get()
|
|
||||||
<< " count is: " << shared_lock.use_count();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
SharedLockReleaser(const SharedLockReleaser& rhs)
|
|
||||||
: shared_lock(rhs.shared_lock) {
|
|
||||||
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
|
|
||||||
<< " count is: " << shared_lock.use_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
~SharedLockReleaser() {
|
|
||||||
VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
|
|
||||||
<< " count is: " << shared_lock.use_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Encode(VariantTensorData*) const {
|
|
||||||
// Not supported.
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Decode(const VariantTensorData&) {
|
|
||||||
return false; // Not supported.
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void AcquireAsync(
|
void AcquireAsync(
|
||||||
OpKernelContext* c,
|
OpKernelContext* c,
|
||||||
@ -205,7 +161,7 @@ class MutexLockOp : public AsyncOpKernel {
|
|||||||
const Status& s,
|
const Status& s,
|
||||||
Mutex::SharedLockReleaser&& lock) {
|
Mutex::SharedLockReleaser&& lock) {
|
||||||
VLOG(2) << "Finished locking mutex " << mutex
|
VLOG(2) << "Finished locking mutex " << mutex
|
||||||
<< " with lock: " << lock.shared_lock.get()
|
<< " with lock: " << lock.shared_ptr.get()
|
||||||
<< " status: " << s.ToString();
|
<< " status: " << s.ToString();
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
variant->scalar<Variant>()() = std::move(lock);
|
variant->scalar<Variant>()() = std::move(lock);
|
||||||
@ -242,7 +198,7 @@ class ConsumeMutexLockOp : public OpKernel {
|
|||||||
"Expected input to contain a SharedLockReleaser "
|
"Expected input to contain a SharedLockReleaser "
|
||||||
"object, but saw variant: '",
|
"object, but saw variant: '",
|
||||||
lock_t.scalar<Variant>()().DebugString(), "'"));
|
lock_t.scalar<Variant>()().DebugString(), "'"));
|
||||||
const int use_count = lock->shared_lock.use_count();
|
const int use_count = lock->shared_ptr.use_count();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, use_count == 1,
|
c, use_count == 1,
|
||||||
errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
|
errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user