diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index d976f8296c6..c2245b8eae8 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -176,9 +176,11 @@ cc_library(
         "//tensorflow/core/kernels:cast_op",
         "//tensorflow/core/kernels:constant_op",
         "//tensorflow/core/kernels:control_flow_ops",
+        "//tensorflow/core/kernels:fifo_queue",
         "//tensorflow/core/kernels:identity_n_op",
         "//tensorflow/core/kernels:identity_op",
         "//tensorflow/core/kernels:no_op",
+        "//tensorflow/core/kernels:queue_op",
         "//tensorflow/core/kernels:resource_variable_ops",
         "//tensorflow/core/kernels:sendrecv_ops",
         "//tensorflow/core/kernels:shape_ops",
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 11e45d2823d..a605335a94f 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -23,9 +23,11 @@ limitations under the License.
 #include "tensorflow/core/kernels/cast_op.h"
 #include "tensorflow/core/kernels/constant_op.h"
 #include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
 #include "tensorflow/core/kernels/identity_n_op.h"
 #include "tensorflow/core/kernels/identity_op.h"
 #include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/queue_op.h"
 #include "tensorflow/core/kernels/resource_variable_ops.h"
 #include "tensorflow/core/kernels/sendrecv_ops.h"
 #include "tensorflow/core/kernels/shape_ops.h"
@@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
                               .Device(DEVICE)                                  \
                               .HostMemory("input")                             \
                               .HostMemory("output"),                           \
-                          LoopCondOp);
+                          LoopCondOp);                                         \
+                                                                               \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp);  \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp);  \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
+  REGISTER_KERNEL_BUILDER(Name("QueueSizeV2")                                  \
+                              .Device(DEVICE)                                  \
+                              .HostMemory("size")                              \
+                              .HostMemory("handle"),                           \
+                          QueueSizeOp);                                        \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"),             \
+      QueueIsClosedOp);                                                        \
+                                                                               \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
+
+// TODO(phawkins): currently we do not register the QueueEnqueueMany,
+// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
+// and write the tensors they access in order to concatenate them into a batch.
+// We would need either to call out to an XLA computation to perform the
+// concatenation, or we would need to refactor those kernels so the splitting
+// or merging is done in a separate operator that can be compiled.
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c1f65416b44..366822f0b74 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -371,6 +371,20 @@ tf_xla_py_test(
     ],
 )
 
