[tf.data] Reenable rebatch_dataset_test in TAP and remove some tests from it to reduce its size. These tests exercise logic that was important to check for in the old rewrite-based RebatchDataset. Now that rebatching is done with its own RebatchDataset kernel, they are no longer relevant.

PiperOrigin-RevId: 304296546
Change-Id: I673c5fb88bad993b92d760cd8832271716401e25
This commit is contained in:
Rachel Lim 2020-04-01 17:34:13 -07:00 committed by TensorFlower Gardener
parent a0422e404e
commit 7a96514f55
2 changed files with 0 additions and 304 deletions

View File

@ -580,16 +580,9 @@ tf_py_test(
name = "rebatch_dataset_test",
size = "small",
srcs = ["rebatch_dataset_test.py"],
tags = [
"manual", # TODO(b/152215379)
"notap",
],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:image_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",

View File

@ -17,30 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -206,290 +193,6 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatch(self):
dataset = dataset_ops.Dataset.range(1024).apply(
batching.map_and_batch(math_ops.square, 32))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatchWithCapturedInput(self):
captured_t = variables.Variable(42)
dataset = dataset_ops.Dataset.range(1024).apply(
batching.map_and_batch(lambda x: captured_t, 32))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension
for i in range(0, 1024, 8)]
self.evaluate(variables.global_variables_initializer())
self.assertDatasetProduces(
rebatched_dataset, expected_output, requires_initialization=True)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatch(self):
dataset = dataset_ops.Dataset.range(128).batch(
4, drop_remainder=True).padded_batch(
8, padded_shapes=[5])
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
# Each element is a list of 8 elements in which each element is a list of 5
# elements, first four are numbers and the last one is a padded zero.
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
for j in range(i, i + 32, 4)] # generates 8 elements
for i in range(0, 128, 32)]
self.assertDatasetProduces(dataset, expected_output)
self.assertEqual([[None, 5]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
# Each element is a list of 2 elements in which each element is a list of 5
# elements, first four are numbers and the last one is a padded zero.
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
for j in range(i, i + 8, 4)] # generates 2 elements
for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenate(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
dataset = dataset1.concatenate(dataset2)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
[[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
dataset = dataset1.concatenate(dataset2)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
[[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZip(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None], [None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZipDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None], [None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testFlatMapBatching(self):
dataset = dataset_ops.Dataset.range(2).flat_map(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
32))
# Two elements where each element is range(32)
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
# Two elements where each element is a list of 4 elements where each element
# is a list of 8.
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
for _ in range(2)
for i in range(0, 32, 8)] # generates 4 elements
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
32),
cycle_length=2)
# Two elements where each element is range(32)
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
32),
cycle_length=2,
num_parallel_calls=2)
# Two elements where each element is range(32)
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowStaticBatch(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda
batch_size=10)
dataset = dataset.apply(
grouping.group_by_window(
key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2)
self.assertEqual([[None, 3]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
# pylint: disable=g-complex-comprehension
expected_output = [[[j + i * 4 + k * 20] * 3
for i in range(5)]
for j in range(4)
for k in range(2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatch(self):
# {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
def reduce_fn(key, ds):
# key == 0 -> .batch(5)
# key == 1 -> .batch(10)
return ds.batch(batch_size=(key + 1) * 5)
dataset = dataset.apply(
grouping.group_by_window(
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
# the batches of 10 (value == 1) split into minibatches of (5, 5)
# [(batch_size, value), ...]
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)]
pairs = pairs * 2
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatch(self):
# {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
def reduce_fn(key, ds):
# key == 0 -> .batch(5)
# key == 1 -> .batch(10)
return ds.batch(batch_size=(key + 1) * 5)
dataset = dataset.apply(
grouping.group_by_window(
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1),
(1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
# This test exercises nested batch functionality, dynamic batch size
# and drop_remainder=True together.
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
def reduce_fn(key, ds):
# key == 0 -> .batch(5)
# key == 1 -> .batch(10)
return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True)
dataset = dataset.apply(
grouping.group_by_window(
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
# the batches of 10 (value == 1) split into minibatches of (5, 5)
# [(batch_size, value), ...]
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)]
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testScanAfterBatch(self):
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(dataset)])
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMakeBatchedFeaturesDataset(self):
# Set up
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
writer = python_io.TFRecordWriter(fn)
for i in range(1024):
writer.write(
example_pb2.Example(
features=feature_pb2.Features(
feature={
"value":
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[i]))
})).SerializeToString())
writer.close()
dataset = readers.make_batched_features_dataset(
file_pattern=fn,
batch_size=32,
features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)},
shuffle=False,
num_epochs=1,
drop_final_batch=False)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = [{
"value": [k for k in range(i, i + 8)]
} for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testRaggedTensorDataset(self):
# Set up a dataset that produces ragged tensors with a static batch size.