Set allow_smaller_final_batch to true when num_epochs is set so the last items

in the data can be read without missing.
Change: 128396107
This commit is contained in:
A. Unique TensorFlower 2016-07-25 12:57:45 -08:00 committed by TensorFlower Gardener
parent 88ffd73ed9
commit ed281973d6
2 changed files with 11 additions and 3 deletions

View File

@ -192,6 +192,11 @@ def read_keyed_batch_examples(
enqueue_many = read_batch_size > 1 enqueue_many = read_batch_size > 1
if num_epochs is not None:
allow_smaller_final_batch = True
else:
allow_smaller_final_batch = False
# Setup batching queue given list of read example tensors. # Setup batching queue given list of read example tensors.
if randomize_input: if randomize_input:
if isinstance(batch_size, ops.Tensor): if isinstance(batch_size, ops.Tensor):
@ -201,11 +206,13 @@ def read_keyed_batch_examples(
queued_examples_with_keys = input_ops.shuffle_batch_join( queued_examples_with_keys = input_ops.shuffle_batch_join(
example_list, batch_size, capacity=queue_capacity, example_list, batch_size, capacity=queue_capacity,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
enqueue_many=enqueue_many, name=scope) enqueue_many=enqueue_many, name=scope,
allow_smaller_final_batch=allow_smaller_final_batch)
else: else:
queued_examples_with_keys = input_ops.batch_join( queued_examples_with_keys = input_ops.batch_join(
example_list, batch_size, capacity=queue_capacity, example_list, batch_size, capacity=queue_capacity,
enqueue_many=enqueue_many, name=scope) enqueue_many=enqueue_many, name=scope,
allow_smaller_final_batch=allow_smaller_final_batch)
if parse_fn and isinstance(queued_examples_with_keys, dict): if parse_fn and isinstance(queued_examples_with_keys, dict):
queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
return queued_keys, queued_examples_with_keys return queued_keys, queued_examples_with_keys

View File

@ -164,7 +164,7 @@ class GraphIOTest(tf.test.TestCase):
file_name_queue_name: "FIFOQueue", file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader", "%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "RandomShuffleQueue", example_queue_name: "RandomShuffleQueue",
name: "QueueDequeueMany", name: "QueueDequeueUpTo",
file_name_queue_limit_name: "Variable" file_name_queue_limit_name: "Variable"
}, g) }, g)
self.assertEqual( self.assertEqual(
@ -249,6 +249,7 @@ class GraphIOTest(tf.test.TestCase):
tf.train.start_queue_runners(session, coord=coord) tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
session.run(inputs) session.run(inputs)