Add a queue for the output tensors from parse example
Change: 129279917
This commit is contained in:
parent
f261f1a572
commit
bdad5cdcbe
@ -20,12 +20,17 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import input as input_ops
|
from tensorflow.python.training import input as input_ops
|
||||||
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
# Default name for key in the feature dict.
|
# Default name for key in the feature dict.
|
||||||
KEY_FEATURE_NAME = '__key__'
|
KEY_FEATURE_NAME = '__key__'
|
||||||
@ -219,11 +224,18 @@ def read_keyed_batch_examples(
|
|||||||
return queued_examples_with_keys
|
return queued_examples_with_keys
|
||||||
|
|
||||||
|
|
||||||
def read_keyed_batch_features(
|
def read_keyed_batch_features(file_pattern,
|
||||||
file_pattern, batch_size, features, reader,
|
batch_size,
|
||||||
randomize_input=True, num_epochs=None,
|
features,
|
||||||
queue_capacity=10000, reader_num_threads=1,
|
reader,
|
||||||
parser_num_threads=1, name=None):
|
randomize_input=True,
|
||||||
|
num_epochs=None,
|
||||||
|
queue_capacity=10000,
|
||||||
|
reader_num_threads=1,
|
||||||
|
feature_queue_capacity=100,
|
||||||
|
num_queue_runners=2,
|
||||||
|
parser_num_threads=None,
|
||||||
|
name=None):
|
||||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||||
|
|
||||||
Given file pattern (or list of files), will setup a queue for file names,
|
Given file pattern (or list of files), will setup a queue for file names,
|
||||||
@ -251,7 +263,12 @@ def read_keyed_batch_features(
|
|||||||
tf.initialize_local_variables() as shown in the tests.
|
tf.initialize_local_variables() as shown in the tests.
|
||||||
queue_capacity: Capacity for input queue.
|
queue_capacity: Capacity for input queue.
|
||||||
reader_num_threads: The number of threads to read examples.
|
reader_num_threads: The number of threads to read examples.
|
||||||
parser_num_threads: The number of threads to parse examples.
|
feature_queue_capacity: Capacity of the parsed features queue.
|
||||||
|
num_queue_runners: Number of queue runners to start for the feature queue,
|
||||||
|
Adding multiple queue runners for the parsed example queue helps maintain
|
||||||
|
a full queue when the subsequent computations overall are cheaper than
|
||||||
|
parsing.
|
||||||
|
parser_num_threads: (Deprecated) The number of threads to parse examples.
|
||||||
name: Name of resulting op.
|
name: Name of resulting op.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -261,6 +278,11 @@ def read_keyed_batch_features(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: for invalid inputs.
|
ValueError: for invalid inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if parser_num_threads:
|
||||||
|
# TODO(sibyl-Aix6ihai): Remove on Sept 3 2016.
|
||||||
|
logging.warning('parser_num_threads is deprecated, it will be removed on'
|
||||||
|
'Sept 3 2016')
|
||||||
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
||||||
keys, examples = read_keyed_batch_examples(
|
keys, examples = read_keyed_batch_examples(
|
||||||
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
||||||
@ -268,24 +290,66 @@ def read_keyed_batch_features(
|
|||||||
num_threads=reader_num_threads, read_batch_size=batch_size,
|
num_threads=reader_num_threads, read_batch_size=batch_size,
|
||||||
name=scope)
|
name=scope)
|
||||||
|
|
||||||
if parser_num_threads == 1:
|
# Parse the example.
|
||||||
# Avoid queue overhead for single thread
|
feature_map = parsing_ops.parse_example(examples, features)
|
||||||
return keys, parsing_ops.parse_example(examples, features)
|
|
||||||
|
|
||||||
# Parse features into tensors in many threads and put on the queue.
|
# Lets also add preprocessed tensors into the queue types for each item of
|
||||||
features_list = []
|
# the queue.
|
||||||
for _ in range(parser_num_threads):
|
tensors_to_enqueue = []
|
||||||
feature_dict = parsing_ops.parse_example(examples, features)
|
# Each entry contains the key, and a boolean which indicates whether the
|
||||||
feature_dict[KEY_FEATURE_NAME] = keys
|
# tensor was a sparse tensor.
|
||||||
features_list.append(feature_dict)
|
tensors_mapping = []
|
||||||
queued_features = input_ops.batch_join(
|
# TODO(sibyl-Aix6ihai): Most of the functionality here is about pushing sparse
|
||||||
features_list,
|
# tensors into a queue. This could be taken care in somewhere else so others
|
||||||
batch_size=batch_size,
|
# can reuse it. Also, QueueBase maybe extended to handle sparse tensors
|
||||||
capacity=queue_capacity,
|
# directly.
|
||||||
enqueue_many=True,
|
for key, tensor in feature_map.iteritems():
|
||||||
name='parse_example_batch_join')
|
if isinstance(tensor, ops.SparseTensor):
|
||||||
queued_keys = queued_features.pop(KEY_FEATURE_NAME)
|
tensors_mapping.append((key, True))
|
||||||
return queued_keys, queued_features
|
tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
|
||||||
|
else:
|
||||||
|
tensors_mapping.append((key, False))
|
||||||
|
tensors_to_enqueue.append(tensor)
|
||||||
|
tensors_to_enqueue.append(keys)
|
||||||
|
|
||||||
|
queue_dtypes = [x.dtype for x in tensors_to_enqueue]
|
||||||
|
input_queue = data_flow_ops.FIFOQueue(feature_queue_capacity, queue_dtypes)
|
||||||
|
|
||||||
|
# Add a summary op to debug if our feature queue is full or not.
|
||||||
|
logging_ops.scalar_summary('queue/parsed_features/%s/fraction_of_%d_full' %
|
||||||
|
(input_queue.name, feature_queue_capacity),
|
||||||
|
math_ops.cast(input_queue.size(), dtypes.float32)
|
||||||
|
* (1. / feature_queue_capacity))
|
||||||
|
|
||||||
|
# Add multiple queue runners so that the queue is always full. Adding more
|
||||||
|
# than two queue-runners may hog the cpu on the worker to fill up the queue.
|
||||||
|
for _ in range(num_queue_runners):
|
||||||
|
queue_runner.add_queue_runner(
|
||||||
|
queue_runner.QueueRunner(input_queue, [input_queue.enqueue(
|
||||||
|
tensors_to_enqueue)]))
|
||||||
|
|
||||||
|
dequeued_tensors = input_queue.dequeue()
|
||||||
|
|
||||||
|
# Reset shapes on dequeued tensors.
|
||||||
|
for i in range(len(tensors_to_enqueue)):
|
||||||
|
dequeued_tensors[i].set_shape(tensors_to_enqueue[i].get_shape())
|
||||||
|
|
||||||
|
# Recreate feature mapping according to the original dictionary.
|
||||||
|
dequeued_feature_map = {}
|
||||||
|
index = 0
|
||||||
|
for key, is_sparse_tensor in tensors_mapping:
|
||||||
|
if is_sparse_tensor:
|
||||||
|
# Three tensors are (indices, values, shape).
|
||||||
|
dequeued_feature_map[key] = ops.SparseTensor(
|
||||||
|
dequeued_tensors[index], dequeued_tensors[index + 1],
|
||||||
|
dequeued_tensors[index + 2])
|
||||||
|
index += 3
|
||||||
|
else:
|
||||||
|
dequeued_feature_map[key] = dequeued_tensors[index]
|
||||||
|
index += 1
|
||||||
|
dequeued_keys = dequeued_tensors[-1]
|
||||||
|
|
||||||
|
return dequeued_keys, dequeued_feature_map
|
||||||
|
|
||||||
|
|
||||||
def read_batch_features(file_pattern, batch_size, features, reader,
|
def read_batch_features(file_pattern, batch_size, features, reader,
|
||||||
|
Loading…
Reference in New Issue
Block a user