tf.learn: Added functionality to run over data only once (num_epochs=1).
Speed up read_batch_example by adding more reading ops before queueing (if reading is slow operation, before it would create a bottleneck on one op). Speed up read_batch_features by adding parsing_num_threads of parsing ops and queue for aggregating them (currently only one parsing op was used syncronously with training). Change: 123278248
This commit is contained in:
parent
2b744876db
commit
710edb74e6
@ -26,8 +26,9 @@ from tensorflow.python.training import input as input_ops
|
||||
|
||||
|
||||
def read_batch_examples(file_pattern, batch_size, reader,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, num_threads=1,
|
||||
name=None):
|
||||
"""Adds operations to read, queue, batch `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -46,6 +47,10 @@ def read_batch_examples(file_pattern, batch_size, reader,
|
||||
reader: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If `None`, cycles through the dataset forever.
|
||||
NOTE - If specified, creates a variable that must be initialized, so call
|
||||
`tf.initialize_all_variables()` as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
name: Name of resulting op.
|
||||
@ -82,39 +87,47 @@ def read_batch_examples(file_pattern, batch_size, reader,
|
||||
(batch_size, queue_capacity))
|
||||
if (not num_threads) or (num_threads <= 0):
|
||||
raise ValueError('Invalid num_threads %s.' % num_threads)
|
||||
if (num_epochs is not None) and (num_epochs <= 0):
|
||||
raise ValueError('Invalid num_epochs %s.' % num_epochs)
|
||||
|
||||
with ops.name_scope(name) as scope:
|
||||
with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
|
||||
# Setup filename queue with shuffling.
|
||||
with ops.name_scope('file_name_queue') as file_name_queue_scope:
|
||||
file_name_queue = input_ops.string_input_producer(
|
||||
constant_op.constant(file_names, name='input'),
|
||||
shuffle=randomize_input, name=file_name_queue_scope)
|
||||
shuffle=randomize_input, num_epochs=num_epochs,
|
||||
name=file_name_queue_scope)
|
||||
|
||||
# Create reader and set it to read from filename queue.
|
||||
# Create readers, one per thread and set them to read from filename queue.
|
||||
with ops.name_scope('read'):
|
||||
_, example_proto = reader().read(file_name_queue)
|
||||
example_list = []
|
||||
for _ in range(num_threads):
|
||||
_, example_proto = reader().read(file_name_queue)
|
||||
example_list.append([example_proto])
|
||||
|
||||
# Setup batching queue.
|
||||
# Setup batching queue given list of read example tensors.
|
||||
if randomize_input:
|
||||
if isinstance(batch_size, ops.Tensor):
|
||||
min_after_dequeue = int(queue_capacity * 0.4)
|
||||
else:
|
||||
min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
|
||||
examples = input_ops.shuffle_batch(
|
||||
[example_proto], batch_size, capacity=queue_capacity,
|
||||
num_threads=num_threads, min_after_dequeue=min_after_dequeue,
|
||||
examples = input_ops.shuffle_batch_join(
|
||||
example_list, batch_size, capacity=queue_capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
name=scope)
|
||||
else:
|
||||
examples = input_ops.batch(
|
||||
[example_proto], batch_size, capacity=queue_capacity,
|
||||
num_threads=num_threads, name=scope)
|
||||
examples = input_ops.batch_join(
|
||||
example_list, batch_size, capacity=queue_capacity,
|
||||
name=scope)
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, reader_num_threads=1,
|
||||
parser_num_threads=1,
|
||||
name=None):
|
||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -136,8 +149,13 @@ def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
reader: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
reader_num_threads: The number of threads to read examples.
|
||||
parser_num_threads: The number of threads to parse examples.
|
||||
name: Name of resulting op.
|
||||
|
||||
Returns:
|
||||
@ -146,17 +164,29 @@ def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
Raises:
|
||||
ValueError: for invalid inputs.
|
||||
"""
|
||||
examples = read_batch_examples(
|
||||
file_pattern, batch_size, reader, randomize_input,
|
||||
queue_capacity, num_threads, name=name)
|
||||
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
||||
examples = read_batch_examples(
|
||||
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
||||
num_epochs=num_epochs, queue_capacity=queue_capacity,
|
||||
num_threads=reader_num_threads, name=scope)
|
||||
|
||||
# Parse features into tensors.
|
||||
return parsing_ops.parse_example(examples, features)
|
||||
# Parse features into tensors in many threads and put on the queue.
|
||||
features_list = []
|
||||
for _ in range(parser_num_threads):
|
||||
features_list.append(parsing_ops.parse_example(examples, features))
|
||||
return input_ops.batch_join(
|
||||
features_list,
|
||||
batch_size=batch_size,
|
||||
capacity=queue_capacity,
|
||||
enqueue_many=True,
|
||||
name='parse_example_batch_join')
|
||||
|
||||
|
||||
def read_batch_record_features(file_pattern, batch_size, features,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_record_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, reader_num_threads=1,
|
||||
parser_num_threads=1,
|
||||
name='dequeue_record_examples'):
|
||||
"""Reads TFRecord, queues, batches and parses `Example` proto.
|
||||
|
||||
See more detailed description in `read_examples`.
|
||||
@ -168,8 +198,13 @@ def read_batch_record_features(file_pattern, batch_size, features,
|
||||
features: A `dict` mapping feature keys to `FixedLenFeature` or
|
||||
`VarLenFeature` values.
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
reader_num_threads: The number of threads to read examples.
|
||||
parser_num_threads: The number of threads to parse examples.
|
||||
name: Name of resulting op.
|
||||
|
||||
Returns:
|
||||
@ -181,5 +216,6 @@ def read_batch_record_features(file_pattern, batch_size, features,
|
||||
return read_batch_features(
|
||||
file_pattern=file_pattern, batch_size=batch_size, features=features,
|
||||
reader=io_ops.TFRecordReader,
|
||||
randomize_input=randomize_input,
|
||||
queue_capacity=queue_capacity, num_threads=num_threads, name=name)
|
||||
randomize_input=randomize_input, num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads, name=name)
|
||||
|
@ -17,10 +17,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
@ -55,44 +58,83 @@ class GraphIOTest(tf.test.TestCase):
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "No files match",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_INVALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity,
|
||||
num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_INVALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, None, None, tf.TFRecordReader,
|
||||
False, queue_capacity, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, None, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, -1, None, tf.TFRecordReader,
|
||||
False, queue_capacity, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, -1, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid queue_capacity",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, None, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=None,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_threads",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity, None,
|
||||
name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=None, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_threads",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity, -1,
|
||||
name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=-1, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, queue_capacity + 1, None, tf.TFRecordReader,
|
||||
False, queue_capacity, 1, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, queue_capacity + 1, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=1, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_epochs",
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=-1, queue_capacity=queue_capacity, num_threads=1,
|
||||
name=name)
|
||||
|
||||
def test_batch_tf_record(self):
|
||||
def test_batch_record_features(self):
|
||||
batch_size = 17
|
||||
queue_capacity = 1234
|
||||
name = "my_batch"
|
||||
features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
|
||||
features = tf.contrib.learn.io.read_batch_record_features(
|
||||
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
|
||||
queue_capacity=queue_capacity, reader_num_threads=2,
|
||||
parser_num_threads=2, name=name)
|
||||
self.assertEquals("%s/parse_example_batch_join:0" % name,
|
||||
features["feature"].name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
example_queue_name = "%s/fifo_queue" % name
|
||||
parse_example_queue_name = "%s/parse_example_batch_join" % name
|
||||
op_nodes = test_util.assert_ops_in_graph({
|
||||
file_names_name: "Const",
|
||||
file_name_queue_name: "FIFOQueue",
|
||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||
example_queue_name: "FIFOQueue",
|
||||
parse_example_queue_name: "QueueDequeueMany",
|
||||
name: "QueueDequeueMany"
|
||||
}, g)
|
||||
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
def test_one_epoch(self):
|
||||
batch_size = 17
|
||||
queue_capacity = 1234
|
||||
name = "my_batch"
|
||||
@ -100,20 +142,25 @@ class GraphIOTest(tf.test.TestCase):
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
|
||||
inputs = tf.contrib.learn.io.read_batch_examples(
|
||||
_VALID_FILE_PATTERN, batch_size,
|
||||
reader=tf.TFRecordReader, randomize_input=False,
|
||||
reader=tf.TFRecordReader, randomize_input=True,
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity, name=name)
|
||||
self.assertEquals("%s:0" % name, inputs.name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_name_queue_limit_name = (
|
||||
"%s/limit_epochs/epochs" % file_name_queue_name)
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
example_queue_name = "%s/fifo_queue" % name
|
||||
example_queue_name = "%s/random_shuffle_queue" % name
|
||||
op_nodes = test_util.assert_ops_in_graph({
|
||||
file_names_name: "Const",
|
||||
file_name_queue_name: "FIFOQueue",
|
||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||
example_queue_name: "FIFOQueue",
|
||||
name: "QueueDequeueMany"
|
||||
example_queue_name: "RandomShuffleQueue",
|
||||
name: "QueueDequeueMany",
|
||||
file_name_queue_limit_name: "Variable"
|
||||
}, g)
|
||||
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
||||
self.assertEqual(
|
||||
set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
@ -143,6 +190,34 @@ class GraphIOTest(tf.test.TestCase):
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
def test_read_csv(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
tempdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(tempdir, "file.csv")
|
||||
gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n")
|
||||
|
||||
batch_size = 1
|
||||
queue_capacity = 5
|
||||
name = "my_batch"
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
|
||||
inputs = tf.contrib.learn.io.read_batch_examples(
|
||||
filename, batch_size,
|
||||
reader=tf.TextLineReader, randomize_input=False,
|
||||
num_epochs=1, queue_capacity=queue_capacity, name=name)
|
||||
session.run(tf.initialize_all_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertEqual(session.run(inputs), "ABC")
|
||||
self.assertEqual(session.run(inputs), "DEF")
|
||||
self.assertEqual(session.run(inputs), "GHK")
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user