From 710edb74e65ae2593040c8fff4934eeeb49e676f Mon Sep 17 00:00:00 2001
From: Illia Polosukhin <ilblackdragon@gmail.com>
Date: Wed, 25 May 2016 18:06:43 -0800
Subject: [PATCH] tf.learn: Added functionality to run over data only once
 (num_epochs=1).

Speed up read_batch_example by adding more reading ops before queueing (if reading is slow operation, before it would create a bottleneck on one op).

Speed up read_batch_features by adding parsing_num_threads of parsing ops and queue for aggregating them (currently only one parsing op was used syncronously with training).
Change: 123278248
---
 .../contrib/learn/python/learn/io/graph_io.py |  88 ++++++++----
 .../learn/python/learn/io/graph_io_test.py    | 135 ++++++++++++++----
 2 files changed, 167 insertions(+), 56 deletions(-)

diff --git a/tensorflow/contrib/learn/python/learn/io/graph_io.py b/tensorflow/contrib/learn/python/learn/io/graph_io.py
index b9fffb2fb0c..bd1f4f3c0e6 100644
--- a/tensorflow/contrib/learn/python/learn/io/graph_io.py
+++ b/tensorflow/contrib/learn/python/learn/io/graph_io.py
@@ -26,8 +26,9 @@ from tensorflow.python.training import input as input_ops
 
 
 def read_batch_examples(file_pattern, batch_size, reader,
-                        randomize_input=True, queue_capacity=10000,
-                        num_threads=1, name='dequeue_examples'):
+                        randomize_input=True, num_epochs=None,
+                        queue_capacity=10000, num_threads=1,
+                        name=None):
   """Adds operations to read, queue, batch `Example` protos.
 
   Given file pattern (or list of files), will setup a queue for file names,
@@ -46,6 +47,10 @@ def read_batch_examples(file_pattern, batch_size, reader,
     reader: A function or class that returns an object with
       `read` method, (filename tensor) -> (example tensor).
     randomize_input: Whether the input should be randomized.
+    num_epochs: Integer specifying the number of times to read through the
+      dataset. If `None`, cycles through the dataset forever.
+      NOTE - If specified, creates a variable that must be initialized, so call
+      `tf.initialize_all_variables()` as shown in the tests.
     queue_capacity: Capacity for input queue.
     num_threads: The number of threads enqueuing examples.
     name: Name of resulting op.
@@ -82,39 +87,47 @@ def read_batch_examples(file_pattern, batch_size, reader,
         (batch_size, queue_capacity))
   if (not num_threads) or (num_threads <= 0):
     raise ValueError('Invalid num_threads %s.' % num_threads)
+  if (num_epochs is not None) and (num_epochs <= 0):
+    raise ValueError('Invalid num_epochs %s.' % num_epochs)
 
-  with ops.name_scope(name) as scope:
+  with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
     # Setup filename queue with shuffling.
     with ops.name_scope('file_name_queue') as file_name_queue_scope:
       file_name_queue = input_ops.string_input_producer(
           constant_op.constant(file_names, name='input'),
-          shuffle=randomize_input, name=file_name_queue_scope)
+          shuffle=randomize_input, num_epochs=num_epochs,
+          name=file_name_queue_scope)
 
-    # Create reader and set it to read from filename queue.
+    # Create readers, one per thread and set them to read from filename queue.
     with ops.name_scope('read'):
-      _, example_proto = reader().read(file_name_queue)
+      example_list = []
+      for _ in range(num_threads):
+        _, example_proto = reader().read(file_name_queue)
+        example_list.append([example_proto])
 
-    # Setup batching queue.
+    # Setup batching queue given list of read example tensors.
     if randomize_input:
       if isinstance(batch_size, ops.Tensor):
         min_after_dequeue = int(queue_capacity * 0.4)
       else:
         min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
-      examples = input_ops.shuffle_batch(
-          [example_proto], batch_size, capacity=queue_capacity,
-          num_threads=num_threads, min_after_dequeue=min_after_dequeue,
+      examples = input_ops.shuffle_batch_join(
+          example_list, batch_size, capacity=queue_capacity,
+          min_after_dequeue=min_after_dequeue,
           name=scope)
     else:
-      examples = input_ops.batch(
-          [example_proto], batch_size, capacity=queue_capacity,
-          num_threads=num_threads, name=scope)
+      examples = input_ops.batch_join(
+          example_list, batch_size, capacity=queue_capacity,
+          name=scope)
 
     return examples
 
 
 def read_batch_features(file_pattern, batch_size, features, reader,
-                        randomize_input=True, queue_capacity=10000,
-                        num_threads=1, name='dequeue_examples'):
+                        randomize_input=True, num_epochs=None,
+                        queue_capacity=10000, reader_num_threads=1,
+                        parser_num_threads=1,
+                        name=None):
   """Adds operations to read, queue, batch and parse `Example` protos.
 
   Given file pattern (or list of files), will setup a queue for file names,
