[tf.data] Reenable rebatch_dataset_test in TAP and remove some tests from it to reduce its size. These tests exercise logic that was important to check for in the old rewrite-based RebatchDataset. Now that rebatching is done with its own RebatchDataset kernel, they are no longer relevant.
PiperOrigin-RevId: 304296546 Change-Id: I673c5fb88bad993b92d760cd8832271716401e25
This commit is contained in:
		
							parent
							
								
									a0422e404e
								
							
						
					
					
						commit
						7a96514f55
					
				| @ -580,16 +580,9 @@ tf_py_test( | ||||
|     name = "rebatch_dataset_test", | ||||
|     size = "small", | ||||
|     srcs = ["rebatch_dataset_test.py"], | ||||
|     tags = [ | ||||
|         "manual",  # TODO(b/152215379) | ||||
|         "notap", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//tensorflow/core:protos_all_py", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|         "//tensorflow/python:image_ops", | ||||
|         "//tensorflow/python:parsing_ops", | ||||
|         "//tensorflow/python/data/experimental/ops:readers", | ||||
|         "//tensorflow/python/data/kernel_tests:test_base", | ||||
|         "//tensorflow/python/data/ops:dataset_ops", | ||||
|         "//tensorflow/python/data/util:nest", | ||||
|  | ||||
| @ -17,30 +17,17 @@ from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import os | ||||
| 
 | ||||
| from absl.testing import parameterized | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.core.example import example_pb2 | ||||
| from tensorflow.core.example import feature_pb2 | ||||
| from tensorflow.python.data.experimental.ops import batching | ||||
| from tensorflow.python.data.experimental.ops import distribute | ||||
| from tensorflow.python.data.experimental.ops import grouping | ||||
| from tensorflow.python.data.experimental.ops import readers | ||||
| from tensorflow.python.data.experimental.ops import scan_ops | ||||
| from tensorflow.python.data.kernel_tests import test_base | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.data.util import nest | ||||
| from tensorflow.python.framework import combinations | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.lib.io import python_io | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import image_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.ops import parsing_ops | ||||
| from tensorflow.python.ops import variables | ||||
| from tensorflow.python.ops.ragged import ragged_tensor | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| @ -206,290 +193,6 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): | ||||
|                        for i in range(0, 128, 8)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testMapAndBatch(self): | ||||
|     dataset = dataset_ops.Dataset.range(1024).apply( | ||||
|         batching.map_and_batch(math_ops.square, 32)) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension | ||||
|                        for i in range(0, 1024, 8)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testMapAndBatchWithCapturedInput(self): | ||||
|     captured_t = variables.Variable(42) | ||||
|     dataset = dataset_ops.Dataset.range(1024).apply( | ||||
|         batching.map_and_batch(lambda x: captured_t, 32)) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [[42 for _ in range(i, i + 8)]  # pylint: disable=g-complex-comprehension | ||||
|                        for i in range(0, 1024, 8)] | ||||
|     self.evaluate(variables.global_variables_initializer()) | ||||
|     self.assertDatasetProduces( | ||||
|         rebatched_dataset, expected_output, requires_initialization=True) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testPaddedBatch(self): | ||||
|     dataset = dataset_ops.Dataset.range(128).batch( | ||||
|         4, drop_remainder=True).padded_batch( | ||||
|             8, padded_shapes=[5]) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     # Each element is a list of 8 elements in which each element is a list of 5 | ||||
|     # elements, first four are numbers and the last one is a padded zero. | ||||
|     expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension | ||||
|                         for j in range(i, i + 32, 4)]  # generates 8 elements | ||||
|                        for i in range(0, 128, 32)] | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
|     self.assertEqual([[None, 5]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     # Each element is a list of 2 elements in which each element is a list of 5 | ||||
|     # elements, first four are numbers and the last one is a padded zero. | ||||
|     expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension | ||||
|                         for j in range(i, i + 8, 4)]  # generates 2 elements | ||||
|                        for i in range(0, 128, 8)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testConcatenate(self): | ||||
|     dataset1 = dataset_ops.Dataset.range(64).batch(8) | ||||
|     dataset2 = dataset_ops.Dataset.range(32).batch(8) | ||||
|     dataset = dataset1.concatenate(dataset2) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + | ||||
|                        [[i, i + 1] for i in range(0, 32, 2)]) | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testConcatenateDifferentShapes(self): | ||||
|     dataset1 = dataset_ops.Dataset.range(64).batch(16) | ||||
|     dataset2 = dataset_ops.Dataset.range(32).batch(8) | ||||
|     dataset = dataset1.concatenate(dataset2) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + | ||||
|                        [[i, i + 1] for i in range(0, 32, 2)]) | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testZip(self): | ||||
|     dataset1 = dataset_ops.Dataset.range(64).batch(8) | ||||
|     dataset2 = dataset_ops.Dataset.range(32).batch(8) | ||||
|     dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None], [None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testZipDifferentShapes(self): | ||||
|     dataset1 = dataset_ops.Dataset.range(64).batch(16) | ||||
|     dataset2 = dataset_ops.Dataset.range(32).batch(8) | ||||
|     dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None], [None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) | ||||
|                        for i in range(0, 32, 2)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testFlatMapBatching(self): | ||||
|     dataset = dataset_ops.Dataset.range(2).flat_map( | ||||
|         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda | ||||
|             32)) | ||||
|     # Two elements where each element is range(32) | ||||
|     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     # Two elements where each element is a list of 4 elements where each element | ||||
|     # is a list of 8. | ||||
|     expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension | ||||
|                        for _ in range(2) | ||||
|                        for i in range(0, 32, 8)]  # generates 4 elements | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testInterleaveBatching(self): | ||||
|     dataset = dataset_ops.Dataset.range(2).interleave( | ||||
|         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda | ||||
|             32), | ||||
|         cycle_length=2) | ||||
|     # Two elements where each element is range(32) | ||||
|     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] | ||||
|     expected_output += expected_output | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testParallelInterleaveBatching(self): | ||||
|     dataset = dataset_ops.Dataset.range(2).interleave( | ||||
|         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda | ||||
|             32), | ||||
|         cycle_length=2, | ||||
|         num_parallel_calls=2) | ||||
|     # Two elements where each element is range(32) | ||||
|     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] | ||||
|     expected_output += expected_output | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testGroupByWindowStaticBatch(self): | ||||
|     dataset = dataset_ops.Dataset.from_tensor_slices( | ||||
|         [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) | ||||
|     reduce_fn = lambda bucket_id, ds: ds.batch(  # pylint: disable=g-long-lambda | ||||
|         batch_size=10) | ||||
|     dataset = dataset.apply( | ||||
|         grouping.group_by_window( | ||||
|             key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2) | ||||
| 
 | ||||
|     self.assertEqual([[None, 3]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
|     # pylint: disable=g-complex-comprehension | ||||
|     expected_output = [[[j + i * 4 + k * 20] * 3 | ||||
|                         for i in range(5)] | ||||
|                        for j in range(4) | ||||
|                        for k in range(2)] | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testGroupByWindowDynamicBatch(self): | ||||
|     # {0, 1, 0, 1, ...} | ||||
|     dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) | ||||
| 
 | ||||
|     def reduce_fn(key, ds): | ||||
|       # key == 0 -> .batch(5) | ||||
|       # key == 1 -> .batch(10) | ||||
|       return ds.batch(batch_size=(key + 1) * 5) | ||||
| 
 | ||||
|     dataset = dataset.apply( | ||||
|         grouping.group_by_window( | ||||
|             key_func=lambda x: x, reduce_func=reduce_fn, window_size=10)) | ||||
|     dataset = distribute._RebatchDataset(dataset, num_replicas=2) | ||||
| 
 | ||||
|     self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) | ||||
| 
 | ||||
|     # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and | ||||
|     # the batches of 10 (value == 1) split into minibatches of (5, 5) | ||||
|     # [(batch_size, value), ...] | ||||
|     pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)] | ||||
|     pairs = pairs * 2 | ||||
|     expected_output = [[value] * batch_size for batch_size, value in pairs] | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testGroupByWindowDynamicBatchWithPartialBatch(self): | ||||
|     # {0, 1, 0, 1, ...} | ||||
|     dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) | ||||
| 
 | ||||
