From bdad5cdcbe0bfcc0b235110a9f76ae80053bc80d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 3 Aug 2016 16:38:23 -0800
Subject: [PATCH] Add a queue for the output tensors from parse example Change:
 129279917

---
 .../learn/python/learn/learn_io/graph_io.py   | 112 ++++++++++++++----
 1 file changed, 88 insertions(+), 24 deletions(-)

diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
index 1709e428fc2..bf5e62cb4c0 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
@@ -20,12 +20,17 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import input as input_ops
-
+from tensorflow.python.training import queue_runner
 
 # Default name for key in the feature dict.
 KEY_FEATURE_NAME = '__key__'
@@ -219,11 +224,18 @@ def read_keyed_batch_examples(
     return queued_examples_with_keys
 
 
-def read_keyed_batch_features(
-    file_pattern, batch_size, features, reader,
-    randomize_input=True, num_epochs=None,
-    queue_capacity=10000, reader_num_threads=1,
-    parser_num_threads=1, name=None):
+def read_keyed_batch_features(file_pattern,
+                              batch_size,
+                              features,
+                              reader,
+                              randomize_input=True,
+                              num_epochs=None,
+                              queue_capacity=10000,
+                              reader_num_threads=1,
+                              feature_queue_capacity=100,
+                              num_queue_runners=2,
+                              parser_num_threads=None,
+                              name=None):
   """Adds operations to read, queue, batch and parse `Example` protos.
 
   Given file pattern (or list of files), will setup a queue for file names,
@@ -251,7 +263,12 @@ def read_keyed_batch_features(
       tf.initialize_local_variables() as shown in the tests.
     queue_capacity: Capacity for input queue.
     reader_num_threads: The number of threads to read examples.
-    parser_num_threads: The number of threads to parse examples.
+    feature_queue_capacity: Capacity of the parsed features queue.
+    num_queue_runners: Number of queue runners to start for the feature queue,
+      Adding multiple queue runners for the parsed example queue helps maintain
+      a full queue when the subsequent computations overall are cheaper than
+      parsing.
+    parser_num_threads: (Deprecated) The number of threads to parse examples.
     name: Name of resulting op.
 
   Returns:
@@ -261,6 +278,11 @@ def read_keyed_batch_features(
   Raises:
     ValueError: for invalid inputs.
   """
+
+  if parser_num_threads:
+    # TODO(sibyl-Aix6ihai): Remove on Sept 3 2016.
+    logging.warning('parser_num_threads is deprecated, it will be removed on'
+                    'Sept 3 2016')
   with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
     keys, examples = read_keyed_batch_examples(
         file_pattern, batch_size, reader, randomize_input=randomize_input,
@@ -268,24 +290,66 @@ def read_keyed_batch_features(
         num_threads=reader_num_threads, read_batch_size=batch_size,
         name=scope)
 
-    if parser_num_threads == 1:
-      # Avoid queue overhead for single thread
-      return keys, parsing_ops.parse_example(examples, features)
+    # Parse the example.
+    feature_map = parsing_ops.parse_example(examples, features)
 
-    # Parse features into tensors in many threads and put on the queue.
-    features_list = []
-    for _ in range(parser_num_threads):
-      feature_dict = parsing_ops.parse_example(examples, features)
-      feature_dict[KEY_FEATURE_NAME] = keys
-      features_list.append(feature_dict)
-    queued_features = input_ops.batch_join(
-        features_list,
-        batch_size=batch_size,
-        capacity=queue_capacity,
-        enqueue_many=True,
-        name='parse_example_batch_join')
-    queued_keys = queued_features.pop(KEY_FEATURE_NAME)
-    return queued_keys, queued_features
+    # Lets also add preprocessed tensors into the queue types for each item of
+    # the queue.
+    tensors_to_enqueue = []
+    # Each entry contains the key, and a boolean which indicates whether the
+    # tensor was a sparse tensor.
+    tensors_mapping = []
+    # TODO(sibyl-Aix6ihai): Most of the functionality here is about pushing sparse
+    # tensors into a queue. This could be taken care in somewhere else so others
+    # can reuse it. Also, QueueBase maybe extended to handle sparse tensors
+    # directly.
+    for key, tensor in feature_map.iteritems():
+      if isinstance(tensor, ops.SparseTensor):
+        tensors_mapping.append((key, True))
+        tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
+      else:
+        tensors_mapping.append((key, False))
+        tensors_to_enqueue.append(tensor)
+    tensors_to_enqueue.append(keys)
+
+    queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+    input_queue = data_flow_ops.FIFOQueue(feature_queue_capacity, queue_dtypes)
+
+    # Add a summary op to debug if our feature queue is full or not.
+    logging_ops.scalar_summary('queue/parsed_features/%s/fraction_of_%d_full' %
+                               (input_queue.name, feature_queue_capacity),
+                               math_ops.cast(input_queue.size(), dtypes.float32)
+                               * (1. / feature_queue_capacity))
+
+    # Add multiple queue runners so that the queue is always full. Adding more
+    # than two queue-runners may hog the cpu on the worker to fill up the queue.
+    for _ in range(num_queue_runners):
+      queue_runner.add_queue_runner(
+          queue_runner.QueueRunner(input_queue, [input_queue.enqueue(
+              tensors_to_enqueue)]))
+
+    dequeued_tensors = input_queue.dequeue()
+
+    # Reset shapes on dequeued tensors.
+    for i in range(len(tensors_to_enqueue)):
+      dequeued_tensors[i].set_shape(tensors_to_enqueue[i].get_shape())
+
+    # Recreate feature mapping according to the original dictionary.
+    dequeued_feature_map = {}
+    index = 0
+    for key, is_sparse_tensor in tensors_mapping:
+      if is_sparse_tensor:
+        # Three tensors are (indices, values, shape).
+        dequeued_feature_map[key] = ops.SparseTensor(
+            dequeued_tensors[index], dequeued_tensors[index + 1],
+            dequeued_tensors[index + 2])
+        index += 3
+      else:
+        dequeued_feature_map[key] = dequeued_tensors[index]
+        index += 1
+    dequeued_keys = dequeued_tensors[-1]
+
+    return dequeued_keys, dequeued_feature_map
 
 
 def read_batch_features(file_pattern, batch_size, features, reader,