From 8f706abe75391bc5bd877761e244b2cf57816326 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Wed, 7 Dec 2016 10:13:37 -0800 Subject: [PATCH] Refactor resource ops. There are many ops producing different kinds of resources, and the same code pattern is repeated in several places. Change: 141323281 --- tensorflow/core/BUILD | 2 + tensorflow/core/framework/reader_op_kernel.cc | 55 ----- tensorflow/core/framework/reader_op_kernel.h | 35 +-- .../core/framework/resource_op_kernel.h | 119 +++++++++++ .../core/framework/resource_op_kernel_test.cc | 202 ++++++++++++++++++ tensorflow/core/kernels/barrier_ops.cc | 57 ++--- tensorflow/core/kernels/fifo_queue_op.cc | 24 +-- .../core/kernels/padding_fifo_queue_op.cc | 24 +-- tensorflow/core/kernels/priority_queue_op.cc | 24 +-- tensorflow/core/kernels/queue_op.h | 70 ++---- .../core/kernels/random_shuffle_queue_op.cc | 26 +-- 11 files changed, 413 insertions(+), 225 deletions(-) delete mode 100644 tensorflow/core/framework/reader_op_kernel.cc create mode 100644 tensorflow/core/framework/resource_op_kernel.h create mode 100644 tensorflow/core/framework/resource_op_kernel_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 81a517c5068..92519e269e8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -323,6 +323,7 @@ tf_cuda_library( "framework/reader_op_kernel.h", "framework/register_types.h", "framework/resource_mgr.h", + "framework/resource_op_kernel.h", "framework/selective_registration.h", "framework/session_state.h", "framework/shape_inference.h", @@ -1705,6 +1706,7 @@ tf_cc_tests( "framework/partial_tensor_shape_test.cc", # "framework/rendezvous_test.cc", # flaky b/30476344 "framework/resource_mgr_test.cc", + "framework/resource_op_kernel_test.cc", "framework/shape_inference_test.cc", "framework/shape_inference_testutil_test.cc", "framework/tensor_shape_test.cc", diff --git a/tensorflow/core/framework/reader_op_kernel.cc b/tensorflow/core/framework/reader_op_kernel.cc deleted file mode 100644 index 44df86b479c..00000000000 --- a/tensorflow/core/framework/reader_op_kernel.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/reader_op_kernel.h" - -namespace tensorflow { - -ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context) - : OpKernel(context), have_handle_(false) { - OP_REQUIRES_OK(context, context->allocate_persistent( - tensorflow::DT_STRING, - tensorflow::TensorShape({2}), &handle_, nullptr)); -} - -ReaderOpKernel::~ReaderOpKernel() { - if (have_handle_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>( - cinfo_.container(), cinfo_.name())); - } -} - -void ReaderOpKernel::Compute(OpKernelContext* ctx) { - mutex_lock l(mu_); - if (!have_handle_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false)); - ReaderInterface* reader; - OP_REQUIRES_OK(ctx, - cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>( - cinfo_.container(), cinfo_.name(), &reader, - [this](ReaderInterface** ret) { - *ret = factory_(); - return Status::OK(); - })); - reader->Unref(); - auto h = handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - have_handle_ = true; - } - ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx)); -} - -} // namespace tensorflow diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h index 755af123ffd..502b98f13d9 100644 --- a/tensorflow/core/framework/reader_op_kernel.h +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -22,35 +22,44 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/reader_interface.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -// Implementation for ops providing a Reader. -class ReaderOpKernel : public OpKernel { - public: - explicit ReaderOpKernel(OpKernelConstruction* context); - ~ReaderOpKernel() override; +// NOTE: This is now a very thin layer over ResourceOpKernel. +// TODO(sjhwang): Remove dependencies to this class, then delete this. - void Compute(OpKernelContext* context) override; +// Implementation for ops providing a Reader. +class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> { + public: + using ResourceOpKernel::ResourceOpKernel; // Must be called by descendants before the first call to Compute() // (typically called during construction). factory must return a // ReaderInterface descendant allocated with new that ReaderOpKernel // will take ownership of. - void SetReaderFactory(std::function<ReaderInterface*()> factory) { + void SetReaderFactory(std::function<ReaderInterface*()> factory) + LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); - DCHECK(!have_handle_); + DCHECK(resource_ == nullptr); factory_ = factory; } private: - mutex mu_; - bool have_handle_ GUARDED_BY(mu_); - PersistentTensor handle_ GUARDED_BY(mu_); - ContainerInfo cinfo_; - std::function<ReaderInterface*()> factory_; + Status CreateResource(ReaderInterface** reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *reader = factory_(); + if (*reader == nullptr) { + return errors::ResourceExhausted("Failed to allocate reader"); + } + std::function<ReaderInterface*()> temp = nullptr; + factory_.swap(temp); + return Status::OK(); + } + + std::function<ReaderInterface*()> factory_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h new file mode 100644 index 00000000000..09f81ed0583 --- /dev/null +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -0,0 +1,119 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ + +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// ResourceOpKernel<T> is a virtual base class for resource op implementing +// interface type T. The inherited op looks up the resource name (determined by +// ContainerInfo), and creates a new resource if necessary. +// +// Requirements: +// - Op must be marked as stateful. +// - Op must have `container` and `shared_name` attributes. Empty `container` +// means using the default container. Empty `shared_name` means private +// resource. +// - Subclass must override CreateResource(). +// - Subclass is encouraged to override VerifyResource(). +template <typename T> +class ResourceOpKernel : public OpKernel { + public: + explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &handle_, nullptr)); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ResourceOpKernel() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->template Delete<T>( + cinfo_.container(), cinfo_.name())); + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (resource_ == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + T* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate<T>(cinfo_.container(), cinfo_.name(), &resource, + [this](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status s = CreateResource(ret); + if (!s.ok() && *ret != nullptr) { + CHECK((*ret)->Unref()); + } + return s; + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + auto h = handle_.AccessTensor(context)->template flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + resource_ = resource; + } + context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); + } + + protected: + // Variables accessible from subclasses. + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + T* resource_ GUARDED_BY(mu_) = nullptr; + + private: + // Must return a T descendant allocated with new that ResourceOpKernel will + // take ownership of. + virtual Status CreateResource(T** resource) EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + virtual Status VerifyResource(T* resource) { return Status::OK(); } + + PersistentTensor handle_ GUARDED_BY(mu_); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ diff --git a/tensorflow/core/framework/resource_op_kernel_test.cc b/tensorflow/core/framework/resource_op_kernel_test.cc new file mode 100644 index 00000000000..c1e503dc576 --- /dev/null +++ b/tensorflow/core/framework/resource_op_kernel_test.cc @@ -0,0 +1,202 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_op_kernel.h" + +#include <memory> + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { + +// Stub DeviceBase subclass which only returns allocators. +class StubDevice : public DeviceBase { + public: + StubDevice() : DeviceBase(nullptr) {} + + Allocator* GetAllocator(AllocatorAttributes) override { + return cpu_allocator(); + } +}; + +// Stub resource for testing resource op kernel. +class StubResource : public ResourceBase { + public: + string DebugString() override { return ""; } + int code; +}; + +class StubResourceOpKernel : public ResourceOpKernel<StubResource> { + public: + using ResourceOpKernel::ResourceOpKernel; + + StubResource* resource() LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + return resource_; + } + + private: + Status CreateResource(StubResource** resource) override { + *resource = CHECK_NOTNULL(new StubResource); + return GetNodeAttr(def(), "code", &(*resource)->code); + } + + Status VerifyResource(StubResource* resource) override { + int code; + TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code)); + if (code != resource->code) { + return errors::InvalidArgument("stub has code ", resource->code, + " but requested code ", code); + } + return Status::OK(); + } +}; + +REGISTER_OP("StubResourceOp") + .Attr("code: int") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("output: Ref(string)"); + +REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU), + StubResourceOpKernel); + +class ResourceOpKernelTest : public ::testing::Test { + protected: + std::unique_ptr<StubResourceOpKernel> CreateOp(int code, + const string& shared_name) { + NodeDef node_def; + TF_CHECK_OK( + NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp") + .Attr("code", code) + .Attr("shared_name", shared_name) + .Finalize(&node_def)); + Status status; + std::unique_ptr<OpKernel> op(CreateOpKernel( + DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()), + node_def, TF_GRAPH_DEF_VERSION, &status)); + TF_EXPECT_OK(status) << status; + EXPECT_TRUE(op != nullptr); + + // Downcast to StubResourceOpKernel to call resource() later. + std::unique_ptr<StubResourceOpKernel> resource_op( + dynamic_cast<StubResourceOpKernel*>(op.get())); + EXPECT_TRUE(resource_op != nullptr); + if (resource_op != nullptr) { + op.release(); + } + return resource_op; + } + + Status RunOpKernel(OpKernel* op) { + OpKernelContext::Params params; + + params.device = &device_; + params.resource_manager = &mgr_; + params.op_kernel = op; + + OpKernelContext context(¶ms); + op->Compute(&context); + return context.status(); + } + + StubDevice device_; + ResourceMgr mgr_; + int count_ = 0; +}; + +TEST_F(ResourceOpKernelTest, PrivateResource) { + // Empty shared_name means private resource. + const int code = -100; + auto op = CreateOp(code, ""); + ASSERT_TRUE(op != nullptr); + TF_EXPECT_OK(RunOpKernel(op.get())); + + // Default non-shared name provided from ContainerInfo. + const string key = "_0_" + op->name(); + + StubResource* resource; + TF_ASSERT_OK( + mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource)); + EXPECT_EQ(op->resource(), resource); // Check resource identity. + EXPECT_EQ(code, resource->code); // Check resource stored information. + resource->Unref(); + + // Destroy the op kernel. Expect the resource to be released. + op = nullptr; + Status s = + mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource); + + EXPECT_FALSE(s.ok()); +} + +TEST_F(ResourceOpKernelTest, SharedResource) { + const string shared_name = "shared_stub"; + const int code = -201; + auto op = CreateOp(code, shared_name); + ASSERT_TRUE(op != nullptr); + TF_EXPECT_OK(RunOpKernel(op.get())); + + StubResource* resource; + TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, + &resource)); + EXPECT_EQ(op->resource(), resource); // Check resource identity. + EXPECT_EQ(code, resource->code); // Check resource stored information. + resource->Unref(); + + // Destroy the op kernel. Expect the resource not to be released. + op = nullptr; + TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, + &resource)); + resource->Unref(); +} + +TEST_F(ResourceOpKernelTest, LookupShared) { + auto op1 = CreateOp(-333, "shared_stub"); + auto op2 = CreateOp(-333, "shared_stub"); + ASSERT_TRUE(op1 != nullptr); + ASSERT_TRUE(op2 != nullptr); + + TF_EXPECT_OK(RunOpKernel(op1.get())); + TF_EXPECT_OK(RunOpKernel(op2.get())); + EXPECT_EQ(op1->resource(), op2->resource()); +} + +TEST_F(ResourceOpKernelTest, VerifyResource) { + auto op1 = CreateOp(-444, "shared_stub"); + auto op2 = CreateOp(0, "shared_stub"); // Different resource code. + ASSERT_TRUE(op1 != nullptr); + ASSERT_TRUE(op2 != nullptr); + + TF_EXPECT_OK(RunOpKernel(op1.get())); + EXPECT_FALSE(RunOpKernel(op2.get()).ok()); + EXPECT_TRUE(op1->resource() != nullptr); + EXPECT_TRUE(op2->resource() == nullptr); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index 84f57517605..03880b98273 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -432,14 +433,10 @@ class Barrier : public ResourceBase { TF_DISALLOW_COPY_AND_ASSIGN(Barrier); }; -class BarrierOp : public OpKernel { +class BarrierOp : public ResourceOpKernel<Barrier> { public: explicit BarrierOp(OpKernelConstruction* context) - : OpKernel(context), barrier_handle_set_(false) { - OP_REQUIRES_OK(context, - context->allocate_persistent(tensorflow::DT_STRING, - tensorflow::TensorShape({2}), - &barrier_handle_, nullptr)); + : ResourceOpKernel(context) { OP_REQUIRES_OK( context, context->GetAttr("component_types", &value_component_types_)); OP_REQUIRES_OK(context, @@ -458,34 +455,19 @@ class BarrierOp : public OpKernel { "limited capacity.")); } - ~BarrierOp() override { - // If the barrier object was not shared, delete it. - if (barrier_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<Barrier>(cinfo_.container(), - cinfo_.name())); - } - } - - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (!barrier_handle_set_) { - OP_REQUIRES_OK(ctx, SetBarrierHandle(ctx)); - } - ctx->set_output_ref(0, &mu_, barrier_handle_.AccessTensor(ctx)); - } - private: - Status SetBarrierHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); - Barrier* barrier = nullptr; - auto creator = [this](Barrier** ret) { - *ret = new Barrier(value_component_types_, value_component_shapes_, - cinfo_.name()); - return (*ret)->Initialize(); - }; - TF_RETURN_IF_ERROR(cinfo_.resource_manager()->LookupOrCreate<Barrier>( - cinfo_.container(), cinfo_.name(), &barrier, creator)); - core::ScopedUnref unref_me(barrier); + Status CreateResource(Barrier** barrier) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *barrier = new Barrier(value_component_types_, value_component_shapes_, + cinfo_.name()); + if (*barrier == nullptr) { + return errors::ResourceExhausted("Failed to allocate barrier"); + } + return (*barrier)->Initialize(); + } + + Status VerifyResource(Barrier* barrier) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (barrier->component_types() != value_component_types_) { return errors::InvalidArgument( "Shared barrier '", cinfo_.name(), "' has component types ", @@ -500,20 +482,11 @@ class BarrierOp : public OpKernel { " but requested component shapes were ", TensorShapeUtils::ShapeListString(value_component_shapes_)); } - auto h = barrier_handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - barrier_handle_set_ = true; return Status::OK(); } DataTypeVector value_component_types_; std::vector<TensorShape> value_component_shapes_; - ContainerInfo cinfo_; - - mutex mu_; - PersistentTensor barrier_handle_ GUARDED_BY(mu_); - bool barrier_handle_set_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp); }; diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc index c710eff704c..31df6642540 100644 --- a/tensorflow/core/kernels/fifo_queue_op.cc +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -39,28 +39,20 @@ namespace tensorflow { // backed by FIFOQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class FIFOQueueOp : public QueueOp { +class FIFOQueueOp : public TypedQueueOp { public: - explicit FIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) { + explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } - protected: - CreatorCallback GetCreator() const override { - return [this](QueueInterface** ret) { - FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, - component_shapes_, cinfo_.name()); - Status s = queue->Initialize(); - if (s.ok()) { - *ret = queue; - } else { - queue->Unref(); - } - return s; - }; + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, + component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); } - private: std::vector<TensorShape> component_shapes_; TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); }; diff --git a/tensorflow/core/kernels/padding_fifo_queue_op.cc b/tensorflow/core/kernels/padding_fifo_queue_op.cc index c3c4a3e10c3..b87b2b90b8e 100644 --- a/tensorflow/core/kernels/padding_fifo_queue_op.cc +++ b/tensorflow/core/kernels/padding_fifo_queue_op.cc @@ -40,10 +40,10 @@ namespace tensorflow { // backed by PaddingFIFOQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class PaddingFIFOQueueOp : public QueueOp { +class PaddingFIFOQueueOp : public TypedQueueOp { public: explicit PaddingFIFOQueueOp(OpKernelConstruction* context) - : QueueOp(context) { + : TypedQueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); for (const auto& shape : component_shapes_) { OP_REQUIRES(context, shape.dims() >= 0, @@ -52,22 +52,14 @@ class PaddingFIFOQueueOp : public QueueOp { } } - protected: - CreatorCallback GetCreator() const override { - return [this](QueueInterface** ret) { - PaddingFIFOQueue* queue = new PaddingFIFOQueue( - capacity_, component_types_, component_shapes_, cinfo_.name()); - Status s = queue->Initialize(); - if (s.ok()) { - *ret = queue; - } else { - queue->Unref(); - } - return s; - }; + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + PaddingFIFOQueue* queue = new PaddingFIFOQueue( + capacity_, component_types_, component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); } - private: std::vector<PartialTensorShape> component_shapes_; TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueueOp); diff --git a/tensorflow/core/kernels/priority_queue_op.cc b/tensorflow/core/kernels/priority_queue_op.cc index d928145cda7..eb4b4f289f0 100644 --- a/tensorflow/core/kernels/priority_queue_op.cc +++ b/tensorflow/core/kernels/priority_queue_op.cc @@ -38,9 +38,10 @@ namespace tensorflow { // backed by PriorityQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class PriorityQueueOp : public QueueOp { +class PriorityQueueOp : public TypedQueueOp { public: - explicit PriorityQueueOp(OpKernelConstruction* context) : QueueOp(context) { + explicit PriorityQueueOp(OpKernelConstruction* context) + : TypedQueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); component_types_.insert(component_types_.begin(), DT_INT64); if (!component_shapes_.empty()) { @@ -48,19 +49,12 @@ class PriorityQueueOp : public QueueOp { } } - protected: - CreatorCallback GetCreator() const override { - return [this](QueueInterface** ret) { - PriorityQueue* queue = new PriorityQueue( - capacity_, component_types_, component_shapes_, cinfo_.name()); - Status s = queue->Initialize(); - if (s.ok()) { - *ret = queue; - } else { - queue->Unref(); - } - return s; - }; + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + PriorityQueue* queue = new PriorityQueue(capacity_, component_types_, + component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); } std::vector<TensorShape> component_shapes_; diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 7694827854c..e13ea46e56f 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -19,28 +19,21 @@ limitations under the License. #include <deque> #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/queue_base.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { // Defines a QueueOp, an abstract class for Queue construction ops. -class QueueOp : public OpKernel { +class QueueOp : public ResourceOpKernel<QueueInterface> { public: - QueueOp(OpKernelConstruction* context) - : OpKernel(context), queue_handle_set_(false) { + QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &queue_handle_, nullptr)); if (capacity_ < 0) { capacity_ = QueueBase::kUnbounded; } @@ -48,55 +41,30 @@ class QueueOp : public OpKernel { context->GetAttr("component_types", &component_types_)); } - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (!queue_handle_set_) { - OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); - } - ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); - } - protected: - ~QueueOp() override { - // If the queue object was not shared, delete it. - if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>( - cinfo_.container(), cinfo_.name())); - } - } - - protected: - typedef std::function<Status(QueueInterface**)> CreatorCallback; - - // Subclasses must override this - virtual CreatorCallback GetCreator() const = 0; - // Variables accessible by subclasses int32 capacity_; DataTypeVector component_types_; - ContainerInfo cinfo_; private: - Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); - CreatorCallback creator = GetCreator(); - QueueInterface* queue; - TF_RETURN_IF_ERROR( - cinfo_.resource_manager()->LookupOrCreate<QueueInterface>( - cinfo_.container(), cinfo_.name(), &queue, creator)); - core::ScopedUnref unref_me(queue); - // Verify that the shared queue is compatible with the requested arguments. - TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); - auto h = queue_handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - queue_handle_set_ = true; - return Status::OK(); + Status VerifyResource(QueueInterface* queue) override { + return queue->MatchesNodeDef(def()); } +}; - mutex mu_; - PersistentTensor queue_handle_ GUARDED_BY(mu_); - bool queue_handle_set_ GUARDED_BY(mu_); +class TypedQueueOp : public QueueOp { + public: + using QueueOp::QueueOp; + + protected: + template <typename TypedQueue> + Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) { + if (queue == nullptr) { + return errors::ResourceExhausted("Failed to allocate queue."); + } + *ret = queue; + return queue->Initialize(); + } }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 6b3b792814f..064d8b9c748 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -459,10 +459,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { // backed by RandomShuffleQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class RandomShuffleQueueOp : public QueueOp { +class RandomShuffleQueueOp : public TypedQueueOp { public: explicit RandomShuffleQueueOp(OpKernelConstruction* context) - : QueueOp(context) { + : TypedQueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("min_after_dequeue", &min_after_dequeue_)); OP_REQUIRES(context, min_after_dequeue_ >= 0, @@ -478,23 +478,15 @@ class RandomShuffleQueueOp : public QueueOp { OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } - protected: - CreatorCallback GetCreator() const override { - return [this](QueueInterface** ret) { - auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_, - seed2_, component_types_, - component_shapes_, cinfo_.name()); - Status s = q->Initialize(); - if (s.ok()) { - *ret = q; - } else { - q->Unref(); - } - return s; - }; + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + RandomShuffleQueue* queue = new RandomShuffleQueue( + capacity_, min_after_dequeue_, seed_, seed2_, component_types_, + component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); } - private: int32 min_after_dequeue_; int64 seed_; int64 seed2_;