|     def reduce_fn(key, ds): | ||||
|       # key == 0 -> .batch(5) | ||||
|       # key == 1 -> .batch(10) | ||||
|       return ds.batch(batch_size=(key + 1) * 5) | ||||
| 
 | ||||
|     dataset = dataset.apply( | ||||
|         grouping.group_by_window( | ||||
|             key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) | ||||
|     dataset = distribute._RebatchDataset(dataset, num_replicas=2) | ||||
| 
 | ||||
|     self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) | ||||
| 
 | ||||
|     pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1), | ||||
|              (1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)] | ||||
|     expected_output = [[value] * batch_size for batch_size, value in pairs] | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): | ||||
|     # This test exercises nested batch functionality, dynamic batch size | ||||
|     # and drop_remainder=True together. | ||||
|     dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) | ||||
| 
 | ||||
|     def reduce_fn(key, ds): | ||||
|       # key == 0 -> .batch(5) | ||||
|       # key == 1 -> .batch(10) | ||||
|       return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True) | ||||
| 
 | ||||
|     dataset = dataset.apply( | ||||
|         grouping.group_by_window( | ||||
|             key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) | ||||
|     dataset = distribute._RebatchDataset(dataset, num_replicas=2) | ||||
| 
 | ||||
|     self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) | ||||
| 
 | ||||
