[TF:XLA] Add partial implementation of tf.FIFOQueue for XLA devices (e.g., TPU).

The idea is to have a host-side queue of device tensors.

Operators dequeue_many, enqueue_many, and dequeue_up_to are not yet implemented because they require splitting/concatenating tensors, which will require calling into a compiled XLA compilation.

Refactor queue operator implementations into libraries separate from the kernel registrations.

Add support for ResourceOpKernels that are placed on non-CPU devices. Add support for allocating host-memory tensors during OpKernel construction.

PiperOrigin-RevId: 202590292
This commit is contained in:
Peter Hawkins 2018-06-28 20:24:15 -07:00 committed by TensorFlower Gardener
parent f04400f18f
commit 5083915489
13 changed files with 883 additions and 466 deletions

View File

@ -176,9 +176,11 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",

View File

@ -23,9 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
LoopCondOp);
LoopCondOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
.Device(DEVICE) \
.HostMemory("size") \
.HostMemory("handle"), \
QueueSizeOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
QueueIsClosedOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
// and write the tensors they access in order to concatenate them into a batch.
// We would need either to call out to an XLA computation to perform the
// concatenation, or we would need to refactor those kernels so the splitting
// or merging is done in a separate operator that can be compiled.
} // namespace tensorflow

View File

@ -371,6 +371,20 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "fifo_queue_test",
size = "medium",
srcs = ["fifo_queue_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "fft_test",
size = "medium",

View File

@ -0,0 +1,201 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import test
class FIFOQueueTest(xla_test.XLATestCase):
def testEnqueue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
with self.assertRaises(ValueError):
q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
self.assertEqual(1, q.size().eval())
def testMultipleDequeues(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue([1]))
self.evaluate(q.enqueue([2]))
self.evaluate(q.enqueue([3]))
a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
def testQueuesDontShare(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue(1))
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q2.enqueue(2))
self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})
def testParallelEnqueue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()
# Run one producer thread for each element in elems.
def enqueue(enqueue_op):
sess.run(enqueue_op)
threads = [
self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# Dequeue every element using a single thread.
results = []
for _ in xrange(len(elems)):
results.append(dequeued_t.eval())
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()
# Enqueue every element using a single thread.
for enqueue_op in enqueue_ops:
enqueue_op.run()
# Run one consumer thread for each element in elems.
results = []
def dequeue():
results.append(sess.run(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertItemsEqual(elems, results)
def testDequeue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()
for enqueue_op in enqueue_ops:
enqueue_op.run()
for i in xrange(len(elems)):
vals = dequeued_t.eval()
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()
def enqueue():
# The enqueue_ops should run after the dequeue op has blocked.
# TODO(mrry): Figure out how to do this without sleeping.
time.sleep(0.1)
for enqueue_op in enqueue_ops:
sess.run(enqueue_op)
results = []
def dequeue():
for _ in xrange(len(elems)):
results.append(sess.run(dequeued_t))
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
enqueue_thread.start()
dequeue_thread.start()
enqueue_thread.join()
dequeue_thread.join()
for elem, result in zip(elems, results):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
dequeued_t = q.dequeue()
for enqueue_op in enqueue_ops:
enqueue_op.run()
for i in xrange(len(elems)):
x_val, y_val = sess.run(dequeued_t)
x, y = elems[i]
self.assertEqual([x], x_val)
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
size = q.size()
self.assertEqual([], size.get_shape())
enqueue_op.run()
self.assertEqual(1, size.eval())
dequeued_t.op.run()
self.assertEqual(0, size.eval())
if __name__ == "__main__":
test.main()

View File

@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
tensorflow/core/kernels/roll_op.cc
tensorflow/core/kernels/queue_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc

View File

@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
if (!has_resource_type_) {
// The resource variant of the op may be placed on non-CPU devices, but
// this allocation is always on the host. Fortunately we don't need it in
// the resource case.
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
}
}
// The resource is deleted from the resource manager only when it is private
@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}
auto h = handle_.AccessTensor(context)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
if (!has_resource_type_) {
auto h = handle_.AccessTensor(context)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
}
resource_ = resource;
}
if (context->expected_output_dtype(0) == DT_RESOURCE) {
if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
// Is the output of the operator of type DT_RESOURCE?
bool has_resource_type_;
};
} // namespace tensorflow