+tf_xla_py_test(
+    name = "fifo_queue_test",
+    size = "medium",
+    srcs = ["fifo_queue_test.py"],
+    deps = [
+        ":xla_test",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:data_flow_ops",
+        "//tensorflow/python:extra_py_tests_deps",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
 tf_xla_py_test(
     name = "fft_test",
     size = "medium",
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
new file mode 100644
index 00000000000..0f64cc87cde
--- /dev/null
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import test
+
+
+class FIFOQueueTest(xla_test.XLATestCase):
+
+  def testEnqueue(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      enqueue_op = q.enqueue((10.0,))
+      enqueue_op.run()
+
+  def testEnqueueWithShape(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
+      enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+      enqueue_correct_op.run()
+      with self.assertRaises(ValueError):
+        q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+      self.assertEqual(1, q.size().eval())
+
+  def testMultipleDequeues(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+      self.evaluate(q.enqueue([1]))
+      self.evaluate(q.enqueue([2]))
+      self.evaluate(q.enqueue([3]))
+      a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+      self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+  def testQueuesDontShare(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+      self.evaluate(q.enqueue(1))
+      q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+      self.evaluate(q2.enqueue(2))
+      self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+      self.assertAllEqual(self.evaluate(q.dequeue()), 1)
+
+  def testEnqueueDictWithoutNames(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      with self.assertRaisesRegexp(ValueError, "must have names"):
+        q.enqueue({"a": 12.0})
+
+  def testParallelEnqueue(self):
+    with self.test_session() as sess, self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+      enqueue_ops = [q.enqueue((x,)) for x in elems]
+      dequeued_t = q.dequeue()
+
+      # Run one producer thread for each element in elems.
+      def enqueue(enqueue_op):
+        sess.run(enqueue_op)
+
+      threads = [
+          self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
+      ]
+      for thread in threads:
+        thread.start()
+      for thread in threads:
+        thread.join()
+
+      # Dequeue every element using a single thread.
+      results = []
+      for _ in xrange(len(elems)):
+        results.append(dequeued_t.eval())
+      self.assertItemsEqual(elems, results)
+
+  def testParallelDequeue(self):
+    with self.test_session() as sess, self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+      enqueue_ops = [q.enqueue((x,)) for x in elems]
+      dequeued_t = q.dequeue()
+
+      # Enqueue every element using a single thread.
+      for enqueue_op in enqueue_ops:
+        enqueue_op.run()
+
+      # Run one consumer thread for each element in elems.
+      results = []
+
+      def dequeue():
+        results.append(sess.run(dequeued_t))
+
+      threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+      for thread in threads:
+        thread.start()
+      for thread in threads:
+        thread.join()
+      self.assertItemsEqual(elems, results)
+
+  def testDequeue(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      elems = [10.0, 20.0, 30.0]
+      enqueue_ops = [q.enqueue((x,)) for x in elems]
+      dequeued_t = q.dequeue()
+
+      for enqueue_op in enqueue_ops:
+        enqueue_op.run()
+
+      for i in xrange(len(elems)):
+        vals = dequeued_t.eval()
+        self.assertEqual([elems[i]], vals)
+
+  def testEnqueueAndBlockingDequeue(self):
+    with self.test_session() as sess, self.test_scope():
+      q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
+      elems = [10.0, 20.0, 30.0]
+      enqueue_ops = [q.enqueue((x,)) for x in elems]
+      dequeued_t = q.dequeue()
+
+      def enqueue():
+        # The enqueue_ops should run after the dequeue op has blocked.
+        # TODO(mrry): Figure out how to do this without sleeping.
+        time.sleep(0.1)
+        for enqueue_op in enqueue_ops:
+          sess.run(enqueue_op)
+
+      results = []
+
+      def dequeue():
+        for _ in xrange(len(elems)):
+          results.append(sess.run(dequeued_t))
+
+      enqueue_thread = self.checkedThread(target=enqueue)
+      dequeue_thread = self.checkedThread(target=dequeue)
+      enqueue_thread.start()
+      dequeue_thread.start()
+      enqueue_thread.join()
+      dequeue_thread.join()
+
+      for elem, result in zip(elems, results):
+        self.assertEqual([elem], result)
+
+  def testMultiEnqueueAndDequeue(self):
+    with self.test_session() as sess, self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
+      elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+      enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+      dequeued_t = q.dequeue()
+
+      for enqueue_op in enqueue_ops:
+        enqueue_op.run()
+
+      for i in xrange(len(elems)):
+        x_val, y_val = sess.run(dequeued_t)
+        x, y = elems[i]
+        self.assertEqual([x], x_val)
+        self.assertEqual([y], y_val)
+
+  def testQueueSizeEmpty(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      self.assertEqual([0], q.size().eval())
+
+  def testQueueSizeAfterEnqueueAndDequeue(self):
+    with self.test_session(), self.test_scope():
+      q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+      enqueue_op = q.enqueue((10.0,))
+      dequeued_t = q.dequeue()
+      size = q.size()
+      self.assertEqual([], size.get_shape())
+
+      enqueue_op.run()
+      self.assertEqual(1, size.eval())
+      dequeued_t.op.run()
+      self.assertEqual(0, size.eval())
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 89db9ee2794..6e7423f85e3 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
 tensorflow/core/kernels/reduction_ops_any.cc
 tensorflow/core/kernels/reduction_ops_all.cc
 tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/queue_op.cc
 tensorflow/core/kernels/queue_ops.cc
 tensorflow/core/kernels/queue_base.cc
 tensorflow/core/kernels/pooling_ops_common.cc
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 813ec6eed58..0a8da8b3bf0 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -43,9 +43,15 @@ template <typename T>
 class ResourceOpKernel : public OpKernel {
  public:
   explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
-    OP_REQUIRES_OK(context,
-                   context->allocate_persistent(DT_STRING, TensorShape({2}),
-                                                &handle_, nullptr));
+    has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
+    if (!has_resource_type_) {
+      // The resource variant of the op may be placed on non-CPU devices, but
+      // this allocation is always on the host. Fortunately we don't need it in
+      // the resource case.
+      OP_REQUIRES_OK(context,
+                     context->allocate_persistent(DT_STRING, TensorShape({2}),
+                                                  &handle_, nullptr));
+    }
   }
 
   // The resource is deleted from the resource manager only when it is private
@@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
         return;
       }
 
-      auto h = handle_.AccessTensor(context)->template flat<string>();
-      h(0) = cinfo_.container();
-      h(1) = cinfo_.name();
+      if (!has_resource_type_) {
+        auto h = handle_.AccessTensor(context)->template flat<string>();
+        h(0) = cinfo_.container();
+        h(1) = cinfo_.name();
+      }
       resource_ = resource;
     }
-    if (context->expected_output_dtype(0) == DT_RESOURCE) {
+    if (has_resource_type_) {
       OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
                                   context, 0, cinfo_.container(), cinfo_.name(),
                                   MakeTypeIndex<T>()));
@@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
   virtual Status VerifyResource(T* resource) { return Status::OK(); }
 
   PersistentTensor handle_ GUARDED_BY(mu_);
+
+  // Is the output of the operator of type DT_RESOURCE?
+  bool has_resource_type_;
 };
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cbe30cdca14..861fb1ef697 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -368,6 +368,7 @@ cc_library(
 
 cc_library(
     name = "queue_op",
+    srcs = ["queue_op.cc"],
     hdrs = ["queue_op.h"],
     deps = [
         ":queue_base",
@@ -1885,9 +1886,10 @@ cc_library(
     name = "fifo_queue",
     srcs = ["fifo_queue.cc"],
     hdrs = ["fifo_queue.h"],
-    visibility = ["//visibility:private"],
+    visibility = [":friends"],
     deps = [
         ":queue_base",
+        ":queue_op",
         ":typed_queue",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -5076,6 +5078,7 @@ filegroup(
         "padding_fifo_queue.cc",
         "padding_fifo_queue_op.cc",
         "queue_base.cc",
+        "queue_op.cc",
         "queue_ops.cc",
         "random_op.cc",
         "reduction_ops_all.cc",
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index a23478af5b5..d6e859f1aa0 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
   return Status::OK();
 }
 
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
+    : TypedQueueOp(context) {
+  OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+}
+
+Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
+  FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
+                                   component_shapes_, cinfo_.name());
+  return CreateTypedQueue(queue, ret);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index f01d70924d0..697ee81c39b 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
 
 #include <deque>
 #include <vector>
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
 #include "tensorflow/core/kernels/typed_queue.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
   TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
 };
 
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class FIFOQueueOp : public TypedQueueOp {
+ public:
+  explicit FIFOQueueOp(OpKernelConstruction* context);
+
+ private:
+  Status CreateResource(QueueInterface** ret) override
+      EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  std::vector<TensorShape> component_shapes_;
+  TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
+};
+
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#endif  // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
index b35bdbb2f01..80869768f18 100644
--- a/tensorflow/core/kernels/fifo_queue_op.cc
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// See docs in ../ops/data_flow_ops.cc.
-
-#include <deque>
-#include <vector>
-
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/fifo_queue.h"
-#include "tensorflow/core/kernels/queue_base.h"
-#include "tensorflow/core/kernels/queue_op.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
-// Defines a FIFOQueueOp, which produces a Queue (specifically, one
-// backed by FIFOQueue) that persists across different graph
-// executions, and sessions. Running this op produces a single-element
-// tensor of handles to Queues in the corresponding device.
-class FIFOQueueOp : public TypedQueueOp {
- public:
-  explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
-  }
-
- private:
-  Status CreateResource(QueueInterface** ret) override
-      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
-                                     component_shapes_, cinfo_.name());
-    return CreateTypedQueue(queue, ret);
-  }
-
-  std::vector<TensorShape> component_shapes_;
-  TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
 REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
 
diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc
new file mode 100644
index 00000000000..53f431ef3c7
--- /dev/null
+++ b/tensorflow/core/kernels/queue_op.cc
@@ -0,0 +1,367 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/queue_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
+  OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+  if (capacity_ < 0) {
+    capacity_ = QueueBase::kUnbounded;
+  }
+  OP_REQUIRES_OK(context,
+                 context->GetAttr("component_types", &component_types_));
+}
+
+void QueueOp::Compute(OpKernelContext* context) {
+  ResourceOpKernel<QueueInterface>::Compute(context);
+  mutex_lock l(mu_);
+  if (resource_ && context->track_allocations()) {
+    context->record_persistent_memory_allocation(resource_->MemoryUsed());
+  }
+}
+
+Status QueueOp::VerifyResource(QueueInterface* queue) {
+  return queue->MatchesNodeDef(def());
+}
+
+
+QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
+    : AsyncOpKernel(context) {}
+
+void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
+  QueueInterface* queue;
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    OP_REQUIRES_OK_ASYNC(
+        ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
+  } else {
+    OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
+                         callback);
+  }
+  ComputeAsync(ctx, queue, [callback, queue]() {
+    queue->Unref();
+    callback();
+  });
+}
+
+QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
+    : QueueOpKernel(context) {
+  OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
+  // TODO(keveman): Enable timeout.
+  OP_REQUIRES(context, timeout_ == -1,
+              errors::InvalidArgument("Timeout not supported yet."));
+}
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+EnqueueOp::EnqueueOp(OpKernelConstruction* context)
+    : QueueAccessOpKernel(context) {}
+
+void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                             DoneCallback callback) {
+  DataTypeVector expected_inputs;
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    expected_inputs.push_back(DT_RESOURCE);
+  } else {
+    expected_inputs.push_back(DT_STRING_REF);
+  }
+  for (DataType dt : queue->component_dtypes()) {
+    expected_inputs.push_back(dt);
+  }
+  OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+  QueueInterface::Tuple tuple;
+  OpInputList components;
+  OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+                       callback);
+  for (const Tensor& Tcomponent : components) {
+    tuple.push_back(Tcomponent);
+  }
+
+  OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
+  queue->TryEnqueue(tuple, ctx, callback);
+}
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
+    : QueueAccessOpKernel(context) {}
+
+void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                                 DoneCallback callback) {
+  DataTypeVector expected_inputs;
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    expected_inputs.push_back(DT_RESOURCE);
+  } else {
+    expected_inputs.push_back(DT_STRING_REF);
+  }
+  for (DataType dt : queue->component_dtypes()) {
+    expected_inputs.push_back(dt);
+  }
+  OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+  QueueInterface::Tuple tuple;
+  OpInputList components;
+  OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+                       callback);
+  for (const Tensor& Tcomponent : components) {
+    tuple.push_back(Tcomponent);
+  }
+
+  OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
+  queue->TryEnqueueMany(tuple, ctx, callback);
+}
+
+EnqueueManyOp::~EnqueueManyOp() = default;
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+DequeueOp::DequeueOp(OpKernelConstruction* context)
+    : QueueAccessOpKernel(context) {}
+
+void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                             DoneCallback callback) {
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    OP_REQUIRES_OK_ASYNC(
+        ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
+        callback);
+  } else {
+    OP_REQUIRES_OK_ASYNC(
+        ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
+        callback);
+  }
+
+  queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+    if (!ctx->status().ok()) {
+      callback();
+      return;
+    }
+    OpOutputList output_components;
+    OP_REQUIRES_OK_ASYNC(
+        ctx, ctx->output_list("components", &output_components), callback);
+    for (int i = 0; i < ctx->num_outputs(); ++i) {
+      output_components.set(i, tuple[i]);
+    }
+    callback();
+  });
+}
+
+DequeueOp::~DequeueOp() = default;
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
+    : QueueAccessOpKernel(context) {}
+
+void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                                 DoneCallback callback) {
+  const Tensor& Tnum_elements = ctx->input(1);
+  int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+  OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+                    errors::InvalidArgument("DequeueManyOp requested ",
+                                            num_elements, " < 0 elements"),
+                    callback);
+
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    OP_REQUIRES_OK_ASYNC(
+        ctx,
+        ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+        callback);
+  } else {
+    OP_REQUIRES_OK_ASYNC(ctx,
+                         ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+                                             queue->component_dtypes()),
+                         callback);
+  }
+
+  queue->TryDequeueMany(
+      num_elements, ctx, false /* allow_small_batch */,
+      [ctx, callback](const QueueInterface::Tuple& tuple) {
+        if (!ctx->status().ok()) {
+          callback();
+          return;
+        }
+        OpOutputList output_components;
+        OP_REQUIRES_OK_ASYNC(
+            ctx, ctx->output_list("components", &output_components), callback);
+        for (int i = 0; i < ctx->num_outputs(); ++i) {
+          output_components.set(i, tuple[i]);
+        }
+        callback();
+      });
+}
+
+DequeueManyOp::~DequeueManyOp() = default;
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed.  While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch.  If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes.  In this case,
+// the Queue does not block when closed.
+DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
+    : QueueAccessOpKernel(context) {}
+
+void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                                 DoneCallback callback) {
+  const Tensor& Tnum_elements = ctx->input(1);
+  int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+  OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+                    errors::InvalidArgument("DequeueUpToOp requested ",
+                                            num_elements, " < 0 elements"),
+                    callback);
+
+  if (ctx->input_dtype(0) == DT_RESOURCE) {
+    OP_REQUIRES_OK_ASYNC(
+        ctx,
+        ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+        callback);
+  } else {
+    OP_REQUIRES_OK_ASYNC(ctx,
+                         ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+                                             queue->component_dtypes()),
+                         callback);
+  }
+
+  queue->TryDequeueMany(
+      num_elements, ctx, true /* allow_small_batch */,
+      [ctx, callback](const QueueInterface::Tuple& tuple) {
+        if (!ctx->status().ok()) {
+          callback();
+          return;
+        }
+        OpOutputList output_components;
+        OP_REQUIRES_OK_ASYNC(
+            ctx, ctx->output_list("components", &output_components), callback);
+        for (int i = 0; i < ctx->num_outputs(); ++i) {
+          output_components.set(i, tuple[i]);
+        }
+        callback();
+      });
+}
+
+DequeueUpToOp::~DequeueUpToOp() = default;
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
+    : QueueOpKernel(context) {
+  OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
+                                           &cancel_pending_enqueues_));
+}
+
+void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                                DoneCallback callback) {
+  queue->Close(ctx, cancel_pending_enqueues_, callback);
+}
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
+    : QueueOpKernel(context) {}
+
+void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                               DoneCallback callback) {
+  Tensor* Tqueue_size = nullptr;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
+  Tqueue_size->flat<int32>().setConstant(queue->size());
+  callback();
+}
+
+QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
+    : QueueOpKernel(context) {}
+
+void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                                   DoneCallback callback) {
+  Tensor* Tqueue_is_closed = nullptr;
+  OP_REQUIRES_OK(ctx,
+                 ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
+  Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
+  callback();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
index 6c19f9841cd..2efd838a5fb 100644
--- a/tensorflow/core/kernels/queue_op.h
+++ b/tensorflow/core/kernels/queue_op.h
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
-#define TENSORFLOW_KERNELS_QUEUE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
 
 #include <deque>
 
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
 #include "tensorflow/core/framework/resource_op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
@@ -32,22 +33,9 @@ namespace tensorflow {
 // Defines a QueueOp, an abstract class for Queue construction ops.
 class QueueOp : public ResourceOpKernel<QueueInterface> {
  public:
-  QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
-    if (capacity_ < 0) {
-      capacity_ = QueueBase::kUnbounded;
-    }
-    OP_REQUIRES_OK(context,
-                   context->GetAttr("component_types", &component_types_));
-  }
+  QueueOp(OpKernelConstruction* context);
 
-  void Compute(OpKernelContext* context) override {
-    ResourceOpKernel<QueueInterface>::Compute(context);
-    mutex_lock l(mu_);
-    if (resource_ && context->track_allocations()) {
-      context->record_persistent_memory_allocation(resource_->MemoryUsed());
-    }
-  }
+  void Compute(OpKernelContext* context) override;
 
  protected:
   // Variables accessible by subclasses
@@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
   DataTypeVector component_types_;
 
  private:
-  Status VerifyResource(QueueInterface* queue) override {
-    return queue->MatchesNodeDef(def());
-  }
+  Status VerifyResource(QueueInterface* queue) override;
 };
 
 class TypedQueueOp : public QueueOp {
@@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
   }
 };
 
+// Queue manipulator kernels
+
+class QueueOpKernel : public AsyncOpKernel {
+ public:
+  explicit QueueOpKernel(OpKernelConstruction* context);
+
+  void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
+
+ protected:
+  virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                            DoneCallback callback) = 0;
+};
+
+class QueueAccessOpKernel : public QueueOpKernel {
+ public:
+  explicit QueueAccessOpKernel(OpKernelConstruction* context);
+
+ protected:
+  int64 timeout_;
+};
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+class EnqueueOp : public QueueAccessOpKernel {
+ public:
+  explicit EnqueueOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
+};
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+class EnqueueManyOp : public QueueAccessOpKernel {
+ public:
+  explicit EnqueueManyOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+  ~EnqueueManyOp() override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
+};
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+class DequeueOp : public QueueAccessOpKernel {
+ public:
+  explicit DequeueOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+  ~DequeueOp() override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
+};
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+class DequeueManyOp : public QueueAccessOpKernel {
+ public:
+  explicit DequeueManyOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+  ~DequeueManyOp() override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
+};
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed.  While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch.  If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes.  In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+  explicit DequeueUpToOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+  ~DequeueUpToOp() override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+class QueueCloseOp : public QueueOpKernel {
+ public:
+  explicit QueueCloseOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+ private:
+  bool cancel_pending_enqueues_;
+  TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
+};
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+class QueueSizeOp : public QueueOpKernel {
+ public:
+  explicit QueueSizeOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
+};
+
+class QueueIsClosedOp : public QueueOpKernel {
+ public:
+  explicit QueueIsClosedOp(OpKernelConstruction* context);
+
+ protected:
+  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+                    DoneCallback callback) override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
+};
+
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_KERNELS_QUEUE_OP_H_
+#endif  // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index 46a02854d73..c4d404259ba 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// See docs in ../ops/data_flow_ops.cc.
-
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/queue_interface.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
-class QueueOpKernel : public AsyncOpKernel {
- public:
-  explicit QueueOpKernel(OpKernelConstruction* context)
-      : AsyncOpKernel(context) {}
-
-  void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
-    QueueInterface* queue;
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK_ASYNC(
-          ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
-    } else {
-      OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
-                           callback);
-    }
-    ComputeAsync(ctx, queue, [callback, queue]() {
-      queue->Unref();
-      callback();
-    });
-  }
-
- protected:
-  virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                            DoneCallback callback) = 0;
-};
-
-class QueueAccessOpKernel : public QueueOpKernel {
- public:
-  explicit QueueAccessOpKernel(OpKernelConstruction* context)
-      : QueueOpKernel(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
-    // TODO(keveman): Enable timeout.
-    OP_REQUIRES(context, timeout_ == -1,
-                errors::InvalidArgument("Timeout not supported yet."));
-  }
-
- protected:
-  int64 timeout_;
-};
-
-// Defines an EnqueueOp, the execution of which enqueues a tuple of
-// tensors in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-class EnqueueOp : public QueueAccessOpKernel {
- public:
-  explicit EnqueueOp(OpKernelConstruction* context)
-      : QueueAccessOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    DataTypeVector expected_inputs;
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      expected_inputs.push_back(DT_RESOURCE);
-    } else {
-      expected_inputs.push_back(DT_STRING_REF);
-    }
-    for (DataType dt : queue->component_dtypes()) {
-      expected_inputs.push_back(dt);
-    }
-    OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
-                         callback);
-
-    QueueInterface::Tuple tuple;
-    OpInputList components;
-    OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
-                         callback);
-    for (const Tensor& Tcomponent : components) {
-      tuple.push_back(Tcomponent);
-    }
-
-    OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
-    queue->TryEnqueue(tuple, ctx, callback);
-  }
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
 
