From 7a96514f55f36578734a0352b56e2ab43bfb9226 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Wed, 1 Apr 2020 17:34:13 -0700 Subject: [PATCH] [tf.data] Reenable rebatch_dataset_test in TAP and remove some tests from it to reduce its size. These tests exercise logic that was important to check for in the old rewrite-based RebatchDataset. Now that rebatching is done with its own RebatchDataset kernel, they are no longer relevant. PiperOrigin-RevId: 304296546 Change-Id: I673c5fb88bad993b92d760cd8832271716401e25 --- .../data/experimental/kernel_tests/BUILD | 7 - .../kernel_tests/rebatch_dataset_test.py | 297 ------------------ 2 files changed, 304 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index a8e87326743..6930a7db30c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -580,16 +580,9 @@ tf_py_test( name = "rebatch_dataset_test", size = "small", srcs = ["rebatch_dataset_test.py"], - tags = [ - "manual", # TODO(b/152215379) - "notap", - ], deps = [ - "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:image_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/data/experimental/ops:readers", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 054b926bed9..61d0e5eb0bb 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -17,30 +17,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os from absl.testing import parameterized import numpy as np -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import distribute -from tensorflow.python.data.experimental.ops import grouping -from tensorflow.python.data.experimental.ops import readers -from tensorflow.python.data.experimental.ops import scan_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import combinations -from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape -from tensorflow.python.lib.io import python_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -206,290 +193,6 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - @combinations.generate(test_base.default_test_combinations()) - def testMapAndBatch(self): - dataset = dataset_ops.Dataset.range(1024).apply( - batching.map_and_batch(math_ops.square, 32)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension - for i in range(0, 1024, 8)] - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testMapAndBatchWithCapturedInput(self): - captured_t = variables.Variable(42) - dataset = dataset_ops.Dataset.range(1024).apply( - batching.map_and_batch(lambda x: captured_t, 32)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension - for i in range(0, 1024, 8)] - self.evaluate(variables.global_variables_initializer()) - self.assertDatasetProduces( - rebatched_dataset, expected_output, requires_initialization=True) - - @combinations.generate(test_base.default_test_combinations()) - def testPaddedBatch(self): - dataset = dataset_ops.Dataset.range(128).batch( - 4, drop_remainder=True).padded_batch( - 8, padded_shapes=[5]) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - # Each element is a list of 8 elements in which each element is a list of 5 - # elements, first four are numbers and the last one is a padded zero. - expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension - for j in range(i, i + 32, 4)] # generates 8 elements - for i in range(0, 128, 32)] - self.assertDatasetProduces(dataset, expected_output) - self.assertEqual([[None, 5]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - # Each element is a list of 2 elements in which each element is a list of 5 - # elements, first four are numbers and the last one is a padded zero. - expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension - for j in range(i, i + 8, 4)] # generates 2 elements - for i in range(0, 128, 8)] - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testConcatenate(self): - dataset1 = dataset_ops.Dataset.range(64).batch(8) - dataset2 = dataset_ops.Dataset.range(32).batch(8) - dataset = dataset1.concatenate(dataset2) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + - [[i, i + 1] for i in range(0, 32, 2)]) - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testConcatenateDifferentShapes(self): - dataset1 = dataset_ops.Dataset.range(64).batch(16) - dataset2 = dataset_ops.Dataset.range(32).batch(8) - dataset = dataset1.concatenate(dataset2) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + - [[i, i + 1] for i in range(0, 32, 2)]) - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testZip(self): - dataset1 = dataset_ops.Dataset.range(64).batch(8) - dataset2 = dataset_ops.Dataset.range(32).batch(8) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None], [None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testZipDifferentShapes(self): - dataset1 = dataset_ops.Dataset.range(64).batch(16) - dataset2 = dataset_ops.Dataset.range(32).batch(8) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None], [None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) - for i in range(0, 32, 2)] - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testFlatMapBatching(self): - dataset = dataset_ops.Dataset.range(2).flat_map( - lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32)) - # Two elements where each element is range(32) - expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension - self.assertDatasetProduces(dataset, expected_output) - - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - # Two elements where each element is a list of 4 elements where each element - # is a list of 8. - expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension - for _ in range(2) - for i in range(0, 32, 8)] # generates 4 elements - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testInterleaveBatching(self): - dataset = dataset_ops.Dataset.range(2).interleave( - lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32), - cycle_length=2) - # Two elements where each element is range(32) - expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension - self.assertDatasetProduces(dataset, expected_output) - - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] - expected_output += expected_output - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testParallelInterleaveBatching(self): - dataset = dataset_ops.Dataset.range(2).interleave( - lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32), - cycle_length=2, - num_parallel_calls=2) - # Two elements where each element is range(32) - expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension - self.assertDatasetProduces(dataset, expected_output) - - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] - expected_output += expected_output - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testGroupByWindowStaticBatch(self): - dataset = dataset_ops.Dataset.from_tensor_slices( - [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) - reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda - batch_size=10) - dataset = dataset.apply( - grouping.group_by_window( - key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2) - - self.assertEqual([[None, 3]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - # pylint: disable=g-complex-comprehension - expected_output = [[[j + i * 4 + k * 20] * 3 - for i in range(5)] - for j in range(4) - for k in range(2)] - self.assertDatasetProduces(rebatched_dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testGroupByWindowDynamicBatch(self): - # {0, 1, 0, 1, ...} - dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) - - def reduce_fn(key, ds): - # key == 0 -> .batch(5) - # key == 1 -> .batch(10) - return ds.batch(batch_size=(key + 1) * 5) - - dataset = dataset.apply( - grouping.group_by_window( - key_func=lambda x: x, reduce_func=reduce_fn, window_size=10)) - dataset = distribute._RebatchDataset(dataset, num_replicas=2) - - self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) - - # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and - # the batches of 10 (value == 1) split into minibatches of (5, 5) - # [(batch_size, value), ...] - pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)] - pairs = pairs * 2 - expected_output = [[value] * batch_size for batch_size, value in pairs] - self.assertDatasetProduces(dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testGroupByWindowDynamicBatchWithPartialBatch(self): - # {0, 1, 0, 1, ...} - dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) - - def reduce_fn(key, ds): - # key == 0 -> .batch(5) - # key == 1 -> .batch(10) - return ds.batch(batch_size=(key + 1) * 5) - - dataset = dataset.apply( - grouping.group_by_window( - key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) - dataset = distribute._RebatchDataset(dataset, num_replicas=2) - - self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) - - pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1), - (1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)] - expected_output = [[value] * batch_size for batch_size, value in pairs] - self.assertDatasetProduces(dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): - # This test exercises nested batch functionality, dynamic batch size - # and drop_remainder=True together. - dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) - - def reduce_fn(key, ds): - # key == 0 -> .batch(5) - # key == 1 -> .batch(10) - return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True) - - dataset = dataset.apply( - grouping.group_by_window( - key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) - dataset = distribute._RebatchDataset(dataset, num_replicas=2) - - self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) - - # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and - # the batches of 10 (value == 1) split into minibatches of (5, 5) - # [(batch_size, value), ...] - pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)] - expected_output = [[value] * batch_size for batch_size, value in pairs] - self.assertDatasetProduces(dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testScanAfterBatch(self): - dataset = dataset_ops.Dataset.range(40).batch(10).apply( - scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) - dataset = distribute._RebatchDataset(dataset, num_replicas=2) - - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension - self.assertDatasetProduces(dataset, expected_output) - - @combinations.generate(test_base.default_test_combinations()) - def testMakeBatchedFeaturesDataset(self): - # Set up - fn = os.path.join(self.get_temp_dir(), "tf_record.txt") - writer = python_io.TFRecordWriter(fn) - for i in range(1024): - writer.write( - example_pb2.Example( - features=feature_pb2.Features( - feature={ - "value": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[i])) - })).SerializeToString()) - writer.close() - - dataset = readers.make_batched_features_dataset( - file_pattern=fn, - batch_size=32, - features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)}, - shuffle=False, - num_epochs=1, - drop_final_batch=False) - - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - - self.assertEqual([[None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - - expected_output = [{ - "value": [k for k in range(i, i + 8)] - } for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension - self.assertDatasetProduces(rebatched_dataset, expected_output) - @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size.