Refactor resource ops. There are many ops producing different kinds of resources, and the same code pattern is repeated in several places.
Change: 141323281
This commit is contained in:
parent
db2a81a82c
commit
8f706abe75
@ -323,6 +323,7 @@ tf_cuda_library(
|
||||
"framework/reader_op_kernel.h",
|
||||
"framework/register_types.h",
|
||||
"framework/resource_mgr.h",
|
||||
"framework/resource_op_kernel.h",
|
||||
"framework/selective_registration.h",
|
||||
"framework/session_state.h",
|
||||
"framework/shape_inference.h",
|
||||
@ -1705,6 +1706,7 @@ tf_cc_tests(
|
||||
"framework/partial_tensor_shape_test.cc",
|
||||
# "framework/rendezvous_test.cc", # flaky b/30476344
|
||||
"framework/resource_mgr_test.cc",
|
||||
"framework/resource_op_kernel_test.cc",
|
||||
"framework/shape_inference_test.cc",
|
||||
"framework/shape_inference_testutil_test.cc",
|
||||
"framework/tensor_shape_test.cc",
|
||||
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context)
|
||||
: OpKernel(context), have_handle_(false) {
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
tensorflow::DT_STRING,
|
||||
tensorflow::TensorShape({2}), &handle_, nullptr));
|
||||
}
|
||||
|
||||
ReaderOpKernel::~ReaderOpKernel() {
|
||||
if (have_handle_ && cinfo_.resource_is_private_to_kernel()) {
|
||||
TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>(
|
||||
cinfo_.container(), cinfo_.name()));
|
||||
}
|
||||
}
|
||||
|
||||
void ReaderOpKernel::Compute(OpKernelContext* ctx) {
|
||||
mutex_lock l(mu_);
|
||||
if (!have_handle_) {
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false));
|
||||
ReaderInterface* reader;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>(
|
||||
cinfo_.container(), cinfo_.name(), &reader,
|
||||
[this](ReaderInterface** ret) {
|
||||
*ret = factory_();
|
||||
return Status::OK();
|
||||
}));
|
||||
reader->Unref();
|
||||
auto h = handle_.AccessTensor(ctx)->flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
have_handle_ = true;
|
||||
}
|
||||
ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -22,35 +22,44 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/reader_interface.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implementation for ops providing a Reader.
|
||||
class ReaderOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ReaderOpKernel(OpKernelConstruction* context);
|
||||
~ReaderOpKernel() override;
|
||||
// NOTE: This is now a very thin layer over ResourceOpKernel.
|
||||
// TODO(sjhwang): Remove dependencies to this class, then delete this.
|
||||
|
||||
void Compute(OpKernelContext* context) override;
|
||||
// Implementation for ops providing a Reader.
|
||||
class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
|
||||
public:
|
||||
using ResourceOpKernel::ResourceOpKernel;
|
||||
|
||||
// Must be called by descendants before the first call to Compute()
|
||||
// (typically called during construction). factory must return a
|
||||
// ReaderInterface descendant allocated with new that ReaderOpKernel
|
||||
// will take ownership of.
|
||||
void SetReaderFactory(std::function<ReaderInterface*()> factory) {
|
||||
void SetReaderFactory(std::function<ReaderInterface*()> factory)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(!have_handle_);
|
||||
DCHECK(resource_ == nullptr);
|
||||
factory_ = factory;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
bool have_handle_ GUARDED_BY(mu_);
|
||||
PersistentTensor handle_ GUARDED_BY(mu_);
|
||||
ContainerInfo cinfo_;
|
||||
std::function<ReaderInterface*()> factory_;
|
||||
Status CreateResource(ReaderInterface** reader)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
|
||||
*reader = factory_();
|
||||
if (*reader == nullptr) {
|
||||
return errors::ResourceExhausted("Failed to allocate reader");
|
||||
}
|
||||
std::function<ReaderInterface*()> temp = nullptr;
|
||||
factory_.swap(temp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::function<ReaderInterface*()> factory_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
119
tensorflow/core/framework/resource_op_kernel.h
Normal file
119
tensorflow/core/framework/resource_op_kernel.h
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
|
||||
#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// ResourceOpKernel<T> is a virtual base class for resource op implementing
|
||||
// interface type T. The inherited op looks up the resource name (determined by
|
||||
// ContainerInfo), and creates a new resource if necessary.
|
||||
//
|
||||
// Requirements:
|
||||
// - Op must be marked as stateful.
|
||||
// - Op must have `container` and `shared_name` attributes. Empty `container`
|
||||
// means using the default container. Empty `shared_name` means private
|
||||
// resource.
|
||||
// - Subclass must override CreateResource().
|
||||
// - Subclass is encouraged to override VerifyResource().
|
||||
template <typename T>
|
||||
class ResourceOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_STRING, TensorShape({2}),
|
||||
&handle_, nullptr));
|
||||
}
|
||||
|
||||
// The resource is deleted from the resource manager only when it is private
|
||||
// to kernel. Ideally the resource should be deleted when it is no longer held
|
||||
// by anyone, but it would break backward compatibility.
|
||||
~ResourceOpKernel() override {
|
||||
if (resource_ != nullptr) {
|
||||
resource_->Unref();
|
||||
if (cinfo_.resource_is_private_to_kernel()) {
|
||||
TF_CHECK_OK(cinfo_.resource_manager()->template Delete<T>(
|
||||
cinfo_.container(), cinfo_.name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
if (resource_ == nullptr) {
|
||||
ResourceMgr* mgr = context->resource_manager();
|
||||
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
|
||||
|
||||
T* resource;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
mgr->LookupOrCreate<T>(cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
Status s = CreateResource(ret);
|
||||
if (!s.ok() && *ret != nullptr) {
|
||||
CHECK((*ret)->Unref());
|
||||
}
|
||||
return s;
|
||||
}));
|
||||
|
||||
Status s = VerifyResource(resource);
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
resource->Unref();
|
||||
context->SetStatus(s);
|
||||
return;
|
||||
}
|
||||
|
||||
auto h = handle_.AccessTensor(context)->template flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
resource_ = resource;
|
||||
}
|
||||
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
|
||||
}
|
||||
|
||||
protected:
|
||||
// Variables accessible from subclasses.
|
||||
mutex mu_;
|
||||
ContainerInfo cinfo_ GUARDED_BY(mu_);
|
||||
T* resource_ GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
private:
|
||||
// Must return a T descendant allocated with new that ResourceOpKernel will
|
||||
// take ownership of.
|
||||
virtual Status CreateResource(T** resource) EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
|
||||
|
||||
// During the first Compute(), resource is either created or looked up using
|
||||
// shared_name. In the latter case, the resource found should be verified if
|
||||
// it is compatible with this op's configuration. The verification may fail in
|
||||
// cases such as two graphs asking queues of the same shared name to have
|
||||
// inconsistent capacities.
|
||||
virtual Status VerifyResource(T* resource) { return Status::OK(); }
|
||||
|
||||
PersistentTensor handle_ GUARDED_BY(mu_);
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
|
202
tensorflow/core/framework/resource_op_kernel_test.cc
Normal file
202
tensorflow/core/framework/resource_op_kernel_test.cc
Normal file
@ -0,0 +1,202 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Stub DeviceBase subclass which only returns allocators.
|
||||
class StubDevice : public DeviceBase {
|
||||
public:
|
||||
StubDevice() : DeviceBase(nullptr) {}
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
};
|
||||
|
||||
// Stub resource for testing resource op kernel.
|
||||
class StubResource : public ResourceBase {
|
||||
public:
|
||||
string DebugString() override { return ""; }
|
||||
int code;
|
||||
};
|
||||
|
||||
class StubResourceOpKernel : public ResourceOpKernel<StubResource> {
|
||||
public:
|
||||
using ResourceOpKernel::ResourceOpKernel;
|
||||
|
||||
StubResource* resource() LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock lock(mu_);
|
||||
return resource_;
|
||||
}
|
||||
|
||||
private:
|
||||
Status CreateResource(StubResource** resource) override {
|
||||
*resource = CHECK_NOTNULL(new StubResource);
|
||||
return GetNodeAttr(def(), "code", &(*resource)->code);
|
||||
}
|
||||
|
||||
Status VerifyResource(StubResource* resource) override {
|
||||
int code;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code));
|
||||
if (code != resource->code) {
|
||||
return errors::InvalidArgument("stub has code ", resource->code,
|
||||
" but requested code ", code);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("StubResourceOp")
|
||||
.Attr("code: int")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Output("output: Ref(string)");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU),
|
||||
StubResourceOpKernel);
|
||||
|
||||
class ResourceOpKernelTest : public ::testing::Test {
|
||||
protected:
|
||||
std::unique_ptr<StubResourceOpKernel> CreateOp(int code,
|
||||
const string& shared_name) {
|
||||
NodeDef node_def;
|
||||
TF_CHECK_OK(
|
||||
NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp")
|
||||
.Attr("code", code)
|
||||
.Attr("shared_name", shared_name)
|
||||
.Finalize(&node_def));
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> op(CreateOpKernel(
|
||||
DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()),
|
||||
node_def, TF_GRAPH_DEF_VERSION, &status));
|
||||
TF_EXPECT_OK(status) << status;
|
||||
EXPECT_TRUE(op != nullptr);
|
||||
|
||||
// Downcast to StubResourceOpKernel to call resource() later.
|
||||
std::unique_ptr<StubResourceOpKernel> resource_op(
|
||||
dynamic_cast<StubResourceOpKernel*>(op.get()));
|
||||
EXPECT_TRUE(resource_op != nullptr);
|
||||
if (resource_op != nullptr) {
|
||||
op.release();
|
||||
}
|
||||
return resource_op;
|
||||
}
|
||||
|
||||
Status RunOpKernel(OpKernel* op) {
|
||||
OpKernelContext::Params params;
|
||||
|
||||
params.device = &device_;
|
||||
params.resource_manager = &mgr_;
|
||||
params.op_kernel = op;
|
||||
|
||||
OpKernelContext context(¶ms);
|
||||
op->Compute(&context);
|
||||
return context.status();
|
||||
}
|
||||
|
||||
StubDevice device_;
|
||||
ResourceMgr mgr_;
|
||||
int count_ = 0;
|
||||
};
|
||||
|
||||
TEST_F(ResourceOpKernelTest, PrivateResource) {
|
||||
// Empty shared_name means private resource.
|
||||
const int code = -100;
|
||||
auto op = CreateOp(code, "");
|
||||
ASSERT_TRUE(op != nullptr);
|
||||
TF_EXPECT_OK(RunOpKernel(op.get()));
|
||||
|
||||
// Default non-shared name provided from ContainerInfo.
|
||||
const string key = "_0_" + op->name();
|
||||
|
||||
StubResource* resource;
|
||||
TF_ASSERT_OK(
|
||||
mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource));
|
||||
EXPECT_EQ(op->resource(), resource); // Check resource identity.
|
||||
EXPECT_EQ(code, resource->code); // Check resource stored information.
|
||||
resource->Unref();
|
||||
|
||||
// Destroy the op kernel. Expect the resource to be released.
|
||||
op = nullptr;
|
||||
Status s =
|
||||
mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource);
|
||||
|
||||
EXPECT_FALSE(s.ok());
|
||||
}
|
||||
|
||||
TEST_F(ResourceOpKernelTest, SharedResource) {
|
||||
const string shared_name = "shared_stub";
|
||||
const int code = -201;
|
||||
auto op = CreateOp(code, shared_name);
|
||||
ASSERT_TRUE(op != nullptr);
|
||||
TF_EXPECT_OK(RunOpKernel(op.get()));
|
||||
|
||||
StubResource* resource;
|
||||
TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
|
||||
&resource));
|
||||
EXPECT_EQ(op->resource(), resource); // Check resource identity.
|
||||
EXPECT_EQ(code, resource->code); // Check resource stored information.
|
||||
resource->Unref();
|
||||
|
||||
// Destroy the op kernel. Expect the resource not to be released.
|
||||
op = nullptr;
|
||||
TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name,
|
||||
&resource));
|
||||
resource->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ResourceOpKernelTest, LookupShared) {
|
||||
auto op1 = CreateOp(-333, "shared_stub");
|
||||
auto op2 = CreateOp(-333, "shared_stub");
|
||||
ASSERT_TRUE(op1 != nullptr);
|
||||
ASSERT_TRUE(op2 != nullptr);
|
||||
|
||||
TF_EXPECT_OK(RunOpKernel(op1.get()));
|
||||
TF_EXPECT_OK(RunOpKernel(op2.get()));
|
||||
EXPECT_EQ(op1->resource(), op2->resource());
|
||||
}
|
||||
|
||||
TEST_F(ResourceOpKernelTest, VerifyResource) {
|
||||
auto op1 = CreateOp(-444, "shared_stub");
|
||||
auto op2 = CreateOp(0, "shared_stub"); // Different resource code.
|
||||
ASSERT_TRUE(op1 != nullptr);
|
||||
ASSERT_TRUE(op2 != nullptr);
|
||||
|
||||
TF_EXPECT_OK(RunOpKernel(op1.get()));
|
||||
EXPECT_FALSE(RunOpKernel(op2.get()).ok());
|
||||
EXPECT_TRUE(op1->resource() != nullptr);
|
||||
EXPECT_TRUE(op2->resource() == nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -432,14 +433,10 @@ class Barrier : public ResourceBase {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
|
||||
};
|
||||
|
||||
class BarrierOp : public OpKernel {
|
||||
class BarrierOp : public ResourceOpKernel<Barrier> {
|
||||
public:
|
||||
explicit BarrierOp(OpKernelConstruction* context)
|
||||
: OpKernel(context), barrier_handle_set_(false) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(tensorflow::DT_STRING,
|
||||
tensorflow::TensorShape({2}),
|
||||
&barrier_handle_, nullptr));
|
||||
: ResourceOpKernel(context) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("component_types", &value_component_types_));
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -458,34 +455,19 @@ class BarrierOp : public OpKernel {
|
||||
"limited capacity."));
|
||||
}
|
||||
|
||||
~BarrierOp() override {
|
||||
// If the barrier object was not shared, delete it.
|
||||
if (barrier_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||
TF_CHECK_OK(cinfo_.resource_manager()->Delete<Barrier>(cinfo_.container(),
|
||||
cinfo_.name()));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!barrier_handle_set_) {
|
||||
OP_REQUIRES_OK(ctx, SetBarrierHandle(ctx));
|
||||
}
|
||||
ctx->set_output_ref(0, &mu_, barrier_handle_.AccessTensor(ctx));
|
||||
}
|
||||
|
||||
private:
|
||||
Status SetBarrierHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
|
||||
Barrier* barrier = nullptr;
|
||||
auto creator = [this](Barrier** ret) {
|
||||
*ret = new Barrier(value_component_types_, value_component_shapes_,
|
||||
Status CreateResource(Barrier** barrier) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*barrier = new Barrier(value_component_types_, value_component_shapes_,
|
||||
cinfo_.name());
|
||||
return (*ret)->Initialize();
|
||||
};
|
||||
TF_RETURN_IF_ERROR(cinfo_.resource_manager()->LookupOrCreate<Barrier>(
|
||||
cinfo_.container(), cinfo_.name(), &barrier, creator));
|
||||
core::ScopedUnref unref_me(barrier);
|
||||
if (*barrier == nullptr) {
|
||||
return errors::ResourceExhausted("Failed to allocate barrier");
|
||||
}
|
||||
return (*barrier)->Initialize();
|
||||
}
|
||||
|
||||
Status VerifyResource(Barrier* barrier) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (barrier->component_types() != value_component_types_) {
|
||||
return errors::InvalidArgument(
|
||||
"Shared barrier '", cinfo_.name(), "' has component types ",
|
||||
@ -500,20 +482,11 @@ class BarrierOp : public OpKernel {
|
||||
" but requested component shapes were ",
|
||||
TensorShapeUtils::ShapeListString(value_component_shapes_));
|
||||
}
|
||||
auto h = barrier_handle_.AccessTensor(ctx)->flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
barrier_handle_set_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DataTypeVector value_component_types_;
|
||||
std::vector<TensorShape> value_component_shapes_;
|
||||
ContainerInfo cinfo_;
|
||||
|
||||
mutex mu_;
|
||||
PersistentTensor barrier_handle_ GUARDED_BY(mu_);
|
||||
bool barrier_handle_set_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
|
||||
};
|
||||
|
@ -39,28 +39,20 @@ namespace tensorflow {
|
||||
// backed by FIFOQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class FIFOQueueOp : public QueueOp {
|
||||
class FIFOQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit FIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) {
|
||||
explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
CreatorCallback GetCreator() const override {
|
||||
return [this](QueueInterface** ret) {
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
|
||||
component_shapes_, cinfo_.name());
|
||||
Status s = queue->Initialize();
|
||||
if (s.ok()) {
|
||||
*ret = queue;
|
||||
} else {
|
||||
queue->Unref();
|
||||
}
|
||||
return s;
|
||||
};
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<TensorShape> component_shapes_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
|
||||
};
|
||||
|
@ -40,10 +40,10 @@ namespace tensorflow {
|
||||
// backed by PaddingFIFOQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class PaddingFIFOQueueOp : public QueueOp {
|
||||
class PaddingFIFOQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit PaddingFIFOQueueOp(OpKernelConstruction* context)
|
||||
: QueueOp(context) {
|
||||
: TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
for (const auto& shape : component_shapes_) {
|
||||
OP_REQUIRES(context, shape.dims() >= 0,
|
||||
@ -52,22 +52,14 @@ class PaddingFIFOQueueOp : public QueueOp {
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
CreatorCallback GetCreator() const override {
|
||||
return [this](QueueInterface** ret) {
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
PaddingFIFOQueue* queue = new PaddingFIFOQueue(
|
||||
capacity_, component_types_, component_shapes_, cinfo_.name());
|
||||
Status s = queue->Initialize();
|
||||
if (s.ok()) {
|
||||
*ret = queue;
|
||||
} else {
|
||||
queue->Unref();
|
||||
}
|
||||
return s;
|
||||
};
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PartialTensorShape> component_shapes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueueOp);
|
||||
|
@ -38,9 +38,10 @@ namespace tensorflow {
|
||||
// backed by PriorityQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class PriorityQueueOp : public QueueOp {
|
||||
class PriorityQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit PriorityQueueOp(OpKernelConstruction* context) : QueueOp(context) {
|
||||
explicit PriorityQueueOp(OpKernelConstruction* context)
|
||||
: TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
component_types_.insert(component_types_.begin(), DT_INT64);
|
||||
if (!component_shapes_.empty()) {
|
||||
@ -48,19 +49,12 @@ class PriorityQueueOp : public QueueOp {
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
CreatorCallback GetCreator() const override {
|
||||
return [this](QueueInterface** ret) {
|
||||
PriorityQueue* queue = new PriorityQueue(
|
||||
capacity_, component_types_, component_shapes_, cinfo_.name());
|
||||
Status s = queue->Initialize();
|
||||
if (s.ok()) {
|
||||
*ret = queue;
|
||||
} else {
|
||||
queue->Unref();
|
||||
}
|
||||
return s;
|
||||
};
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
PriorityQueue* queue = new PriorityQueue(capacity_, component_types_,
|
||||
component_shapes_, cinfo_.name());
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
std::vector<TensorShape> component_shapes_;
|
||||
|
@ -19,28 +19,21 @@ limitations under the License.
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/queue_base.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Defines a QueueOp, an abstract class for Queue construction ops.
|
||||
class QueueOp : public OpKernel {
|
||||
class QueueOp : public ResourceOpKernel<QueueInterface> {
|
||||
public:
|
||||
QueueOp(OpKernelConstruction* context)
|
||||
: OpKernel(context), queue_handle_set_(false) {
|
||||
QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_STRING, TensorShape({2}),
|
||||
&queue_handle_, nullptr));
|
||||
if (capacity_ < 0) {
|
||||
capacity_ = QueueBase::kUnbounded;
|
||||
}
|
||||
@ -48,55 +41,30 @@ class QueueOp : public OpKernel {
|
||||
context->GetAttr("component_types", &component_types_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!queue_handle_set_) {
|
||||
OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
|
||||
}
|
||||
ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
|
||||
}
|
||||
|
||||
protected:
|
||||
~QueueOp() override {
|
||||
// If the queue object was not shared, delete it.
|
||||
if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||
TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
|
||||
cinfo_.container(), cinfo_.name()));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
typedef std::function<Status(QueueInterface**)> CreatorCallback;
|
||||
|
||||
// Subclasses must override this
|
||||
virtual CreatorCallback GetCreator() const = 0;
|
||||
|
||||
// Variables accessible by subclasses
|
||||
int32 capacity_;
|
||||
DataTypeVector component_types_;
|
||||
ContainerInfo cinfo_;
|
||||
|
||||
private:
|
||||
Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
|
||||
CreatorCallback creator = GetCreator();
|
||||
QueueInterface* queue;
|
||||
TF_RETURN_IF_ERROR(
|
||||
cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
|
||||
cinfo_.container(), cinfo_.name(), &queue, creator));
|
||||
core::ScopedUnref unref_me(queue);
|
||||
// Verify that the shared queue is compatible with the requested arguments.
|
||||
TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
|
||||
auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
queue_handle_set_ = true;
|
||||
return Status::OK();
|
||||
Status VerifyResource(QueueInterface* queue) override {
|
||||
return queue->MatchesNodeDef(def());
|
||||
}
|
||||
};
|
||||
|
||||
mutex mu_;
|
||||
PersistentTensor queue_handle_ GUARDED_BY(mu_);
|
||||
bool queue_handle_set_ GUARDED_BY(mu_);
|
||||
class TypedQueueOp : public QueueOp {
|
||||
public:
|
||||
using QueueOp::QueueOp;
|
||||
|
||||
protected:
|
||||
template <typename TypedQueue>
|
||||
Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) {
|
||||
if (queue == nullptr) {
|
||||
return errors::ResourceExhausted("Failed to allocate queue.");
|
||||
}
|
||||
*ret = queue;
|
||||
return queue->Initialize();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -459,10 +459,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
|
||||
// backed by RandomShuffleQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class RandomShuffleQueueOp : public QueueOp {
|
||||
class RandomShuffleQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit RandomShuffleQueueOp(OpKernelConstruction* context)
|
||||
: QueueOp(context) {
|
||||
: TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("min_after_dequeue", &min_after_dequeue_));
|
||||
OP_REQUIRES(context, min_after_dequeue_ >= 0,
|
||||
@ -478,23 +478,15 @@ class RandomShuffleQueueOp : public QueueOp {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
CreatorCallback GetCreator() const override {
|
||||
return [this](QueueInterface** ret) {
|
||||
auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
|
||||
seed2_, component_types_,
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
RandomShuffleQueue* queue = new RandomShuffleQueue(
|
||||
capacity_, min_after_dequeue_, seed_, seed2_, component_types_,
|
||||
component_shapes_, cinfo_.name());
|
||||
Status s = q->Initialize();
|
||||
if (s.ok()) {
|
||||
*ret = q;
|
||||
} else {
|
||||
q->Unref();
|
||||
}
|
||||
return s;
|
||||
};
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
private:
|
||||
int32 min_after_dequeue_;
|
||||
int64 seed_;
|
||||
int64 seed2_;
|
||||
|
Loading…
Reference in New Issue
Block a user