-// Defines an EnqueueManyOp, the execution of which slices each
-// component of a tuple of tensors along the 0th dimension, and
-// enqueues tuples of slices in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-//
-// N.B. All tuple components must have the same size in the 0th
-// dimension.
-class EnqueueManyOp : public QueueAccessOpKernel {
- public:
-  explicit EnqueueManyOp(OpKernelConstruction* context)
-      : QueueAccessOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    DataTypeVector expected_inputs;
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      expected_inputs.push_back(DT_RESOURCE);
-    } else {
-      expected_inputs.push_back(DT_STRING_REF);
-    }
-    for (DataType dt : queue->component_dtypes()) {
-      expected_inputs.push_back(dt);
-    }
-    OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
-                         callback);
-
-    QueueInterface::Tuple tuple;
-    OpInputList components;
-    OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
-                         callback);
-    for (const Tensor& Tcomponent : components) {
-      tuple.push_back(Tcomponent);
-    }
-
-    OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
-    queue->TryEnqueueMany(tuple, ctx, callback);
-  }
-
-  ~EnqueueManyOp() override {}
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
                         EnqueueManyOp);
 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
                         EnqueueManyOp);
 
-// Defines a DequeueOp, the execution of which dequeues a tuple of
-// tensors from the given Queue.
-//
-// The op has one input, which is the handle of the appropriate
-// Queue. The op has k outputs, where k is the number of components in
-// the tuples stored in the given Queue, and output i is the ith
-// component of the dequeued tuple.
-class DequeueOp : public QueueAccessOpKernel {
- public:
-  explicit DequeueOp(OpKernelConstruction* context)
-      : QueueAccessOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK_ASYNC(
-          ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
-          callback);
-    } else {
-      OP_REQUIRES_OK_ASYNC(
-          ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
-          callback);
-    }
-
-    queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
-      if (!ctx->status().ok()) {
-        callback();
-        return;
-      }
-      OpOutputList output_components;
-      OP_REQUIRES_OK_ASYNC(
-          ctx, ctx->output_list("components", &output_components), callback);
-      for (int i = 0; i < ctx->num_outputs(); ++i) {
-        output_components.set(i, tuple[i]);
-      }
-      callback();
-    });
-  }
-
-  ~DequeueOp() override {}
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
 REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
 