View File

@ -368,6 +368,7 @@ cc_library(
cc_library(
name = "queue_op",
srcs = ["queue_op.cc"],
hdrs = ["queue_op.h"],
deps = [
":queue_base",
@ -1885,9 +1886,10 @@ cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
hdrs = ["fifo_queue.h"],
visibility = ["//visibility:private"],
visibility = [":friends"],
deps = [
":queue_base",
":queue_op",
":typed_queue",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -5076,6 +5078,7 @@ filegroup(
"padding_fifo_queue.cc",
"padding_fifo_queue_op.cc",
"queue_base.cc",
"queue_op.cc",
"queue_ops.cc",
"random_op.cc",
"reduction_ops_all.cc",

View File

@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
: TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
return CreateTypedQueue(queue, ret);
}
} // namespace tensorflow

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
};
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class FIFOQueueOp : public TypedQueueOp {
public:
explicit FIFOQueueOp(OpKernelConstruction* context);
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_);
std::vector<TensorShape> component_shapes_;
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_

View File

@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/data_flow_ops.cc.
#include <deque>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/queue_base.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
class FIFOQueueOp : public TypedQueueOp {
public:
explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
private:
Status CreateResource(QueueInterface** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
return CreateTypedQueue(queue, ret);
}
std::vector<TensorShape> component_shapes_;
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
};
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);

View File

@ -0,0 +1,367 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
if (capacity_ < 0) {
capacity_ = QueueBase::kUnbounded;
}
OP_REQUIRES_OK(context,
context->GetAttr("component_types", &component_types_));
}
void QueueOp::Compute(OpKernelContext* context) {
ResourceOpKernel<QueueInterface>::Compute(context);
mutex_lock l(mu_);
if (resource_ && context->track_allocations()) {
context->record_persistent_memory_allocation(resource_->MemoryUsed());
}
}
Status QueueOp::VerifyResource(QueueInterface* queue) {
return queue->MatchesNodeDef(def());
}
QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
QueueInterface* queue;
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
callback);
}
ComputeAsync(ctx, queue, [callback, queue]() {
queue->Unref();
callback();
});
}
QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
: QueueOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
// TODO(keveman): Enable timeout.
OP_REQUIRES(context, timeout_ == -1,
errors::InvalidArgument("Timeout not supported yet."));
}
// Defines an EnqueueOp, the execution of which enqueues a tuple of
// tensors in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
EnqueueOp::EnqueueOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
QueueInterface::Tuple tuple;
OpInputList components;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
callback);
for (const Tensor& Tcomponent : components) {
tuple.push_back(Tcomponent);
}
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
queue->TryEnqueue(tuple, ctx, callback);
}
// Defines an EnqueueManyOp, the execution of which slices each
// component of a tuple of tensors along the 0th dimension, and
// enqueues tuples of slices in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
//
// N.B. All tuple components must have the same size in the 0th
// dimension.
EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
QueueInterface::Tuple tuple;
OpInputList components;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
callback);
for (const Tensor& Tcomponent : components) {
tuple.push_back(Tcomponent);
}
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
queue->TryEnqueueMany(tuple, ctx, callback);
}
EnqueueManyOp::~EnqueueManyOp() = default;
// Defines a DequeueOp, the execution of which dequeues a tuple of
// tensors from the given Queue.
//
// The op has one input, which is the handle of the appropriate
// Queue. The op has k outputs, where k is the number of components in
// the tuples stored in the given Queue, and output i is the ith
// component of the dequeued tuple.
DequeueOp::DequeueOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
callback);
}
queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components), callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
DequeueOp::~DequeueOp() = default;
// Defines a DequeueManyOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
const Tensor& Tnum_elements = ctx->input(1);
int32 num_elements = Tnum_elements.flat<int32>()(0);
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
errors::InvalidArgument("DequeueManyOp requested ",
num_elements, " < 0 elements"),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, false /* allow_small_batch */,
[ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components), callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
DequeueManyOp::~DequeueManyOp() = default;
// Defines a DequeueUpToOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The difference between this op and DequeueMany is the handling when
// the Queue is closed. While the DequeueMany op will return if there
// an error when there are less than num_elements elements left in the
// closed queue, this op will return between 1 and
// min(num_elements, elements_remaining_in_queue), and will not block.
// If there are no elements left, then the standard DequeueMany error
// is returned.
//
// This op only works if the underlying Queue implementation accepts
// the allow_small_batch = true parameter to TryDequeueMany.
// If it does not, an errors::Unimplemented exception is returned.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
//
// The op has one attribute: allow_small_batch. If the Queue supports
// it, setting this to true causes the queue to return smaller
// (possibly zero length) batches when it is closed, up to however
// many elements are available when the op executes. In this case,
// the Queue does not block when closed.
DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
const Tensor& Tnum_elements = ctx->input(1);
int32 num_elements = Tnum_elements.flat<int32>()(0);
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
errors::InvalidArgument("DequeueUpToOp requested ",
num_elements, " < 0 elements"),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, true /* allow_small_batch */,
[ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components), callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
DequeueUpToOp::~DequeueUpToOp() = default;
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
//
// The op has one input, which is the handle of the appropriate Queue.
QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
: QueueOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
&cancel_pending_enqueues_));
}
void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
queue->Close(ctx, cancel_pending_enqueues_, callback);
}
// Defines a QueueSizeOp, which computes the number of elements in the
// given Queue, and emits it as an output tensor.
//
// The op has one input, which is the handle of the appropriate Queue;
// and one output, which is a single-element tensor containing the current
// size of that Queue.
QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
: QueueOpKernel(context) {}
void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
Tensor* Tqueue_size = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
Tqueue_size->flat<int32>().setConstant(queue->size());
callback();
}
QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
: QueueOpKernel(context) {}
void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) {
Tensor* Tqueue_is_closed = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
callback();
}
} // namespace tensorflow

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
#define TENSORFLOW_KERNELS_QUEUE_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
#include <deque>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@ -32,22 +33,9 @@ namespace tensorflow {
// Defines a QueueOp, an abstract class for Queue construction ops.
class QueueOp : public ResourceOpKernel<QueueInterface> {
public:
QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
if (capacity_ < 0) {
capacity_ = QueueBase::kUnbounded;
}
OP_REQUIRES_OK(context,
context->GetAttr("component_types", &component_types_));
}
QueueOp(OpKernelConstruction* context);
void Compute(OpKernelContext* context) override {
ResourceOpKernel<QueueInterface>::Compute(context);
mutex_lock l(mu_);
if (resource_ && context->track_allocations()) {
context->record_persistent_memory_allocation(resource_->MemoryUsed());
}
}
void Compute(OpKernelContext* context) override;
protected:
// Variables accessible by subclasses
@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
DataTypeVector component_types_;
private:
Status VerifyResource(QueueInterface* queue) override {
return queue->MatchesNodeDef(def());
}
Status VerifyResource(QueueInterface* queue) override;
};
class TypedQueueOp : public QueueOp {
@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
}
};
// Queue manipulator kernels
class QueueOpKernel : public AsyncOpKernel {
public:
explicit QueueOpKernel(OpKernelConstruction* context);
void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
protected:
virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) = 0;
};
class QueueAccessOpKernel : public QueueOpKernel {
public:
explicit QueueAccessOpKernel(OpKernelConstruction* context);
protected:
int64 timeout_;
};
// Defines an EnqueueOp, the execution of which enqueues a tuple of
// tensors in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
class EnqueueOp : public QueueAccessOpKernel {
public:
explicit EnqueueOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
};
// Defines an EnqueueManyOp, the execution of which slices each
// component of a tuple of tensors along the 0th dimension, and
// enqueues tuples of slices in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
//
// N.B. All tuple components must have the same size in the 0th
// dimension.
class EnqueueManyOp : public QueueAccessOpKernel {
public:
explicit EnqueueManyOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
~EnqueueManyOp() override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
};
// Defines a DequeueOp, the execution of which dequeues a tuple of
// tensors from the given Queue.
//
// The op has one input, which is the handle of the appropriate
// Queue. The op has k outputs, where k is the number of components in
// the tuples stored in the given Queue, and output i is the ith
// component of the dequeued tuple.
class DequeueOp : public QueueAccessOpKernel {
public:
explicit DequeueOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
~DequeueOp() override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
};
// Defines a DequeueManyOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
class DequeueManyOp : public QueueAccessOpKernel {
public:
explicit DequeueManyOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
~DequeueManyOp() override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
};
// Defines a DequeueUpToOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The difference between this op and DequeueMany is the handling when
// the Queue is closed. While the DequeueMany op will return if there
// an error when there are less than num_elements elements left in the
// closed queue, this op will return between 1 and
// min(num_elements, elements_remaining_in_queue), and will not block.
// If there are no elements left, then the standard DequeueMany error
// is returned.
//
// This op only works if the underlying Queue implementation accepts
// the allow_small_batch = true parameter to TryDequeueMany.
// If it does not, an errors::Unimplemented exception is returned.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
//
// The op has one attribute: allow_small_batch. If the Queue supports
// it, setting this to true causes the queue to return smaller
// (possibly zero length) batches when it is closed, up to however
// many elements are available when the op executes. In this case,
// the Queue does not block when closed.
class DequeueUpToOp : public QueueAccessOpKernel {
public:
explicit DequeueUpToOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
~DequeueUpToOp() override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
};
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
//
// The op has one input, which is the handle of the appropriate Queue.
class QueueCloseOp : public QueueOpKernel {
public:
explicit QueueCloseOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
private:
bool cancel_pending_enqueues_;
TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
};
// Defines a QueueSizeOp, which computes the number of elements in the
// given Queue, and emits it as an output tensor.
//
// The op has one input, which is the handle of the appropriate Queue;
// and one output, which is a single-element tensor containing the current
// size of that Queue.
class QueueSizeOp : public QueueOpKernel {
public:
explicit QueueSizeOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
};
class QueueIsClosedOp : public QueueOpKernel {
public:
explicit QueueIsClosedOp(OpKernelConstruction* context);
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_

View File

@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/data_flow_ops.cc.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class QueueOpKernel : public AsyncOpKernel {
public:
explicit QueueOpKernel(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
QueueInterface* queue;
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
callback);
}
ComputeAsync(ctx, queue, [callback, queue]() {
queue->Unref();
callback();
});
}
protected:
virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) = 0;
};
class QueueAccessOpKernel : public QueueOpKernel {
public:
explicit QueueAccessOpKernel(OpKernelConstruction* context)
: QueueOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
// TODO(keveman): Enable timeout.
OP_REQUIRES(context, timeout_ == -1,
errors::InvalidArgument("Timeout not supported yet."));
}
protected:
int64 timeout_;
};
// Defines an EnqueueOp, the execution of which enqueues a tuple of
// tensors in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
class EnqueueOp : public QueueAccessOpKernel {
public:
explicit EnqueueOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
callback);
QueueInterface::Tuple tuple;
OpInputList components;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
callback);
for (const Tensor& Tcomponent : components) {
tuple.push_back(Tcomponent);
}
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
queue->TryEnqueue(tuple, ctx, callback);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
// Defines an EnqueueManyOp, the execution of which slices each
// component of a tuple of tensors along the 0th dimension, and
// enqueues tuples of slices in the given Queue.
//
// The op has 1 + k inputs, where k is the number of components in the
// tuples stored in the given Queue:
// - Input 0: queue handle.
// - Input 1: 0th element of the tuple.
// - ...
// - Input (1+k): kth element of the tuple.
//
// N.B. All tuple components must have the same size in the 0th
// dimension.
class EnqueueManyOp : public QueueAccessOpKernel {
public:
explicit EnqueueManyOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
callback);
QueueInterface::Tuple tuple;
OpInputList components;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
callback);
for (const Tensor& Tcomponent : components) {
tuple.push_back(Tcomponent);
}
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
queue->TryEnqueueMany(tuple, ctx, callback);
}
~EnqueueManyOp() override {}
private:
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
EnqueueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
EnqueueManyOp);
// Defines a DequeueOp, the execution of which dequeues a tuple of
// tensors from the given Queue.
//
// The op has one input, which is the handle of the appropriate
// Queue. The op has k outputs, where k is the number of components in
// the tuples stored in the given Queue, and output i is the ith
// component of the dequeued tuple.
class DequeueOp : public QueueAccessOpKernel {
public:
explicit DequeueOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
callback);
}
queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components), callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
~DequeueOp() override {}
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
// Defines a DequeueManyOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
class DequeueManyOp : public QueueAccessOpKernel {
public:
explicit DequeueManyOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
const Tensor& Tnum_elements = ctx->input(1);
int32 num_elements = Tnum_elements.flat<int32>()(0);
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
errors::InvalidArgument("DequeueManyOp requested ",
num_elements, " < 0 elements"),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, false /* allow_small_batch */,
[ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components),
callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
~DequeueManyOp() override {}
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
DequeueManyOp);
// Defines a DequeueUpToOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
// dimension, and emits the result as a single tuple of tensors.
//
// The difference between this op and DequeueMany is the handling when
// the Queue is closed. While the DequeueMany op will return if there
// an error when there are less than num_elements elements left in the
// closed queue, this op will return between 1 and
// min(num_elements, elements_remaining_in_queue), and will not block.
// If there are no elements left, then the standard DequeueMany error
// is returned.
//
// This op only works if the underlying Queue implementation accepts
// the allow_small_batch = true parameter to TryDequeueMany.
// If it does not, an errors::Unimplemented exception is returned.
//
// The op has two inputs:
// - Input 0: the handle to a queue.
// - Input 1: the number of elements to dequeue.
//
// The op has k outputs, where k is the number of components in the
// tuples stored in the given Queue, and output i is the ith component
// of the dequeued tuple.
//
// The op has one attribute: allow_small_batch. If the Queue supports
// it, setting this to true causes the queue to return smaller
// (possibly zero length) batches when it is closed, up to however
// many elements are available when the op executes. In this case,
// the Queue does not block when closed.
class DequeueUpToOp : public QueueAccessOpKernel {
public:
explicit DequeueUpToOp(OpKernelConstruction* context)
: QueueAccessOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
const Tensor& Tnum_elements = ctx->input(1);
int32 num_elements = Tnum_elements.flat<int32>()(0);
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
errors::InvalidArgument("DequeueUpToOp requested ",
num_elements, " < 0 elements"),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, true /* allow_small_batch */,
[ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
}
OpOutputList output_components;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->output_list("components", &output_components),
callback);
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_components.set(i, tuple[i]);
}
callback();
});
}
~DequeueUpToOp() override {}
private:
TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
DequeueUpToOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
DequeueUpToOp);
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
//
// The op has one input, which is the handle of the appropriate Queue.
class QueueCloseOp : public QueueOpKernel {
public:
explicit QueueCloseOp(OpKernelConstruction* context)
: QueueOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
&cancel_pending_enqueues_));
}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
queue->Close(ctx, cancel_pending_enqueues_, callback);
}
private:
bool cancel_pending_enqueues_;
TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
// Defines a QueueSizeOp, which computes the number of elements in the
// given Queue, and emits it as an output tensor.
//
// The op has one input, which is the handle of the appropriate Queue;
// and one output, which is a single-element tensor containing the current
// size of that Queue.
class QueueSizeOp : public QueueOpKernel {
public:
explicit QueueSizeOp(OpKernelConstruction* context)
: QueueOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
Tensor* Tqueue_size = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
Tqueue_size->flat<int32>().setConstant(queue->size());
callback();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
class QueueIsClosedOp : public QueueOpKernel {
public:
explicit QueueIsClosedOp(OpKernelConstruction* context)
: QueueOpKernel(context) {}
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
Tensor* Tqueue_is_closed = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
callback();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
};
REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
QueueIsClosedOp);
REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),