[TF:XLA] Add partial implementation of tf.FIFOQueue for XLA devices (e.g., TPU).
The idea is to have a host-side queue of device tensors. Operators dequeue_many, enqueue_many, and dequeue_up_to are not yet implemented because they require splitting/concatenating tensors, which will require calling into a compiled XLA compilation. Refactor queue operator implementations into libraries separate from the kernel registrations. Add support for ResourceOpKernels that are placed on non-CPU devices. Add support for allocating host-memory tensors during OpKernel construction. PiperOrigin-RevId: 202590292
This commit is contained in:
parent
f04400f18f
commit
5083915489
@ -176,9 +176,11 @@ cc_library(
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:identity_n_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:queue_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/fifo_queue.h"
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output"), \
|
||||
LoopCondOp);
|
||||
LoopCondOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("size") \
|
||||
.HostMemory("handle"), \
|
||||
QueueSizeOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
|
||||
QueueIsClosedOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
// and write the tensors they access in order to concatenate them into a batch.
|
||||
// We would need either to call out to an XLA computation to perform the
|
||||
// concatenation, or we would need to refactor those kernels so the splitting
|
||||
// or merging is done in a separate operator that can be compiled.
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -371,6 +371,20 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fifo_queue_test",
|
||||
size = "medium",
|
||||
srcs = ["fifo_queue_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
|
201
tensorflow/compiler/tests/fifo_queue_test.py
Normal file
201
tensorflow/compiler/tests/fifo_queue_test.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FIFOQueueTest(xla_test.XLATestCase):
|
||||
|
||||
def testEnqueue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
enqueue_op = q.enqueue((10.0,))
|
||||
enqueue_op.run()
|
||||
|
||||
def testEnqueueWithShape(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
|
||||
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
|
||||
enqueue_correct_op.run()
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
|
||||
self.assertEqual(1, q.size().eval())
|
||||
|
||||
def testMultipleDequeues(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q.enqueue([1]))
|
||||
self.evaluate(q.enqueue([2]))
|
||||
self.evaluate(q.enqueue([3]))
|
||||
a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
|
||||
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
|
||||
|
||||
def testQueuesDontShare(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q.enqueue(1))
|
||||
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q2.enqueue(2))
|
||||
self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
|
||||
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
|
||||
|
||||
def testEnqueueDictWithoutNames(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
with self.assertRaisesRegexp(ValueError, "must have names"):
|
||||
q.enqueue({"a": 12.0})
|
||||
|
||||
def testParallelEnqueue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
# Run one producer thread for each element in elems.
|
||||
def enqueue(enqueue_op):
|
||||
sess.run(enqueue_op)
|
||||
|
||||
threads = [
|
||||
self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Dequeue every element using a single thread.
|
||||
results = []
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(dequeued_t.eval())
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testParallelDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
# Enqueue every element using a single thread.
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
# Run one consumer thread for each element in elems.
|
||||
results = []
|
||||
|
||||
def dequeue():
|
||||
results.append(sess.run(dequeued_t))
|
||||
|
||||
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testDequeue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
for i in xrange(len(elems)):
|
||||
vals = dequeued_t.eval()
|
||||
self.assertEqual([elems[i]], vals)
|
||||
|
||||
def testEnqueueAndBlockingDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
def enqueue():
|
||||
# The enqueue_ops should run after the dequeue op has blocked.
|
||||
# TODO(mrry): Figure out how to do this without sleeping.
|
||||
time.sleep(0.1)
|
||||
for enqueue_op in enqueue_ops:
|
||||
sess.run(enqueue_op)
|
||||
|
||||
results = []
|
||||
|
||||
def dequeue():
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(sess.run(dequeued_t))
|
||||
|
||||
enqueue_thread = self.checkedThread(target=enqueue)
|
||||
dequeue_thread = self.checkedThread(target=dequeue)
|
||||
enqueue_thread.start()
|
||||
dequeue_thread.start()
|
||||
enqueue_thread.join()
|
||||
dequeue_thread.join()
|
||||
|
||||
for elem, result in zip(elems, results):
|
||||
self.assertEqual([elem], result)
|
||||
|
||||
def testMultiEnqueueAndDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
|
||||
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
|
||||
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
for i in xrange(len(elems)):
|
||||
x_val, y_val = sess.run(dequeued_t)
|
||||
x, y = elems[i]
|
||||
self.assertEqual([x], x_val)
|
||||
self.assertEqual([y], y_val)
|
||||
|
||||
def testQueueSizeEmpty(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
self.assertEqual([0], q.size().eval())
|
||||
|
||||
def testQueueSizeAfterEnqueueAndDequeue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
enqueue_op = q.enqueue((10.0,))
|
||||
dequeued_t = q.dequeue()
|
||||
size = q.size()
|
||||
self.assertEqual([], size.get_shape())
|
||||
|
||||
enqueue_op.run()
|
||||
self.assertEqual(1, size.eval())
|
||||
dequeued_t.op.run()
|
||||
self.assertEqual(0, size.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
|
||||
tensorflow/core/kernels/reduction_ops_any.cc
|
||||
tensorflow/core/kernels/reduction_ops_all.cc
|
||||
tensorflow/core/kernels/roll_op.cc
|
||||
tensorflow/core/kernels/queue_op.cc
|
||||
tensorflow/core/kernels/queue_ops.cc
|
||||
tensorflow/core/kernels/queue_base.cc
|
||||
tensorflow/core/kernels/pooling_ops_common.cc
|
||||
|
@ -43,9 +43,15 @@ template <typename T>
|
||||
class ResourceOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_STRING, TensorShape({2}),
|
||||
&handle_, nullptr));
|
||||
has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
|
||||
if (!has_resource_type_) {
|
||||
// The resource variant of the op may be placed on non-CPU devices, but
|
||||
// this allocation is always on the host. Fortunately we don't need it in
|
||||
// the resource case.
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_STRING, TensorShape({2}),
|
||||
&handle_, nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
// The resource is deleted from the resource manager only when it is private
|
||||
@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
auto h = handle_.AccessTensor(context)->template flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
if (!has_resource_type_) {
|
||||
auto h = handle_.AccessTensor(context)->template flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
}
|
||||
resource_ = resource;
|
||||
}
|
||||
if (context->expected_output_dtype(0) == DT_RESOURCE) {
|
||||
if (has_resource_type_) {
|
||||
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
|
||||
context, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<T>()));
|
||||
@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
|
||||
virtual Status VerifyResource(T* resource) { return Status::OK(); }
|
||||
|
||||
PersistentTensor handle_ GUARDED_BY(mu_);
|
||||
|
||||
// Is the output of the operator of type DT_RESOURCE?
|
||||
bool has_resource_type_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -368,6 +368,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "queue_op",
|
||||
srcs = ["queue_op.cc"],
|
||||
hdrs = ["queue_op.h"],
|
||||
deps = [
|
||||
":queue_base",
|
||||
@ -1885,9 +1886,10 @@ cc_library(
|
||||
name = "fifo_queue",
|
||||
srcs = ["fifo_queue.cc"],
|
||||
hdrs = ["fifo_queue.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":queue_base",
|
||||
":queue_op",
|
||||
":typed_queue",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -5076,6 +5078,7 @@ filegroup(
|
||||
"padding_fifo_queue.cc",
|
||||
"padding_fifo_queue_op.cc",
|
||||
"queue_base.cc",
|
||||
"queue_op.cc",
|
||||
"queue_ops.cc",
|
||||
"random_op.cc",
|
||||
"reduction_ops_all.cc",
|
||||
|
@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
|
||||
// backed by FIFOQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
|
||||
: TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
}
|
||||
|
||||
Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
|
||||
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
|
||||
component_shapes_, cinfo_.name());
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
|
||||
#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/kernels/typed_queue.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
|
||||
};
|
||||
|
||||
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
|
||||
// backed by FIFOQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class FIFOQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit FIFOQueueOp(OpKernelConstruction* context);
|
||||
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
std::vector<TensorShape> component_shapes_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
|
||||
|
@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/fifo_queue.h"
|
||||
#include "tensorflow/core/kernels/queue_base.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Defines a FIFOQueueOp, which produces a Queue (specifically, one
|
||||
// backed by FIFOQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
// tensor of handles to Queues in the corresponding device.
|
||||
class FIFOQueueOp : public TypedQueueOp {
|
||||
public:
|
||||
explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
|
||||
}
|
||||
|
||||
private:
|
||||
Status CreateResource(QueueInterface** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
|
||||
component_shapes_, cinfo_.name());
|
||||
return CreateTypedQueue(queue, ret);
|
||||
}
|
||||
|
||||
std::vector<TensorShape> component_shapes_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
|
||||
|
||||
|
367
tensorflow/core/kernels/queue_op.cc
Normal file
367
tensorflow/core/kernels/queue_op.cc
Normal file
@ -0,0 +1,367 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
|
||||
if (capacity_ < 0) {
|
||||
capacity_ = QueueBase::kUnbounded;
|
||||
}
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("component_types", &component_types_));
|
||||
}
|
||||
|
||||
void QueueOp::Compute(OpKernelContext* context) {
|
||||
ResourceOpKernel<QueueInterface>::Compute(context);
|
||||
mutex_lock l(mu_);
|
||||
if (resource_ && context->track_allocations()) {
|
||||
context->record_persistent_memory_allocation(resource_->MemoryUsed());
|
||||
}
|
||||
}
|
||||
|
||||
Status QueueOp::VerifyResource(QueueInterface* queue) {
|
||||
return queue->MatchesNodeDef(def());
|
||||
}
|
||||
|
||||
|
||||
QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {}
|
||||
|
||||
void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
|
||||
QueueInterface* queue;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
|
||||
callback);
|
||||
}
|
||||
ComputeAsync(ctx, queue, [callback, queue]() {
|
||||
queue->Unref();
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
|
||||
// TODO(keveman): Enable timeout.
|
||||
OP_REQUIRES(context, timeout_ == -1,
|
||||
errors::InvalidArgument("Timeout not supported yet."));
|
||||
}
|
||||
|
||||
// Defines an EnqueueOp, the execution of which enqueues a tuple of
|
||||
// tensors in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
EnqueueOp::EnqueueOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
DataTypeVector expected_inputs;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
expected_inputs.push_back(DT_RESOURCE);
|
||||
} else {
|
||||
expected_inputs.push_back(DT_STRING_REF);
|
||||
}
|
||||
for (DataType dt : queue->component_dtypes()) {
|
||||
expected_inputs.push_back(dt);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
|
||||
|
||||
QueueInterface::Tuple tuple;
|
||||
OpInputList components;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
|
||||
callback);
|
||||
for (const Tensor& Tcomponent : components) {
|
||||
tuple.push_back(Tcomponent);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
|
||||
queue->TryEnqueue(tuple, ctx, callback);
|
||||
}
|
||||
|
||||
// Defines an EnqueueManyOp, the execution of which slices each
|
||||
// component of a tuple of tensors along the 0th dimension, and
|
||||
// enqueues tuples of slices in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
//
|
||||
// N.B. All tuple components must have the same size in the 0th
|
||||
// dimension.
|
||||
EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
DataTypeVector expected_inputs;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
expected_inputs.push_back(DT_RESOURCE);
|
||||
} else {
|
||||
expected_inputs.push_back(DT_STRING_REF);
|
||||
}
|
||||
for (DataType dt : queue->component_dtypes()) {
|
||||
expected_inputs.push_back(dt);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
|
||||
|
||||
QueueInterface::Tuple tuple;
|
||||
OpInputList components;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
|
||||
callback);
|
||||
for (const Tensor& Tcomponent : components) {
|
||||
tuple.push_back(Tcomponent);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
|
||||
queue->TryEnqueueMany(tuple, ctx, callback);
|
||||
}
|
||||
|
||||
EnqueueManyOp::~EnqueueManyOp() = default;
|
||||
|
||||
// Defines a DequeueOp, the execution of which dequeues a tuple of
|
||||
// tensors from the given Queue.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate
|
||||
// Queue. The op has k outputs, where k is the number of components in
|
||||
// the tuples stored in the given Queue, and output i is the ith
|
||||
// component of the dequeued tuple.
|
||||
DequeueOp::DequeueOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components), callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
DequeueOp::~DequeueOp() = default;
|
||||
|
||||
// Defines a DequeueManyOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
const Tensor& Tnum_elements = ctx->input(1);
|
||||
int32 num_elements = Tnum_elements.flat<int32>()(0);
|
||||
|
||||
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
|
||||
errors::InvalidArgument("DequeueManyOp requested ",
|
||||
num_elements, " < 0 elements"),
|
||||
callback);
|
||||
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeueMany(
|
||||
num_elements, ctx, false /* allow_small_batch */,
|
||||
[ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components), callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
DequeueManyOp::~DequeueManyOp() = default;
|
||||
|
||||
// Defines a DequeueUpToOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The difference between this op and DequeueMany is the handling when
|
||||
// the Queue is closed. While the DequeueMany op will return if there
|
||||
// an error when there are less than num_elements elements left in the
|
||||
// closed queue, this op will return between 1 and
|
||||
// min(num_elements, elements_remaining_in_queue), and will not block.
|
||||
// If there are no elements left, then the standard DequeueMany error
|
||||
// is returned.
|
||||
//
|
||||
// This op only works if the underlying Queue implementation accepts
|
||||
// the allow_small_batch = true parameter to TryDequeueMany.
|
||||
// If it does not, an errors::Unimplemented exception is returned.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
//
|
||||
// The op has one attribute: allow_small_batch. If the Queue supports
|
||||
// it, setting this to true causes the queue to return smaller
|
||||
// (possibly zero length) batches when it is closed, up to however
|
||||
// many elements are available when the op executes. In this case,
|
||||
// the Queue does not block when closed.
|
||||
DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
const Tensor& Tnum_elements = ctx->input(1);
|
||||
int32 num_elements = Tnum_elements.flat<int32>()(0);
|
||||
|
||||
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
|
||||
errors::InvalidArgument("DequeueUpToOp requested ",
|
||||
num_elements, " < 0 elements"),
|
||||
callback);
|
||||
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeueMany(
|
||||
num_elements, ctx, true /* allow_small_batch */,
|
||||
[ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components), callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
DequeueUpToOp::~DequeueUpToOp() = default;
|
||||
|
||||
// Defines a QueueCloseOp, which closes the given Queue. Closing a
|
||||
// Queue signals that no more elements will be enqueued in it.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue.
|
||||
QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
|
||||
&cancel_pending_enqueues_));
|
||||
}
|
||||
|
||||
void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
queue->Close(ctx, cancel_pending_enqueues_, callback);
|
||||
}
|
||||
|
||||
// Defines a QueueSizeOp, which computes the number of elements in the
|
||||
// given Queue, and emits it as an output tensor.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue;
|
||||
// and one output, which is a single-element tensor containing the current
|
||||
// size of that Queue.
|
||||
QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {}
|
||||
|
||||
void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
Tensor* Tqueue_size = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
|
||||
Tqueue_size->flat<int32>().setConstant(queue->size());
|
||||
callback();
|
||||
}
|
||||
|
||||
QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {}
|
||||
|
||||
void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) {
|
||||
Tensor* Tqueue_is_closed = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
|
||||
Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
|
||||
callback();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
|
||||
#define TENSORFLOW_KERNELS_QUEUE_OP_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -32,22 +33,9 @@ namespace tensorflow {
|
||||
// Defines a QueueOp, an abstract class for Queue construction ops.
|
||||
class QueueOp : public ResourceOpKernel<QueueInterface> {
|
||||
public:
|
||||
QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
|
||||
if (capacity_ < 0) {
|
||||
capacity_ = QueueBase::kUnbounded;
|
||||
}
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("component_types", &component_types_));
|
||||
}
|
||||
QueueOp(OpKernelConstruction* context);
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
ResourceOpKernel<QueueInterface>::Compute(context);
|
||||
mutex_lock l(mu_);
|
||||
if (resource_ && context->track_allocations()) {
|
||||
context->record_persistent_memory_allocation(resource_->MemoryUsed());
|
||||
}
|
||||
}
|
||||
void Compute(OpKernelContext* context) override;
|
||||
|
||||
protected:
|
||||
// Variables accessible by subclasses
|
||||
@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
|
||||
DataTypeVector component_types_;
|
||||
|
||||
private:
|
||||
Status VerifyResource(QueueInterface* queue) override {
|
||||
return queue->MatchesNodeDef(def());
|
||||
}
|
||||
Status VerifyResource(QueueInterface* queue) override;
|
||||
};
|
||||
|
||||
class TypedQueueOp : public QueueOp {
|
||||
@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Queue manipulator kernels
|
||||
|
||||
class QueueOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit QueueOpKernel(OpKernelConstruction* context);
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) = 0;
|
||||
};
|
||||
|
||||
class QueueAccessOpKernel : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueAccessOpKernel(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
int64 timeout_;
|
||||
};
|
||||
|
||||
// Defines an EnqueueOp, the execution of which enqueues a tuple of
|
||||
// tensors in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
class EnqueueOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit EnqueueOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
|
||||
};
|
||||
|
||||
// Defines an EnqueueManyOp, the execution of which slices each
|
||||
// component of a tuple of tensors along the 0th dimension, and
|
||||
// enqueues tuples of slices in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
//
|
||||
// N.B. All tuple components must have the same size in the 0th
|
||||
// dimension.
|
||||
class EnqueueManyOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit EnqueueManyOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
~EnqueueManyOp() override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
|
||||
};
|
||||
|
||||
// Defines a DequeueOp, the execution of which dequeues a tuple of
|
||||
// tensors from the given Queue.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate
|
||||
// Queue. The op has k outputs, where k is the number of components in
|
||||
// the tuples stored in the given Queue, and output i is the ith
|
||||
// component of the dequeued tuple.
|
||||
class DequeueOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
~DequeueOp() override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
|
||||
};
|
||||
|
||||
// Defines a DequeueManyOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
class DequeueManyOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueManyOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
~DequeueManyOp() override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
|
||||
};
|
||||
|
||||
// Defines a DequeueUpToOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The difference between this op and DequeueMany is the handling when
|
||||
// the Queue is closed. While the DequeueMany op will return if there
|
||||
// an error when there are less than num_elements elements left in the
|
||||
// closed queue, this op will return between 1 and
|
||||
// min(num_elements, elements_remaining_in_queue), and will not block.
|
||||
// If there are no elements left, then the standard DequeueMany error
|
||||
// is returned.
|
||||
//
|
||||
// This op only works if the underlying Queue implementation accepts
|
||||
// the allow_small_batch = true parameter to TryDequeueMany.
|
||||
// If it does not, an errors::Unimplemented exception is returned.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
//
|
||||
// The op has one attribute: allow_small_batch. If the Queue supports
|
||||
// it, setting this to true causes the queue to return smaller
|
||||
// (possibly zero length) batches when it is closed, up to however
|
||||
// many elements are available when the op executes. In this case,
|
||||
// the Queue does not block when closed.
|
||||
class DequeueUpToOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueUpToOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
~DequeueUpToOp() override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
|
||||
};
|
||||
|
||||
// Defines a QueueCloseOp, which closes the given Queue. Closing a
|
||||
// Queue signals that no more elements will be enqueued in it.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue.
|
||||
class QueueCloseOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueCloseOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
private:
|
||||
bool cancel_pending_enqueues_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
|
||||
};
|
||||
|
||||
// Defines a QueueSizeOp, which computes the number of elements in the
|
||||
// given Queue, and emits it as an output tensor.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue;
|
||||
// and one output, which is a single-element tensor containing the current
|
||||
// size of that Queue.
|
||||
class QueueSizeOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueSizeOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
|
||||
};
|
||||
|
||||
class QueueIsClosedOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueIsClosedOp(OpKernelConstruction* context);
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
|
||||
|
@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class QueueOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit QueueOpKernel(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
|
||||
QueueInterface* queue;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
|
||||
callback);
|
||||
}
|
||||
ComputeAsync(ctx, queue, [callback, queue]() {
|
||||
queue->Unref();
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) = 0;
|
||||
};
|
||||
|
||||
class QueueAccessOpKernel : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueAccessOpKernel(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
|
||||
// TODO(keveman): Enable timeout.
|
||||
OP_REQUIRES(context, timeout_ == -1,
|
||||
errors::InvalidArgument("Timeout not supported yet."));
|
||||
}
|
||||
|
||||
protected:
|
||||
int64 timeout_;
|
||||
};
|
||||
|
||||
// Defines an EnqueueOp, the execution of which enqueues a tuple of
|
||||
// tensors in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
class EnqueueOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit EnqueueOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
DataTypeVector expected_inputs;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
expected_inputs.push_back(DT_RESOURCE);
|
||||
} else {
|
||||
expected_inputs.push_back(DT_STRING_REF);
|
||||
}
|
||||
for (DataType dt : queue->component_dtypes()) {
|
||||
expected_inputs.push_back(dt);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
|
||||
callback);
|
||||
|
||||
QueueInterface::Tuple tuple;
|
||||
OpInputList components;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
|
||||
callback);
|
||||
for (const Tensor& Tcomponent : components) {
|
||||
tuple.push_back(Tcomponent);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
|
||||
queue->TryEnqueue(tuple, ctx, callback);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
|
||||
|
||||
// Defines an EnqueueManyOp, the execution of which slices each
|
||||
// component of a tuple of tensors along the 0th dimension, and
|
||||
// enqueues tuples of slices in the given Queue.
|
||||
//
|
||||
// The op has 1 + k inputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue:
|
||||
// - Input 0: queue handle.
|
||||
// - Input 1: 0th element of the tuple.
|
||||
// - ...
|
||||
// - Input (1+k): kth element of the tuple.
|
||||
//
|
||||
// N.B. All tuple components must have the same size in the 0th
|
||||
// dimension.
|
||||
class EnqueueManyOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit EnqueueManyOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
DataTypeVector expected_inputs;
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
expected_inputs.push_back(DT_RESOURCE);
|
||||
} else {
|
||||
expected_inputs.push_back(DT_STRING_REF);
|
||||
}
|
||||
for (DataType dt : queue->component_dtypes()) {
|
||||
expected_inputs.push_back(dt);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
|
||||
callback);
|
||||
|
||||
QueueInterface::Tuple tuple;
|
||||
OpInputList components;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
|
||||
callback);
|
||||
for (const Tensor& Tcomponent : components) {
|
||||
tuple.push_back(Tcomponent);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
|
||||
queue->TryEnqueueMany(tuple, ctx, callback);
|
||||
}
|
||||
|
||||
~EnqueueManyOp() override {}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
|
||||
EnqueueManyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
|
||||
EnqueueManyOp);
|
||||
|
||||
// Defines a DequeueOp, the execution of which dequeues a tuple of
|
||||
// tensors from the given Queue.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate
|
||||
// Queue. The op has k outputs, where k is the number of components in
|
||||
// the tuples stored in the given Queue, and output i is the ith
|
||||
// component of the dequeued tuple.
|
||||
class DequeueOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components), callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
~DequeueOp() override {}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
|
||||
|
||||
// Defines a DequeueManyOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
class DequeueManyOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueManyOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
const Tensor& Tnum_elements = ctx->input(1);
|
||||
int32 num_elements = Tnum_elements.flat<int32>()(0);
|
||||
|
||||
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
|
||||
errors::InvalidArgument("DequeueManyOp requested ",
|
||||
num_elements, " < 0 elements"),
|
||||
callback);
|
||||
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeueMany(
|
||||
num_elements, ctx, false /* allow_small_batch */,
|
||||
[ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components),
|
||||
callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
~DequeueManyOp() override {}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
|
||||
DequeueManyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
|
||||
DequeueManyOp);
|
||||
|
||||
// Defines a DequeueUpToOp, the execution of which concatenates the
|
||||
// requested number of elements from the given Queue along the 0th
|
||||
// dimension, and emits the result as a single tuple of tensors.
|
||||
//
|
||||
// The difference between this op and DequeueMany is the handling when
|
||||
// the Queue is closed. While the DequeueMany op will return if there
|
||||
// an error when there are less than num_elements elements left in the
|
||||
// closed queue, this op will return between 1 and
|
||||
// min(num_elements, elements_remaining_in_queue), and will not block.
|
||||
// If there are no elements left, then the standard DequeueMany error
|
||||
// is returned.
|
||||
//
|
||||
// This op only works if the underlying Queue implementation accepts
|
||||
// the allow_small_batch = true parameter to TryDequeueMany.
|
||||
// If it does not, an errors::Unimplemented exception is returned.
|
||||
//
|
||||
// The op has two inputs:
|
||||
// - Input 0: the handle to a queue.
|
||||
// - Input 1: the number of elements to dequeue.
|
||||
//
|
||||
// The op has k outputs, where k is the number of components in the
|
||||
// tuples stored in the given Queue, and output i is the ith component
|
||||
// of the dequeued tuple.
|
||||
//
|
||||
// The op has one attribute: allow_small_batch. If the Queue supports
|
||||
// it, setting this to true causes the queue to return smaller
|
||||
// (possibly zero length) batches when it is closed, up to however
|
||||
// many elements are available when the op executes. In this case,
|
||||
// the Queue does not block when closed.
|
||||
class DequeueUpToOp : public QueueAccessOpKernel {
|
||||
public:
|
||||
explicit DequeueUpToOp(OpKernelConstruction* context)
|
||||
: QueueAccessOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
const Tensor& Tnum_elements = ctx->input(1);
|
||||
int32 num_elements = Tnum_elements.flat<int32>()(0);
|
||||
|
||||
OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
|
||||
errors::InvalidArgument("DequeueUpToOp requested ",
|
||||
num_elements, " < 0 elements"),
|
||||
callback);
|
||||
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
} else {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
|
||||
queue->component_dtypes()),
|
||||
callback);
|
||||
}
|
||||
|
||||
queue->TryDequeueMany(
|
||||
num_elements, ctx, true /* allow_small_batch */,
|
||||
[ctx, callback](const QueueInterface::Tuple& tuple) {
|
||||
if (!ctx->status().ok()) {
|
||||
callback();
|
||||
return;
|
||||
}
|
||||
OpOutputList output_components;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->output_list("components", &output_components),
|
||||
callback);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_components.set(i, tuple[i]);
|
||||
}
|
||||
callback();
|
||||
});
|
||||
}
|
||||
|
||||
~DequeueUpToOp() override {}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
|
||||
DequeueUpToOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
|
||||
DequeueUpToOp);
|
||||
|
||||
// Defines a QueueCloseOp, which closes the given Queue. Closing a
|
||||
// Queue signals that no more elements will be enqueued in it.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue.
|
||||
class QueueCloseOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueCloseOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
|
||||
&cancel_pending_enqueues_));
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
queue->Close(ctx, cancel_pending_enqueues_, callback);
|
||||
}
|
||||
|
||||
private:
|
||||
bool cancel_pending_enqueues_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
|
||||
|
||||
// Defines a QueueSizeOp, which computes the number of elements in the
|
||||
// given Queue, and emits it as an output tensor.
|
||||
//
|
||||
// The op has one input, which is the handle of the appropriate Queue;
|
||||
// and one output, which is a single-element tensor containing the current
|
||||
// size of that Queue.
|
||||
class QueueSizeOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueSizeOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
Tensor* Tqueue_size = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
|
||||
Tqueue_size->flat<int32>().setConstant(queue->size());
|
||||
callback();
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
|
||||
|
||||
class QueueIsClosedOp : public QueueOpKernel {
|
||||
public:
|
||||
explicit QueueIsClosedOp(OpKernelConstruction* context)
|
||||
: QueueOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
|
||||
DoneCallback callback) override {
|
||||
Tensor* Tqueue_is_closed = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
|
||||
Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
|
||||
callback();
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
|
||||
QueueIsClosedOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),
|
||||
|
Loading…
Reference in New Issue
Block a user