Refactor resource ops. There are many ops producing different kinds of resources, and the same code pattern is repeated in several places.

Change: 141323281
This commit is contained in:
A. Unique TensorFlower 2016-12-07 10:13:37 -08:00 committed by TensorFlower Gardener
parent db2a81a82c
commit 8f706abe75
11 changed files with 413 additions and 225 deletions

View File

@ -323,6 +323,7 @@ tf_cuda_library(
"framework/reader_op_kernel.h",
"framework/register_types.h",
"framework/resource_mgr.h",
"framework/resource_op_kernel.h",
"framework/selective_registration.h",
"framework/session_state.h",
"framework/shape_inference.h",
@ -1705,6 +1706,7 @@ tf_cc_tests(
"framework/partial_tensor_shape_test.cc",
# "framework/rendezvous_test.cc", # flaky b/30476344
"framework/resource_mgr_test.cc",
"framework/resource_op_kernel_test.cc",
"framework/shape_inference_test.cc",
"framework/shape_inference_testutil_test.cc",
"framework/tensor_shape_test.cc",

View File

@ -1,55 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/reader_op_kernel.h"
namespace tensorflow {
ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context)
: OpKernel(context), have_handle_(false) {
OP_REQUIRES_OK(context, context->allocate_persistent(
tensorflow::DT_STRING,
tensorflow::TensorShape({2}), &handle_, nullptr));
}
ReaderOpKernel::~ReaderOpKernel() {
if (have_handle_ && cinfo_.resource_is_private_to_kernel()) {
TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>(
cinfo_.container(), cinfo_.name()));
}
}
void ReaderOpKernel::Compute(OpKernelContext* ctx) {
mutex_lock l(mu_);
if (!have_handle_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false));
ReaderInterface* reader;
OP_REQUIRES_OK(ctx,
cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>(
cinfo_.container(), cinfo_.name(), &reader,
[this](ReaderInterface** ret) {
*ret = factory_();
return Status::OK();
}));
reader->Unref();
auto h = handle_.AccessTensor(ctx)->flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
have_handle_ = true;
}
ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx));
}
} // namespace tensorflow

View File