-// Defines a DequeueManyOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-class DequeueManyOp : public QueueAccessOpKernel {
- public:
-  explicit DequeueManyOp(OpKernelConstruction* context)
-      : QueueAccessOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    const Tensor& Tnum_elements = ctx->input(1);
-    int32 num_elements = Tnum_elements.flat<int32>()(0);
-
-    OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
-                      errors::InvalidArgument("DequeueManyOp requested ",
-                                              num_elements, " < 0 elements"),
-                      callback);
-
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK_ASYNC(ctx,
-                           ctx->MatchSignature({DT_RESOURCE, DT_INT32},
-                                               queue->component_dtypes()),
-                           callback);
-    } else {
-      OP_REQUIRES_OK_ASYNC(ctx,
-                           ctx->MatchSignature({DT_STRING_REF, DT_INT32},
-                                               queue->component_dtypes()),
-                           callback);
-    }
-
-    queue->TryDequeueMany(
-        num_elements, ctx, false /* allow_small_batch */,
-        [ctx, callback](const QueueInterface::Tuple& tuple) {
-          if (!ctx->status().ok()) {
-            callback();
-            return;
-          }
-          OpOutputList output_components;
-          OP_REQUIRES_OK_ASYNC(
-              ctx, ctx->output_list("components", &output_components),
-              callback);
-          for (int i = 0; i < ctx->num_outputs(); ++i) {
-            output_components.set(i, tuple[i]);
-          }
-          callback();
-        });
-  }
-
-  ~DequeueManyOp() override {}
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
                         DequeueManyOp);
 REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
                         DequeueManyOp);
 
