Set allow_smaller_final_batch to true when num_epochs is set so the last items
in the data can be read without missing. Change: 128396107
This commit is contained in:
		
							parent
							
								
									88ffd73ed9
								
							
						
					
					
						commit
						ed281973d6
					
				| @ -192,6 +192,11 @@ def read_keyed_batch_examples( | |||||||
| 
 | 
 | ||||||
|     enqueue_many = read_batch_size > 1 |     enqueue_many = read_batch_size > 1 | ||||||
| 
 | 
 | ||||||
|  |     if num_epochs is not None: | ||||||
|  |       allow_smaller_final_batch = True | ||||||
|  |     else: | ||||||
|  |       allow_smaller_final_batch = False | ||||||
|  | 
 | ||||||
|     # Setup batching queue given list of read example tensors. |     # Setup batching queue given list of read example tensors. | ||||||
|     if randomize_input: |     if randomize_input: | ||||||
|       if isinstance(batch_size, ops.Tensor): |       if isinstance(batch_size, ops.Tensor): | ||||||
| @ -201,11 +206,13 @@ def read_keyed_batch_examples( | |||||||
|       queued_examples_with_keys = input_ops.shuffle_batch_join( |       queued_examples_with_keys = input_ops.shuffle_batch_join( | ||||||
|           example_list, batch_size, capacity=queue_capacity, |           example_list, batch_size, capacity=queue_capacity, | ||||||
|           min_after_dequeue=min_after_dequeue, |           min_after_dequeue=min_after_dequeue, | ||||||
|           enqueue_many=enqueue_many, name=scope) |           enqueue_many=enqueue_many, name=scope, | ||||||
|  |           allow_smaller_final_batch=allow_smaller_final_batch) | ||||||
|     else: |     else: | ||||||
|       queued_examples_with_keys = input_ops.batch_join( |       queued_examples_with_keys = input_ops.batch_join( | ||||||
|           example_list, batch_size, capacity=queue_capacity, |           example_list, batch_size, capacity=queue_capacity, | ||||||
|           enqueue_many=enqueue_many, name=scope) |           enqueue_many=enqueue_many, name=scope, | ||||||
|  |           allow_smaller_final_batch=allow_smaller_final_batch) | ||||||
|     if parse_fn and isinstance(queued_examples_with_keys, dict): |     if parse_fn and isinstance(queued_examples_with_keys, dict): | ||||||
|       queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) |       queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) | ||||||
|       return queued_keys, queued_examples_with_keys |       return queued_keys, queued_examples_with_keys | ||||||
|  | |||||||
| @ -164,7 +164,7 @@ class GraphIOTest(tf.test.TestCase): | |||||||
|           file_name_queue_name: "FIFOQueue", |           file_name_queue_name: "FIFOQueue", | ||||||
|           "%s/read/TFRecordReader" % name: "TFRecordReader", |           "%s/read/TFRecordReader" % name: "TFRecordReader", | ||||||
|           example_queue_name: "RandomShuffleQueue", |           example_queue_name: "RandomShuffleQueue", | ||||||
|           name: "QueueDequeueMany", |           name: "QueueDequeueUpTo", | ||||||
|           file_name_queue_limit_name: "Variable" |           file_name_queue_limit_name: "Variable" | ||||||
|       }, g) |       }, g) | ||||||
|       self.assertEqual( |       self.assertEqual( | ||||||
| @ -249,6 +249,7 @@ class GraphIOTest(tf.test.TestCase): | |||||||
|       tf.train.start_queue_runners(session, coord=coord) |       tf.train.start_queue_runners(session, coord=coord) | ||||||
| 
 | 
 | ||||||
|       self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) |       self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) | ||||||
|  |       self.assertAllEqual(session.run(inputs), [b"D", b"E"]) | ||||||
|       with self.assertRaises(errors.OutOfRangeError): |       with self.assertRaises(errors.OutOfRangeError): | ||||||
|         session.run(inputs) |         session.run(inputs) | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user