@ -22,35 +22,44 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/reader_interface.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Implementation for ops providing a Reader.
class ReaderOpKernel : public OpKernel {
public:
explicit ReaderOpKernel(OpKernelConstruction* context);
~ReaderOpKernel() override;
// NOTE: This is now a very thin layer over ResourceOpKernel.
// TODO(sjhwang): Remove dependencies to this class, then delete this.
void Compute(OpKernelContext* context) override;
// Implementation for ops providing a Reader.
class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
public:
using ResourceOpKernel::ResourceOpKernel;
// Must be called by descendants before the first call to Compute()
// (typically called during construction). factory must return a
// ReaderInterface descendant allocated with new that ReaderOpKernel
// will take ownership of.
void SetReaderFactory(std::function<ReaderInterface*()> factory) {
void SetReaderFactory(std::function<ReaderInterface*()> factory)
LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
DCHECK(!have_handle_);
DCHECK(resource_ == nullptr);
factory_ = factory;
}
private:
mutex mu_;
bool have_handle_ GUARDED_BY(mu_);
PersistentTensor handle_ GUARDED_BY(mu_);
ContainerInfo cinfo_;
std::function<ReaderInterface*()> factory_;
Status CreateResource(ReaderInterface** reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
*reader = factory_();
if (*reader == nullptr) {
return errors::ResourceExhausted("Failed to allocate reader");
}
std::function<ReaderInterface*()> temp = nullptr;
factory_.swap(temp);
return Status::OK();
}
std::function<ReaderInterface*()> factory_ GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
#include <string>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// ResourceOpKernel<T> is a virtual base class for resource op implementing
// interface type T. The inherited op looks up the resource name (determined by
// ContainerInfo), and creates a new resource if necessary.
//
// Requirements:
// - Op must be marked as stateful.
// - Op must have `container` and `shared_name` attributes. Empty `container`
// means using the default container. Empty `shared_name` means private
// resource.
// - Subclass must override CreateResource().
// - Subclass is encouraged to override VerifyResource().
template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
}
// The resource is deleted from the resource manager only when it is private
// to kernel. Ideally the resource should be deleted when it is no longer held
// by anyone, but it would break backward compatibility.
~ResourceOpKernel() override {
if (resource_ != nullptr) {
resource_->Unref();
if (cinfo_.resource_is_private_to_kernel()) {
TF_CHECK_OK(cinfo_.resource_manager()->template Delete<T>(
cinfo_.container(), cinfo_.name()));
}
}
}
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
if (resource_ == nullptr) {
ResourceMgr* mgr = context->resource_manager();
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
T* resource;
OP_REQUIRES_OK(
context,
mgr->LookupOrCreate<T>(cinfo_.container(), cinfo_.name(), &resource,
[this](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status s = CreateResource(ret);
if (!s.ok() && *ret != nullptr) {
CHECK((*ret)->Unref());
}
return s;
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
resource->Unref();
context->SetStatus(s);
return;
}
auto h = handle_.AccessTensor(context)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
resource_ = resource;
}
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
}
protected:
// Variables accessible from subclasses.
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
T* resource_ GUARDED_BY(mu_) = nullptr;
private:
// Must return a T descendant allocated with new that ResourceOpKernel will
// take ownership of.
virtual Status CreateResource(T** resource) EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
// During the first Compute(), resource is either created or looked up using
// shared_name. In the latter case, the resource found should be verified if
// it is compatible with this op's configuration. The verification may fail in
// cases such as two graphs asking queues of the same shared name to have
// inconsistent capacities.
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_

View File

@ -0,0 +1,202 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/resource_op_kernel.h"
#include <memory>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
// Stub DeviceBase subclass which only returns allocators.
class StubDevice : public DeviceBase {
public:
StubDevice() : DeviceBase(nullptr) {}
Allocator* GetAllocator(AllocatorAttributes) override {
return cpu_allocator();
}
};
// Stub resource for testing resource op kernel.
class StubResource : public ResourceBase {
public:
string DebugString() override { return ""; }
int code;
};
class StubResourceOpKernel : public ResourceOpKernel<StubResource> {
public:
using ResourceOpKernel::ResourceOpKernel;
StubResource* resource() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
return resource_;
}
private:
Status CreateResource(StubResource** resource) override {
*resource = CHECK_NOTNULL(new StubResource);
return GetNodeAttr(def(), "code", &(*resource)->code);
}
Status VerifyResource(StubResource* resource) override {
int code;
TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code));
if (code != resource->code) {
return errors::InvalidArgument("stub has code ", resource->code,
" but requested code ", code);
}
return Status::OK();
}
};
REGISTER_OP("StubResourceOp")
.Attr("code: int")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("output: Ref(string)");
REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU),
StubResourceOpKernel);
class ResourceOpKernelTest : public ::testing::Test {
protected:
std::unique_ptr<StubResourceOpKernel> CreateOp(int code,
const string& shared_name) {
NodeDef node_def;
TF_CHECK_OK(
NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp")
.Attr("code", code)
.Attr("shared_name", shared_name)
.Finalize(&node_def));
Status status;
std::unique_ptr<OpKernel> op(CreateOpKernel(
DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()),
node_def, TF_GRAPH_DEF_VERSION, &status));
TF_EXPECT_OK(status) << status;
EXPECT_TRUE(op != nullptr);
// Downcast to StubResourceOpKernel to call resource() later.
std::unique_ptr<StubResourceOpKernel> resource_op(
dynamic_cast<StubResourceOpKernel*>(op.get()));
EXPECT_TRUE(resource_op != nullptr);
if (resource_op != nullptr) {
op.release();
}
return resource_op;
}
Status RunOpKernel(OpKernel* op) {
OpKernelContext::Params params;
params.device = &device_;
params.resource_manager = &mgr_;
params.op_kernel = op;
OpKernelContext context(&params);
op->Compute(&context);
return context.status();
}
StubDevice device_;
ResourceMgr mgr_;
int count_ = 0;
};
TEST_F(ResourceOpKernelTest, PrivateResource) {
// Empty shared_name means private resource.
const int code = -100;
auto op = CreateOp(code, "");
ASSERT_TRUE(op != nullptr);
TF_EXPECT_OK(RunOpKernel(op.get()));
// Default non-shared name provided from ContainerInfo.
const string key = "_0_" + op->name();
StubResource* resource;
TF_ASSERT_OK(
mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource));
EXPECT_EQ(op->resource(), resource); // Check resource identity.
EXPECT_EQ(code, resource->code); // Check resource stored information.
resource->Unref();
// Destroy the op kernel. Expect the resource to be released.
op = nullptr;
Status s =
mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource);
EXPECT_FALSE(s.ok());
}
TEST_F(ResourceOpKernelTest, SharedResource) {
const string shared_name = "shared_stub";
const int code = -201;
auto op = CreateOp(code, shared_name);
ASSERT_TRUE(op != nullptr);
TF_EXPECT_OK(RunOpKernel(op.get()));
StubResource* resource;
TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
&resource));
EXPECT_EQ(op->resource(), resource); // Check resource identity.
EXPECT_EQ(code, resource->code); // Check resource stored information.
resource->Unref();
// Destroy the op kernel. Expect the resource not to be released.
op = nullptr;
TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
&resource));
resource->Unref();
}
TEST_F(ResourceOpKernelTest, LookupShared) {
auto op1 = CreateOp(-333, "shared_stub");
auto op2 = CreateOp(-333, "shared_stub");
ASSERT_TRUE(op1 != nullptr);
ASSERT_TRUE(op2 != nullptr);
TF_EXPECT_OK(RunOpKernel(op1.get()));
TF_EXPECT_OK(RunOpKernel(op2.get()));
EXPECT_EQ(op1->resource(), op2->resource());
}
TEST_F(ResourceOpKernelTest, VerifyResource) {
auto op1 = CreateOp(-444, "shared_stub");
auto op2 = CreateOp(0, "shared_stub"); // Different resource code.
ASSERT_TRUE(op1 != nullptr);
ASSERT_TRUE(op2 != nullptr);
TF_EXPECT_OK(RunOpKernel(op1.get()));
EXPECT_FALSE(RunOpKernel(op2.get()).ok());
EXPECT_TRUE(op1->resource() != nullptr);
EXPECT_TRUE(op2->resource() == nullptr);
}
} // namespace
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@ -432,14 +433,10 @@ class Barrier : public ResourceBase {
TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
};
class BarrierOp : public OpKernel {
class BarrierOp : public ResourceOpKernel<Barrier> {
public:
explicit BarrierOp(OpKernelConstruction* context)
: OpKernel(context), barrier_handle_set_(false) {
OP_REQUIRES_OK(context,
context->allocate_persistent(tensorflow::DT_STRING,
tensorflow::TensorShape({2}),
&barrier_handle_, nullptr));
: ResourceOpKernel(context) {
OP_REQUIRES_OK(
context, context->GetAttr("component_types", &value_component_types_));
OP_REQUIRES_OK(context,
@ -458,34 +455,19 @@ class BarrierOp : public OpKernel {
"limited capacity."));
}
~BarrierOp() override {
// If the barrier object was not shared, delete it.
if (barrier_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
TF_CHECK_OK(cinfo_.resource_manager()->Delete<Barrier>(cinfo_.container(),
cinfo_.name()));
}
}
void Compute(OpKernelContext* ctx) override {
mutex_lock l(mu_);
if (!barrier_handle_set_) {
OP_REQUIRES_OK(ctx, SetBarrierHandle(ctx));
}
ctx->set_output_ref(0, &mu_, barrier_handle_.AccessTensor(ctx));
}
private:
Status SetBarrierHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
Barrier* barrier = nullptr;
auto creator = [this](Barrier** ret) {
*ret = new Barrier(value_component_types_, value_component_shapes_,
Status CreateResource(Barrier** barrier) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*barrier = new Barrier(value_component_types_, value_component_shapes_,
cinfo_.name());
return (*ret)->Initialize();
};
TF_RETURN_IF_ERROR(cinfo_.resource_manager()->LookupOrCreate<Barrier>(
cinfo_.container(), cinfo_.name(), &barrier, creator));
core::ScopedUnref unref_me(barrier);
if (*barrier == nullptr) {
return errors::ResourceExhausted("Failed to allocate barrier");
}
return (*barrier)->Initialize();
}
Status VerifyResource(Barrier* barrier) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (barrier->component_types() != value_component_types_) {
return errors::InvalidArgument(
"Shared barrier '", cinfo_.name(), "' has component types ",
@ -500,20 +482,11 @@ class BarrierOp : public OpKernel {
" but requested component shapes were ",
TensorShapeUtils::ShapeListString(value_component_shapes_));
}
auto h = barrier_handle_.AccessTensor(ctx)->flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
barrier_handle_set_ = true;
return Status::OK();
}
DataTypeVector value_component_types_;
std::vector<TensorShape> value_component_shapes_;
ContainerInfo cinfo_;
mutex mu_;
PersistentTensor barrier_handle_ GUARDED_BY(mu_);
bool barrier_handle_set_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
};

View File

@ -39,28 +39,20 @@ namespace tensorflow {
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class FIFOQueueOp : public QueueOp {
class FIFOQueueOp : public TypedQueueOp {
public:
explicit FIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) {
explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
protected:
CreatorCallback GetCreator() const override {
return [this](QueueInterface** ret) {
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
Status s = queue->Initialize();
if (s.ok()) {
*ret = queue;
} else {
queue->Unref();
}
return s;
};
return CreateTypedQueue(queue, ret);
}
private:
std::vector<TensorShape> component_shapes_;
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
};

View File

@ -40,10 +40,10 @@ namespace tensorflow {
// backed by PaddingFIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class PaddingFIFOQueueOp : public QueueOp {
class PaddingFIFOQueueOp : public TypedQueueOp {
public:
explicit PaddingFIFOQueueOp(OpKernelConstruction* context)
: QueueOp(context) {
: TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
for (const auto& shape : component_shapes_) {
OP_REQUIRES(context, shape.dims() >= 0,
@ -52,22 +52,14 @@ class PaddingFIFOQueueOp : public QueueOp {
}
}
protected:
CreatorCallback GetCreator() const override {
return [this](QueueInterface** ret) {
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
PaddingFIFOQueue* queue = new PaddingFIFOQueue(
capacity_, component_types_, component_shapes_, cinfo_.name());
Status s = queue->Initialize();
if (s.ok()) {
*ret = queue;
} else {
queue->Unref();
}
return s;
};
return CreateTypedQueue(queue, ret);
}
private:
std::vector<PartialTensorShape> component_shapes_;
TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueueOp);

View File

@ -38,9 +38,10 @@ namespace tensorflow {
// backed by PriorityQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class PriorityQueueOp : public QueueOp {
class PriorityQueueOp : public TypedQueueOp {
public:
explicit PriorityQueueOp(OpKernelConstruction* context) : QueueOp(context) {
explicit PriorityQueueOp(OpKernelConstruction* context)
: TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
component_types_.insert(component_types_.begin(), DT_INT64);
if (!component_shapes_.empty()) {
@ -48,19 +49,12 @@ class PriorityQueueOp : public QueueOp {
}
}
protected:
CreatorCallback GetCreator() const override {
return [this](QueueInterface** ret) {
PriorityQueue* queue = new PriorityQueue(
capacity_, component_types_, component_shapes_, cinfo_.name());
Status s = queue->Initialize();
if (s.ok()) {
*ret = queue;
} else {
queue->Unref();
}
return s;
};
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
PriorityQueue* queue = new PriorityQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
return CreateTypedQueue(queue, ret);
}
std::vector<TensorShape> component_shapes_;

View File

@ -19,28 +19,21 @@ limitations under the License.
#include <deque>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/queue_base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Defines a QueueOp, an abstract class for Queue construction ops.
class QueueOp : public OpKernel {
class QueueOp : public ResourceOpKernel<QueueInterface> {
public:
QueueOp(OpKernelConstruction* context)
: OpKernel(context), queue_handle_set_(false) {
QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&queue_handle_, nullptr));
if (capacity_ < 0) {
capacity_ = QueueBase::kUnbounded;
}
@ -48,55 +41,30 @@ class QueueOp : public OpKernel {
context->GetAttr("component_types", &component_types_));
}
void Compute(OpKernelContext* ctx) override {
mutex_lock l(mu_);
if (!queue_handle_set_) {
OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
}
ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
}
protected:
~QueueOp() override {
// If the queue object was not shared, delete it.
if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
cinfo_.container(), cinfo_.name()));
}
}
protected:
typedef std::function<Status(QueueInterface**)> CreatorCallback;
// Subclasses must override this
virtual CreatorCallback GetCreator() const = 0;
// Variables accessible by subclasses
int32 capacity_;
DataTypeVector component_types_;
ContainerInfo cinfo_;
private:
Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
CreatorCallback creator = GetCreator();
QueueInterface* queue;
TF_RETURN_IF_ERROR(
cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
cinfo_.container(), cinfo_.name(), &queue, creator));
core::ScopedUnref unref_me(queue);
// Verify that the shared queue is compatible with the requested arguments.
TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
queue_handle_set_ = true;
return Status::OK();
Status VerifyResource(QueueInterface* queue) override {
return queue->MatchesNodeDef(def());
}
};
mutex mu_;
PersistentTensor queue_handle_ GUARDED_BY(mu_);
bool queue_handle_set_ GUARDED_BY(mu_);
class TypedQueueOp : public QueueOp {
public:
using QueueOp::QueueOp;
protected:
template <typename TypedQueue>
Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) {
if (queue == nullptr) {
return errors::ResourceExhausted("Failed to allocate queue.");
}
*ret = queue;
return queue->Initialize();
}
};
} // namespace tensorflow

View File

@ -459,10 +459,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
// backed by RandomShuffleQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class RandomShuffleQueueOp : public QueueOp {
class RandomShuffleQueueOp : public TypedQueueOp {
public:
explicit RandomShuffleQueueOp(OpKernelConstruction* context)
: QueueOp(context) {
: TypedQueueOp(context) {
OP_REQUIRES_OK(context,
context->GetAttr("min_after_dequeue", &min_after_dequeue_));
OP_REQUIRES(context, min_after_dequeue_ >= 0,
@ -478,23 +478,15 @@ class RandomShuffleQueueOp : public QueueOp {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
protected:
CreatorCallback GetCreator() const override {
return [this](QueueInterface** ret) {
auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
seed2_, component_types_,
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
RandomShuffleQueue* queue = new RandomShuffleQueue(
capacity_, min_after_dequeue_, seed_, seed2_, component_types_,
component_shapes_, cinfo_.name());
Status s = q->Initialize();
if (s.ok()) {
*ret = q;
} else {
q->Unref();
}
return s;
};
return CreateTypedQueue(queue, ret);
}
private:
int32 min_after_dequeue_;
int64 seed_;
int64 seed2_;