|     # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and | ||||
|     # the batches of 10 (value == 1) split into minibatches of (5, 5) | ||||
|     # [(batch_size, value), ...] | ||||
|     pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)] | ||||
|     expected_output = [[value] * batch_size for batch_size, value in pairs] | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testScanAfterBatch(self): | ||||
|     dataset = dataset_ops.Dataset.range(40).batch(10).apply( | ||||
|         scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) | ||||
|     dataset = distribute._RebatchDataset(dataset, num_replicas=2) | ||||
| 
 | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(dataset)]) | ||||
|     expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)]  # pylint: disable=g-complex-comprehension | ||||
|     self.assertDatasetProduces(dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testMakeBatchedFeaturesDataset(self): | ||||
|     # Set up | ||||
|     fn = os.path.join(self.get_temp_dir(), "tf_record.txt") | ||||
|     writer = python_io.TFRecordWriter(fn) | ||||
|     for i in range(1024): | ||||
|       writer.write( | ||||
|           example_pb2.Example( | ||||
|               features=feature_pb2.Features( | ||||
|                   feature={ | ||||
|                       "value": | ||||
|                           feature_pb2.Feature( | ||||
|                               int64_list=feature_pb2.Int64List(value=[i])) | ||||
|                   })).SerializeToString()) | ||||
|     writer.close() | ||||
| 
 | ||||
|     dataset = readers.make_batched_features_dataset( | ||||
|         file_pattern=fn, | ||||
|         batch_size=32, | ||||
|         features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)}, | ||||
|         shuffle=False, | ||||
|         num_epochs=1, | ||||
|         drop_final_batch=False) | ||||
| 
 | ||||
|     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) | ||||
| 
 | ||||
|     self.assertEqual([[None]], | ||||
|                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) | ||||
| 
 | ||||
|     expected_output = [{ | ||||
|         "value": [k for k in range(i, i + 8)] | ||||
|     } for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension | ||||
|     self.assertDatasetProduces(rebatched_dataset, expected_output) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testRaggedTensorDataset(self): | ||||
|     # Set up a dataset that produces ragged tensors with a static batch size. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user