Set allow_smaller_final_batch to true when num_epochs is set so the last items
in the data can be read without missing. Change: 128396107
This commit is contained in:
parent
88ffd73ed9
commit
ed281973d6
@ -192,6 +192,11 @@ def read_keyed_batch_examples(
|
|||||||
|
|
||||||
enqueue_many = read_batch_size > 1
|
enqueue_many = read_batch_size > 1
|
||||||
|
|
||||||
|
if num_epochs is not None:
|
||||||
|
allow_smaller_final_batch = True
|
||||||
|
else:
|
||||||
|
allow_smaller_final_batch = False
|
||||||
|
|
||||||
# Setup batching queue given list of read example tensors.
|
# Setup batching queue given list of read example tensors.
|
||||||
if randomize_input:
|
if randomize_input:
|
||||||
if isinstance(batch_size, ops.Tensor):
|
if isinstance(batch_size, ops.Tensor):
|
||||||
@ -201,11 +206,13 @@ def read_keyed_batch_examples(
|
|||||||
queued_examples_with_keys = input_ops.shuffle_batch_join(
|
queued_examples_with_keys = input_ops.shuffle_batch_join(
|
||||||
example_list, batch_size, capacity=queue_capacity,
|
example_list, batch_size, capacity=queue_capacity,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
enqueue_many=enqueue_many, name=scope)
|
enqueue_many=enqueue_many, name=scope,
|
||||||
|
allow_smaller_final_batch=allow_smaller_final_batch)
|
||||||
else:
|
else:
|
||||||
queued_examples_with_keys = input_ops.batch_join(
|
queued_examples_with_keys = input_ops.batch_join(
|
||||||
example_list, batch_size, capacity=queue_capacity,
|
example_list, batch_size, capacity=queue_capacity,
|
||||||
enqueue_many=enqueue_many, name=scope)
|
enqueue_many=enqueue_many, name=scope,
|
||||||
|
allow_smaller_final_batch=allow_smaller_final_batch)
|
||||||
if parse_fn and isinstance(queued_examples_with_keys, dict):
|
if parse_fn and isinstance(queued_examples_with_keys, dict):
|
||||||
queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
|
queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
|
||||||
return queued_keys, queued_examples_with_keys
|
return queued_keys, queued_examples_with_keys
|
||||||
|
@ -164,7 +164,7 @@ class GraphIOTest(tf.test.TestCase):
|
|||||||
file_name_queue_name: "FIFOQueue",
|
file_name_queue_name: "FIFOQueue",
|
||||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||||
example_queue_name: "RandomShuffleQueue",
|
example_queue_name: "RandomShuffleQueue",
|
||||||
name: "QueueDequeueMany",
|
name: "QueueDequeueUpTo",
|
||||||
file_name_queue_limit_name: "Variable"
|
file_name_queue_limit_name: "Variable"
|
||||||
}, g)
|
}, g)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -249,6 +249,7 @@ class GraphIOTest(tf.test.TestCase):
|
|||||||
tf.train.start_queue_runners(session, coord=coord)
|
tf.train.start_queue_runners(session, coord=coord)
|
||||||
|
|
||||||
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
|
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
|
||||||
|
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
session.run(inputs)
|
session.run(inputs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user