From 50839154899377f89367d851f6d1e2c45fcd6431 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Jun 2018 20:24:15 -0700 Subject: [PATCH] [TF:XLA] Add partial implementation of tf.FIFOQueue for XLA devices (e.g., TPU). The idea is to have a host-side queue of device tensors. Operators dequeue_many, enqueue_many, and dequeue_up_to are not yet implemented because they require splitting/concatenating tensors, which will require calling into a compiled XLA compilation. Refactor queue operator implementations into libraries separate from the kernel registrations. Add support for ResourceOpKernels that are placed on non-CPU devices. Add support for allocating host-memory tensors during OpKernel construction. PiperOrigin-RevId: 202590292 --- tensorflow/compiler/jit/BUILD | 2 + tensorflow/compiler/jit/xla_device_ops.h | 29 +- tensorflow/compiler/tests/BUILD | 14 + tensorflow/compiler/tests/fifo_queue_test.py | 201 +++++++++ tensorflow/contrib/makefile/tf_op_files.txt | 1 + .../core/framework/resource_op_kernel.h | 25 +- tensorflow/core/kernels/BUILD | 5 +- tensorflow/core/kernels/fifo_queue.cc | 15 + tensorflow/core/kernels/fifo_queue.h | 23 +- tensorflow/core/kernels/fifo_queue_op.cc | 39 -- tensorflow/core/kernels/queue_op.cc | 367 ++++++++++++++++ tensorflow/core/kernels/queue_op.h | 233 ++++++++++- tensorflow/core/kernels/queue_ops.cc | 395 +----------------- 13 files changed, 883 insertions(+), 466 deletions(-) create mode 100644 tensorflow/compiler/tests/fifo_queue_test.py create mode 100644 tensorflow/core/kernels/queue_op.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d976f8296c6..c2245b8eae8 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -176,9 +176,11 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 11e45d2823d..a605335a94f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,9 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" @@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .HostMemory("input") \ .HostMemory("output"), \ - LoopCondOp); + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c1f65416b44..366822f0b74 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -371,6 +371,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 00000000000..0f64cc87cde --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 89db9ee2794..6e7423f85e3 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h index 813ec6eed58..0a8da8b3bf0 100644 --- a/tensorflow/core/framework/resource_op_kernel.h +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -43,9 +43,15 @@ template class ResourceOpKernel : public OpKernel { public: explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &handle_, nullptr)); + has_resource_type_ = (context->output_type(0) == DT_RESOURCE); + if (!has_resource_type_) { + // The resource variant of the op may be placed on non-CPU devices, but + // this allocation is always on the host. Fortunately we don't need it in + // the resource case. + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &handle_, nullptr)); + } } // The resource is deleted from the resource manager only when it is private @@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel { return; } - auto h = handle_.AccessTensor(context)->template flat(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); + if (!has_resource_type_) { + auto h = handle_.AccessTensor(context)->template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } resource_ = resource; } - if (context->expected_output_dtype(0) == DT_RESOURCE) { + if (has_resource_type_) { OP_REQUIRES_OK(context, MakeResourceHandleToOutput( context, 0, cinfo_.container(), cinfo_.name(), MakeTypeIndex())); @@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel { virtual Status VerifyResource(T* resource) { return Status::OK(); } PersistentTensor handle_ GUARDED_BY(mu_); + + // Is the output of the operator of type DT_RESOURCE? + bool has_resource_type_; }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cbe30cdca14..861fb1ef697 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -368,6 +368,7 @@ cc_library( cc_library( name = "queue_op", + srcs = ["queue_op.cc"], hdrs = ["queue_op.h"], deps = [ ":queue_base", @@ -1885,9 +1886,10 @@ cc_library( name = "fifo_queue", srcs = ["fifo_queue.cc"], hdrs = ["fifo_queue.h"], - visibility = ["//visibility:private"], + visibility = [":friends"], deps = [ ":queue_base", + ":queue_op", ":typed_queue", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -5076,6 +5078,7 @@ filegroup( "padding_fifo_queue.cc", "padding_fifo_queue_op.cc", "queue_base.cc", + "queue_op.cc", "queue_ops.cc", "random_op.cc", "reduction_ops_all.cc", diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index a23478af5b5..d6e859f1aa0 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { return Status::OK(); } +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context) + : TypedQueueOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); +} + +Status FIFOQueueOp::CreateResource(QueueInterface** ret) { + FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, + component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index f01d70924d0..697ee81c39b 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_ -#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_ +#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ #include #include @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/typed_queue.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue > { TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue); }; +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class FIFOQueueOp : public TypedQueueOp { + public: + explicit FIFOQueueOp(OpKernelConstruction* context); + + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + std::vector component_shapes_; + TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); +}; + } // namespace tensorflow -#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_ +#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc index b35bdbb2f01..80869768f18 100644 --- a/tensorflow/core/kernels/fifo_queue_op.cc +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -13,50 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/data_flow_ops.cc. - -#include -#include - #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fifo_queue.h" -#include "tensorflow/core/kernels/queue_base.h" -#include "tensorflow/core/kernels/queue_op.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { -// Defines a FIFOQueueOp, which produces a Queue (specifically, one -// backed by FIFOQueue) that persists across different graph -// executions, and sessions. Running this op produces a single-element -// tensor of handles to Queues in the corresponding device. -class FIFOQueueOp : public TypedQueueOp { - public: - explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); - } - - private: - Status CreateResource(QueueInterface** ret) override - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, - component_shapes_, cinfo_.name()); - return CreateTypedQueue(queue, ret); - } - - std::vector component_shapes_; - TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); -}; - REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp); REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp); diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc new file mode 100644 index 00000000000..53f431ef3c7 --- /dev/null +++ b/tensorflow/core/kernels/queue_op.cc @@ -0,0 +1,367 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/queue_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); + if (capacity_ < 0) { + capacity_ = QueueBase::kUnbounded; + } + OP_REQUIRES_OK(context, + context->GetAttr("component_types", &component_types_)); +} + +void QueueOp::Compute(OpKernelContext* context) { + ResourceOpKernel::Compute(context); + mutex_lock l(mu_); + if (resource_ && context->track_allocations()) { + context->record_persistent_memory_allocation(resource_->MemoryUsed()); + } +} + +Status QueueOp::VerifyResource(QueueInterface* queue) { + return queue->MatchesNodeDef(def()); +} + + +QueueOpKernel::QueueOpKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + +void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) { + QueueInterface* queue; + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), + callback); + } + ComputeAsync(ctx, queue, [callback, queue]() { + queue->Unref(); + callback(); + }); +} + +QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); + // TODO(keveman): Enable timeout. + OP_REQUIRES(context, timeout_ == -1, + errors::InvalidArgument("Timeout not supported yet.")); +} + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +EnqueueOp::EnqueueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); + queue->TryEnqueue(tuple, ctx, callback); +} + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); + queue->TryEnqueueMany(tuple, ctx, callback); +} + +EnqueueManyOp::~EnqueueManyOp() = default; + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +DequeueOp::DequeueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), + callback); + } + + queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueOp::~DequeueOp() = default; + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +DequeueManyOp::DequeueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat()(0); + + OP_REQUIRES_ASYNC(ctx, num_elements >= 0, + errors::InvalidArgument("DequeueManyOp requested ", + num_elements, " < 0 elements"), + callback); + + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } + + queue->TryDequeueMany( + num_elements, ctx, false /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueManyOp::~DequeueManyOp() = default; + +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat()(0); + + OP_REQUIRES_ASYNC(ctx, num_elements >= 0, + errors::InvalidArgument("DequeueUpToOp requested ", + num_elements, " < 0 elements"), + callback); + + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } + + queue->TryDequeueMany( + num_elements, ctx, true /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueUpToOp::~DequeueUpToOp() = default; + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +QueueCloseOp::QueueCloseOp(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", + &cancel_pending_enqueues_)); +} + +void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + queue->Close(ctx, cancel_pending_enqueues_, callback); +} + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +QueueSizeOp::QueueSizeOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + +void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + Tensor* Tqueue_size = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); + Tqueue_size->flat().setConstant(queue->size()); + callback(); +} + +QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + +void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + Tensor* Tqueue_is_closed = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); + Tqueue_is_closed->flat().setConstant(queue->is_closed()); + callback(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 6c19f9841cd..2efd838a5fb 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_ -#define TENSORFLOW_KERNELS_QUEUE_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ #include #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -32,22 +33,9 @@ namespace tensorflow { // Defines a QueueOp, an abstract class for Queue construction ops. class QueueOp : public ResourceOpKernel { public: - QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - if (capacity_ < 0) { - capacity_ = QueueBase::kUnbounded; - } - OP_REQUIRES_OK(context, - context->GetAttr("component_types", &component_types_)); - } + QueueOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override { - ResourceOpKernel::Compute(context); - mutex_lock l(mu_); - if (resource_ && context->track_allocations()) { - context->record_persistent_memory_allocation(resource_->MemoryUsed()); - } - } + void Compute(OpKernelContext* context) override; protected: // Variables accessible by subclasses @@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel { DataTypeVector component_types_; private: - Status VerifyResource(QueueInterface* queue) override { - return queue->MatchesNodeDef(def()); - } + Status VerifyResource(QueueInterface* queue) override; }; class TypedQueueOp : public QueueOp { @@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp { } }; +// Queue manipulator kernels + +class QueueOpKernel : public AsyncOpKernel { + public: + explicit QueueOpKernel(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final; + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) = 0; +}; + +class QueueAccessOpKernel : public QueueOpKernel { + public: + explicit QueueAccessOpKernel(OpKernelConstruction* context); + + protected: + int64 timeout_; +}; + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +class EnqueueOp : public QueueAccessOpKernel { + public: + explicit EnqueueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); +}; + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +class EnqueueManyOp : public QueueAccessOpKernel { + public: + explicit EnqueueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~EnqueueManyOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); +}; + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +class DequeueOp : public QueueAccessOpKernel { + public: + explicit DequeueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); +}; + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +class DequeueManyOp : public QueueAccessOpKernel { + public: + explicit DequeueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueManyOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); +}; + +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +class DequeueUpToOp : public QueueAccessOpKernel { + public: + explicit DequeueUpToOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueUpToOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); +}; + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +class QueueCloseOp : public QueueOpKernel { + public: + explicit QueueCloseOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + bool cancel_pending_enqueues_; + TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); +}; + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +class QueueSizeOp : public QueueOpKernel { + public: + explicit QueueSizeOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); +}; + +class QueueIsClosedOp : public QueueOpKernel { + public: + explicit QueueIsClosedOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); +}; + } // namespace tensorflow -#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index 46a02854d73..c4d404259ba 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -13,437 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/data_flow_ops.cc. - #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -class QueueOpKernel : public AsyncOpKernel { - public: - explicit QueueOpKernel(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { - QueueInterface* queue; - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), - callback); - } - ComputeAsync(ctx, queue, [callback, queue]() { - queue->Unref(); - callback(); - }); - } - - protected: - virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) = 0; -}; - -class QueueAccessOpKernel : public QueueOpKernel { - public: - explicit QueueAccessOpKernel(OpKernelConstruction* context) - : QueueOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); - // TODO(keveman): Enable timeout. - OP_REQUIRES(context, timeout_ == -1, - errors::InvalidArgument("Timeout not supported yet.")); - } - - protected: - int64 timeout_; -}; - -// Defines an EnqueueOp, the execution of which enqueues a tuple of -// tensors in the given Queue. -// -// The op has 1 + k inputs, where k is the number of components in the -// tuples stored in the given Queue: -// - Input 0: queue handle. -// - Input 1: 0th element of the tuple. -// - ... -// - Input (1+k): kth element of the tuple. -class EnqueueOp : public QueueAccessOpKernel { - public: - explicit EnqueueOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - DataTypeVector expected_inputs; - if (ctx->input_dtype(0) == DT_RESOURCE) { - expected_inputs.push_back(DT_RESOURCE); - } else { - expected_inputs.push_back(DT_STRING_REF); - } - for (DataType dt : queue->component_dtypes()) { - expected_inputs.push_back(dt); - } - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), - callback); - - QueueInterface::Tuple tuple; - OpInputList components; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), - callback); - for (const Tensor& Tcomponent : components) { - tuple.push_back(Tcomponent); - } - - OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); - queue->TryEnqueue(tuple, ctx, callback); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp); -// Defines an EnqueueManyOp, the execution of which slices each -// component of a tuple of tensors along the 0th dimension, and -// enqueues tuples of slices in the given Queue. -// -// The op has 1 + k inputs, where k is the number of components in the -// tuples stored in the given Queue: -// - Input 0: queue handle. -// - Input 1: 0th element of the tuple. -// - ... -// - Input (1+k): kth element of the tuple. -// -// N.B. All tuple components must have the same size in the 0th -// dimension. -class EnqueueManyOp : public QueueAccessOpKernel { - public: - explicit EnqueueManyOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - DataTypeVector expected_inputs; - if (ctx->input_dtype(0) == DT_RESOURCE) { - expected_inputs.push_back(DT_RESOURCE); - } else { - expected_inputs.push_back(DT_STRING_REF); - } - for (DataType dt : queue->component_dtypes()) { - expected_inputs.push_back(dt); - } - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), - callback); - - QueueInterface::Tuple tuple; - OpInputList components; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), - callback); - for (const Tensor& Tcomponent : components) { - tuple.push_back(Tcomponent); - } - - OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); - queue->TryEnqueueMany(tuple, ctx, callback); - } - - ~EnqueueManyOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), EnqueueManyOp); REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), EnqueueManyOp); -// Defines a DequeueOp, the execution of which dequeues a tuple of -// tensors from the given Queue. -// -// The op has one input, which is the handle of the appropriate -// Queue. The op has k outputs, where k is the number of components in -// the tuples stored in the given Queue, and output i is the ith -// component of the dequeued tuple. -class DequeueOp : public QueueAccessOpKernel { - public: - explicit DequeueOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), - callback); - } - - queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp); -// Defines a DequeueManyOp, the execution of which concatenates the -// requested number of elements from the given Queue along the 0th -// dimension, and emits the result as a single tuple of tensors. -// -// The op has two inputs: -// - Input 0: the handle to a queue. -// - Input 1: the number of elements to dequeue. -// -// The op has k outputs, where k is the number of components in the -// tuples stored in the given Queue, and output i is the ith component -// of the dequeued tuple. -class DequeueManyOp : public QueueAccessOpKernel { - public: - explicit DequeueManyOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - const Tensor& Tnum_elements = ctx->input(1); - int32 num_elements = Tnum_elements.flat()(0); - - OP_REQUIRES_ASYNC(ctx, num_elements >= 0, - errors::InvalidArgument("DequeueManyOp requested ", - num_elements, " < 0 elements"), - callback); - - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_RESOURCE, DT_INT32}, - queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); - } - - queue->TryDequeueMany( - num_elements, ctx, false /* allow_small_batch */, - [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), - callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueManyOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), DequeueManyOp); -// Defines a DequeueUpToOp, the execution of which concatenates the -// requested number of elements from the given Queue along the 0th -// dimension, and emits the result as a single tuple of tensors. -// -// The difference between this op and DequeueMany is the handling when -// the Queue is closed. While the DequeueMany op will return if there -// an error when there are less than num_elements elements left in the -// closed queue, this op will return between 1 and -// min(num_elements, elements_remaining_in_queue), and will not block. -// If there are no elements left, then the standard DequeueMany error -// is returned. -// -// This op only works if the underlying Queue implementation accepts -// the allow_small_batch = true parameter to TryDequeueMany. -// If it does not, an errors::Unimplemented exception is returned. -// -// The op has two inputs: -// - Input 0: the handle to a queue. -// - Input 1: the number of elements to dequeue. -// -// The op has k outputs, where k is the number of components in the -// tuples stored in the given Queue, and output i is the ith component -// of the dequeued tuple. -// -// The op has one attribute: allow_small_batch. If the Queue supports -// it, setting this to true causes the queue to return smaller -// (possibly zero length) batches when it is closed, up to however -// many elements are available when the op executes. In this case, -// the Queue does not block when closed. -class DequeueUpToOp : public QueueAccessOpKernel { - public: - explicit DequeueUpToOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - const Tensor& Tnum_elements = ctx->input(1); - int32 num_elements = Tnum_elements.flat()(0); - - OP_REQUIRES_ASYNC(ctx, num_elements >= 0, - errors::InvalidArgument("DequeueUpToOp requested ", - num_elements, " < 0 elements"), - callback); - - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_RESOURCE, DT_INT32}, - queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); - } - - queue->TryDequeueMany( - num_elements, ctx, true /* allow_small_batch */, - [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), - callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueUpToOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), DequeueUpToOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), DequeueUpToOp); -// Defines a QueueCloseOp, which closes the given Queue. Closing a -// Queue signals that no more elements will be enqueued in it. -// -// The op has one input, which is the handle of the appropriate Queue. -class QueueCloseOp : public QueueOpKernel { - public: - explicit QueueCloseOp(OpKernelConstruction* context) - : QueueOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", - &cancel_pending_enqueues_)); - } - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - queue->Close(ctx, cancel_pending_enqueues_, callback); - } - - private: - bool cancel_pending_enqueues_; - TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp); -// Defines a QueueSizeOp, which computes the number of elements in the -// given Queue, and emits it as an output tensor. -// -// The op has one input, which is the handle of the appropriate Queue; -// and one output, which is a single-element tensor containing the current -// size of that Queue. -class QueueSizeOp : public QueueOpKernel { - public: - explicit QueueSizeOp(OpKernelConstruction* context) - : QueueOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - Tensor* Tqueue_size = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); - Tqueue_size->flat().setConstant(queue->size()); - callback(); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp); -class QueueIsClosedOp : public QueueOpKernel { - public: - explicit QueueIsClosedOp(OpKernelConstruction* context) - : QueueOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - Tensor* Tqueue_is_closed = nullptr; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); - Tqueue_is_closed->flat().setConstant(queue->is_closed()); - callback(); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU), QueueIsClosedOp); REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),