@@ -136,8 +149,13 @@ def read_batch_features(file_pattern, batch_size, features, reader,
     reader: A function or class that returns an object with
       `read` method, (filename tensor) -> (example tensor).
     randomize_input: Whether the input should be randomized.
+    num_epochs: Integer specifying the number of times to read through the
+      dataset. If None, cycles through the dataset forever. NOTE - If specified,
+      creates a variable that must be initialized, so call
+      tf.initialize_all_variables() as shown in the tests.
     queue_capacity: Capacity for input queue.
-    num_threads: The number of threads enqueuing examples.
+    reader_num_threads: The number of threads to read examples.
+    parser_num_threads: The number of threads to parse examples.
     name: Name of resulting op.
 
   Returns:
@@ -146,17 +164,29 @@ def read_batch_features(file_pattern, batch_size, features, reader,
   Raises:
     ValueError: for invalid inputs.
   """
-  examples = read_batch_examples(
-      file_pattern, batch_size, reader, randomize_input,
-      queue_capacity, num_threads, name=name)
+  with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
+    examples = read_batch_examples(
+        file_pattern, batch_size, reader, randomize_input=randomize_input,
+        num_epochs=num_epochs, queue_capacity=queue_capacity,
+        num_threads=reader_num_threads, name=scope)
 
-  # Parse features into tensors.
-  return parsing_ops.parse_example(examples, features)
+    # Parse features into tensors in many threads and put on the queue.
+    features_list = []
+    for _ in range(parser_num_threads):
+      features_list.append(parsing_ops.parse_example(examples, features))
+    return input_ops.batch_join(
+        features_list,
+        batch_size=batch_size,
+        capacity=queue_capacity,
+        enqueue_many=True,
+        name='parse_example_batch_join')
 
 
 def read_batch_record_features(file_pattern, batch_size, features,
-                               randomize_input=True, queue_capacity=10000,
-                               num_threads=1, name='dequeue_record_examples'):
+                               randomize_input=True, num_epochs=None,
+                               queue_capacity=10000, reader_num_threads=1,
+                               parser_num_threads=1,
+                               name='dequeue_record_examples'):
   """Reads TFRecord, queues, batches and parses `Example` proto.
 
   See more detailed description in `read_examples`.
@@ -168,8 +198,13 @@ def read_batch_record_features(file_pattern, batch_size, features,
     features: A `dict` mapping feature keys to `FixedLenFeature` or
       `VarLenFeature` values.
     randomize_input: Whether the input should be randomized.
+    num_epochs: Integer specifying the number of times to read through the
+      dataset. If None, cycles through the dataset forever. NOTE - If specified,
+      creates a variable that must be initialized, so call
+      tf.initialize_all_variables() as shown in the tests.
     queue_capacity: Capacity for input queue.
-    num_threads: The number of threads enqueuing examples.
+    reader_num_threads: The number of threads to read examples.
+    parser_num_threads: The number of threads to parse examples.
     name: Name of resulting op.
 
   Returns:
@@ -181,5 +216,6 @@ def read_batch_record_features(file_pattern, batch_size, features,
   return read_batch_features(
       file_pattern=file_pattern, batch_size=batch_size, features=features,
       reader=io_ops.TFRecordReader,
-      randomize_input=randomize_input,
-      queue_capacity=queue_capacity, num_threads=num_threads, name=name)
+      randomize_input=randomize_input, num_epochs=num_epochs,
+      queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
+      parser_num_threads=parser_num_threads, name=name)
diff --git a/tensorflow/contrib/learn/python/learn/io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/io/graph_io_test.py
index 175c29ac4ed..a3d70164ca5 100644
--- a/tensorflow/contrib/learn/python/learn/io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/io/graph_io_test.py
@@ -17,10 +17,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 import random
+import tempfile
 
 import tensorflow as tf
 
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import gfile
 
@@ -55,44 +58,83 @@ class GraphIOTest(tf.test.TestCase):
 
     self.assertRaisesRegexp(
         ValueError, "No files match",
-        tf.contrib.learn.io.read_batch_features,
-        _INVALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
-        False, queue_capacity,
-        num_threads, name)
+        tf.contrib.learn.io.read_batch_examples,
+        _INVALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=num_threads, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid batch_size",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, None, None, tf.TFRecordReader,
-        False, queue_capacity, num_threads, name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, None, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=num_threads, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid batch_size",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, -1, None, tf.TFRecordReader,
-        False, queue_capacity, num_threads, name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, -1, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=num_threads, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid queue_capacity",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
-        False, None, num_threads, name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=None,
+        num_threads=num_threads, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid num_threads",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
-        False, queue_capacity, None,
-        name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=None, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid num_threads",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
-        False, queue_capacity, -1,
-        name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=-1, name=name)
     self.assertRaisesRegexp(
         ValueError, "Invalid batch_size",
-        tf.contrib.learn.io.read_batch_features,
-        _VALID_FILE_PATTERN, queue_capacity + 1, None, tf.TFRecordReader,
-        False, queue_capacity, 1, name)
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, queue_capacity + 1, tf.TFRecordReader,
+        False, num_epochs=None, queue_capacity=queue_capacity,
+        num_threads=1, name=name)
+    self.assertRaisesRegexp(
+        ValueError, "Invalid num_epochs",
+        tf.contrib.learn.io.read_batch_examples,
+        _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
+        False, num_epochs=-1, queue_capacity=queue_capacity, num_threads=1,
+        name=name)
 
-  def test_batch_tf_record(self):
+  def test_batch_record_features(self):
+    batch_size = 17
+    queue_capacity = 1234
+    name = "my_batch"
+    features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
+
+    with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
+      features = tf.contrib.learn.io.read_batch_record_features(
+          _VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
+          queue_capacity=queue_capacity, reader_num_threads=2,
+          parser_num_threads=2, name=name)
+      self.assertEquals("%s/parse_example_batch_join:0" % name,
+                        features["feature"].name)
+      file_name_queue_name = "%s/file_name_queue" % name
+      file_names_name = "%s/input" % file_name_queue_name
+      example_queue_name = "%s/fifo_queue" % name
+      parse_example_queue_name = "%s/parse_example_batch_join" % name
+      op_nodes = test_util.assert_ops_in_graph({
+          file_names_name: "Const",
+          file_name_queue_name: "FIFOQueue",
+          "%s/read/TFRecordReader" % name: "TFRecordReader",
+          example_queue_name: "FIFOQueue",
+          parse_example_queue_name: "QueueDequeueMany",
+          name: "QueueDequeueMany"
+      }, g)
+      self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
+      self.assertEqual(
+          queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
+
+  def test_one_epoch(self):
     batch_size = 17
     queue_capacity = 1234
     name = "my_batch"
@@ -100,20 +142,25 @@ class GraphIOTest(tf.test.TestCase):
     with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
       inputs = tf.contrib.learn.io.read_batch_examples(
           _VALID_FILE_PATTERN, batch_size,
-          reader=tf.TFRecordReader, randomize_input=False,
+          reader=tf.TFRecordReader, randomize_input=True,
+          num_epochs=1,
           queue_capacity=queue_capacity, name=name)
       self.assertEquals("%s:0" % name, inputs.name)
       file_name_queue_name = "%s/file_name_queue" % name
+      file_name_queue_limit_name = (
+          "%s/limit_epochs/epochs" % file_name_queue_name)
       file_names_name = "%s/input" % file_name_queue_name
-      example_queue_name = "%s/fifo_queue" % name
+      example_queue_name = "%s/random_shuffle_queue" % name
       op_nodes = test_util.assert_ops_in_graph({
           file_names_name: "Const",
           file_name_queue_name: "FIFOQueue",
           "%s/read/TFRecordReader" % name: "TFRecordReader",
-          example_queue_name: "FIFOQueue",
-          name: "QueueDequeueMany"
+          example_queue_name: "RandomShuffleQueue",
+          name: "QueueDequeueMany",
+          file_name_queue_limit_name: "Variable"
       }, g)
-      self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
+      self.assertEqual(
+          set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
       self.assertEqual(
           queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
 
@@ -143,6 +190,34 @@ class GraphIOTest(tf.test.TestCase):
       self.assertEqual(
           queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
 
+  def test_read_csv(self):
+    gfile.Glob = self._orig_glob
+    tempdir = tempfile.mkdtemp()
+    filename = os.path.join(tempdir, "file.csv")
+    gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n")
+
+    batch_size = 1
+    queue_capacity = 5
+    name = "my_batch"
+
+    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
+      inputs = tf.contrib.learn.io.read_batch_examples(
+          filename, batch_size,
+          reader=tf.TextLineReader, randomize_input=False,
+          num_epochs=1, queue_capacity=queue_capacity, name=name)
+      session.run(tf.initialize_all_variables())
+
+      coord = tf.train.Coordinator()
+      tf.train.start_queue_runners(session, coord=coord)
+
+      self.assertEqual(session.run(inputs), "ABC")
+      self.assertEqual(session.run(inputs), "DEF")
+      self.assertEqual(session.run(inputs), "GHK")
+      with self.assertRaises(errors.OutOfRangeError):
+        session.run(inputs)
+
+      coord.request_stop()
+
 
 if __name__ == "__main__":
   tf.test.main()