-// Defines a DequeueUpToOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The difference between this op and DequeueMany is the handling when
-// the Queue is closed.  While the DequeueMany op will return if there
-// an error when there are less than num_elements elements left in the
-// closed queue, this op will return between 1 and
-// min(num_elements, elements_remaining_in_queue), and will not block.
-// If there are no elements left, then the standard DequeueMany error
-// is returned.
-//
-// This op only works if the underlying Queue implementation accepts
-// the allow_small_batch = true parameter to TryDequeueMany.
-// If it does not, an errors::Unimplemented exception is returned.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-//
-// The op has one attribute: allow_small_batch.  If the Queue supports
-// it, setting this to true causes the queue to return smaller
-// (possibly zero length) batches when it is closed, up to however
-// many elements are available when the op executes.  In this case,
-// the Queue does not block when closed.
-class DequeueUpToOp : public QueueAccessOpKernel {
- public:
-  explicit DequeueUpToOp(OpKernelConstruction* context)
-      : QueueAccessOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    const Tensor& Tnum_elements = ctx->input(1);
-    int32 num_elements = Tnum_elements.flat<int32>()(0);
-
-    OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
-                      errors::InvalidArgument("DequeueUpToOp requested ",
-                                              num_elements, " < 0 elements"),
-                      callback);
-
-    if (ctx->input_dtype(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK_ASYNC(ctx,
-                           ctx->MatchSignature({DT_RESOURCE, DT_INT32},
-                                               queue->component_dtypes()),
-                           callback);
-    } else {
-      OP_REQUIRES_OK_ASYNC(ctx,
-                           ctx->MatchSignature({DT_STRING_REF, DT_INT32},
-                                               queue->component_dtypes()),
-                           callback);
-    }
-
-    queue->TryDequeueMany(
-        num_elements, ctx, true /* allow_small_batch */,
-        [ctx, callback](const QueueInterface::Tuple& tuple) {
-          if (!ctx->status().ok()) {
-            callback();
-            return;
-          }
-          OpOutputList output_components;
-          OP_REQUIRES_OK_ASYNC(
-              ctx, ctx->output_list("components", &output_components),
-              callback);
-          for (int i = 0; i < ctx->num_outputs(); ++i) {
-            output_components.set(i, tuple[i]);
-          }
-          callback();
-        });
-  }
-
-  ~DequeueUpToOp() override {}
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
                         DequeueUpToOp);
 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
                         DequeueUpToOp);
 
