Factor out Variant wrapper for shared ptrs.
PiperOrigin-RevId: 269446584
This commit is contained in:
parent
5059b5968b
commit
8f9e0e304d
@ -869,6 +869,7 @@ tf_cuda_library(
|
||||
"framework/variant_encode_decode.h",
|
||||
"framework/variant_op_registry.h",
|
||||
"framework/variant_tensor_data.h",
|
||||
"framework/shared_ptr_variant.h",
|
||||
"framework/allocator_registry.h",
|
||||
"framework/attr_value_util.h",
|
||||
"framework/bfloat16.h",
|
||||
@ -2790,6 +2791,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
|
||||
"framework/resource_var.h",
|
||||
"framework/run_handler.h",
|
||||
"framework/run_handler_util.h",
|
||||
"framework/shared_ptr_variant.h",
|
||||
"framework/tensor_reference.h",
|
||||
"framework/tracking_allocator.h", # only needed for tests
|
||||
"framework/unique_tensor_references.h",
|
||||
|
74
tensorflow/core/framework/shared_ptr_variant.h
Normal file
74
tensorflow/core/framework/shared_ptr_variant.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
struct SharedPtrVariant {
|
||||
std::shared_ptr<T> shared_ptr;
|
||||
|
||||
SharedPtrVariant() : shared_ptr() {}
|
||||
|
||||
explicit SharedPtrVariant(std::shared_ptr<T>&& ptr)
|
||||
: shared_ptr(std::forward<decltype(ptr)>(ptr)) {
|
||||
VLOG(3) << "Creating shared_ptr of " << shared_ptr.get()
|
||||
<< " count is: " << shared_ptr.use_count();
|
||||
}
|
||||
|
||||
SharedPtrVariant(SharedPtrVariant&& rhs)
|
||||
: shared_ptr(std::move(rhs.shared_ptr)) {
|
||||
VLOG(3) << "Moving SharedPtrVariant of " << shared_ptr.get()
|
||||
<< " count is: " << shared_ptr.use_count();
|
||||
}
|
||||
|
||||
SharedPtrVariant& operator=(const SharedPtrVariant& rhs) = delete;
|
||||
|
||||
SharedPtrVariant& operator=(SharedPtrVariant&& rhs) {
|
||||
if (&rhs == this) return *this;
|
||||
std::swap(shared_ptr, rhs.shared_ptr);
|
||||
VLOG(3) << "Move-assign of SharedPtrVariant of " << shared_ptr.get()
|
||||
<< " count is: " << shared_ptr.use_count();
|
||||
return *this;
|
||||
}
|
||||
|
||||
SharedPtrVariant(const SharedPtrVariant& rhs) : shared_ptr(rhs.shared_ptr) {
|
||||
VLOG(3) << "Copying SharedPtrVariant of " << shared_ptr.get()
|
||||
<< " count is: " << shared_ptr.use_count();
|
||||
}
|
||||
|
||||
~SharedPtrVariant() {
|
||||
VLOG(3) << "Destroying SharedPtrVariant of " << shared_ptr.get()
|
||||
<< " count is: " << shared_ptr.use_count();
|
||||
}
|
||||
|
||||
void Encode(VariantTensorData*) const {
|
||||
// Not supported.
|
||||
}
|
||||
|
||||
bool Decode(const VariantTensorData&) {
|
||||
return false; // Not supported.
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/shared_ptr_variant.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
@ -71,52 +72,7 @@ class Mutex : public ResourceBase {
|
||||
Mutex* mutex_;
|
||||
};
|
||||
|
||||
struct SharedLockReleaser {
|
||||
std::shared_ptr<LockReleaser> shared_lock;
|
||||
|
||||
SharedLockReleaser() : shared_lock() {}
|
||||
|
||||
explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
|
||||
: shared_lock(std::forward<decltype(lock)>(lock)) {
|
||||
VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
|
||||
<< " count is: " << shared_lock.use_count();
|
||||
}
|
||||
|
||||
SharedLockReleaser(SharedLockReleaser&& rhs)
|
||||
: shared_lock(std::move(rhs.shared_lock)) {
|
||||
VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
|
||||
<< " count is: " << shared_lock.use_count();
|
||||
}
|
||||
|
||||
SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete;
|
||||
|
||||
SharedLockReleaser& operator=(SharedLockReleaser&& rhs) {
|
||||
if (&rhs == this) return *this;
|
||||
std::swap(shared_lock, rhs.shared_lock);
|
||||
VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get()
|
||||
<< " count is: " << shared_lock.use_count();
|
||||
return *this;
|
||||
}
|
||||
|
||||
SharedLockReleaser(const SharedLockReleaser& rhs)
|
||||
: shared_lock(rhs.shared_lock) {
|
||||
VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
|
||||
<< " count is: " << shared_lock.use_count();
|
||||
}
|
||||
|
||||
~SharedLockReleaser() {
|
||||
VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
|
||||
<< " count is: " << shared_lock.use_count();
|
||||
}
|
||||
|
||||
void Encode(VariantTensorData*) const {
|
||||
// Not supported.
|
||||
}
|
||||
|
||||
bool Decode(const VariantTensorData&) {
|
||||
return false; // Not supported.
|
||||
}
|
||||
};
|
||||
typedef SharedPtrVariant<LockReleaser> SharedLockReleaser;
|
||||
|
||||
void AcquireAsync(
|
||||
OpKernelContext* c,
|
||||
@ -205,7 +161,7 @@ class MutexLockOp : public AsyncOpKernel {
|
||||
const Status& s,
|
||||
Mutex::SharedLockReleaser&& lock) {
|
||||
VLOG(2) << "Finished locking mutex " << mutex
|
||||
<< " with lock: " << lock.shared_lock.get()
|
||||
<< " with lock: " << lock.shared_ptr.get()
|
||||
<< " status: " << s.ToString();
|
||||
if (s.ok()) {
|
||||
variant->scalar<Variant>()() = std::move(lock);
|
||||
@ -242,7 +198,7 @@ class ConsumeMutexLockOp : public OpKernel {
|
||||
"Expected input to contain a SharedLockReleaser "
|
||||
"object, but saw variant: '",
|
||||
lock_t.scalar<Variant>()().DebugString(), "'"));
|
||||
const int use_count = lock->shared_lock.use_count();
|
||||
const int use_count = lock->shared_ptr.use_count();
|
||||
OP_REQUIRES(
|
||||
c, use_count == 1,
|
||||
errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
|
||||
|
Loading…
Reference in New Issue
Block a user