-// Defines a QueueCloseOp, which closes the given Queue. Closing a
-// Queue signals that no more elements will be enqueued in it.
-//
-// The op has one input, which is the handle of the appropriate Queue.
-class QueueCloseOp : public QueueOpKernel {
- public:
-  explicit QueueCloseOp(OpKernelConstruction* context)
-      : QueueOpKernel(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
-                                             &cancel_pending_enqueues_));
-  }
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    queue->Close(ctx, cancel_pending_enqueues_, callback);
-  }
-
- private:
-  bool cancel_pending_enqueues_;
-  TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
 REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
 
-// Defines a QueueSizeOp, which computes the number of elements in the
-// given Queue, and emits it as an output tensor.
-//
-// The op has one input, which is the handle of the appropriate Queue;
-// and one output, which is a single-element tensor containing the current
-// size of that Queue.
-class QueueSizeOp : public QueueOpKernel {
- public:
-  explicit QueueSizeOp(OpKernelConstruction* context)
-      : QueueOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    Tensor* Tqueue_size = nullptr;
-    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
-    Tqueue_size->flat<int32>().setConstant(queue->size());
-    callback();
-  }
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
 REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
 
-class QueueIsClosedOp : public QueueOpKernel {
- public:
-  explicit QueueIsClosedOp(OpKernelConstruction* context)
-      : QueueOpKernel(context) {}
-
- protected:
-  void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
-                    DoneCallback callback) override {
-    Tensor* Tqueue_is_closed = nullptr;
-    OP_REQUIRES_OK(ctx,
-                   ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
-    Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
-    callback();
-  }
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
-};
-
 REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
                         QueueIsClosedOp);
 REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),