[tf.data] Reenable rebatch_dataset_test in TAP and remove some tests from it to reduce its size. These tests exercise logic that was important to check for in the old rewrite-based RebatchDataset. Now that rebatching is done with its own RebatchDataset kernel, they are no longer relevant.
PiperOrigin-RevId: 304296546 Change-Id: I673c5fb88bad993b92d760cd8832271716401e25
This commit is contained in:
parent
a0422e404e
commit
7a96514f55
@ -580,16 +580,9 @@ tf_py_test(
|
||||
name = "rebatch_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["rebatch_dataset_test.py"],
|
||||
tags = [
|
||||
"manual", # TODO(b/152215379)
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
|
@ -17,30 +17,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.experimental.ops import scan_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -206,290 +193,6 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 128, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
batching.map_and_batch(math_ops.square, 32))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||
for i in range(0, 1024, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatchWithCapturedInput(self):
|
||||
captured_t = variables.Variable(42)
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
batching.map_and_batch(lambda x: captured_t, 32))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||
for i in range(0, 1024, 8)]
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertDatasetProduces(
|
||||
rebatched_dataset, expected_output, requires_initialization=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(128).batch(
|
||||
4, drop_remainder=True).padded_batch(
|
||||
8, padded_shapes=[5])
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
# Each element is a list of 8 elements in which each element is a list of 5
|
||||
# elements, first four are numbers and the last one is a padded zero.
|
||||
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
||||
for j in range(i, i + 32, 4)] # generates 8 elements
|
||||
for i in range(0, 128, 32)]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
self.assertEqual([[None, 5]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# Each element is a list of 2 elements in which each element is a list of 5
|
||||
# elements, first four are numbers and the last one is a padded zero.
|
||||
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
||||
for j in range(i, i + 8, 4)] # generates 2 elements
|
||||
for i in range(0, 128, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenate(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset1.concatenate(dataset2)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
|
||||
[[i, i + 1] for i in range(0, 32, 2)])
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDifferentShapes(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset1.concatenate(dataset2)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
|
||||
[[i, i + 1] for i in range(0, 32, 2)])
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZip(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None], [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipDifferentShapes(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None], [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
|
||||
for i in range(0, 32, 2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFlatMapBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).flat_map(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
32))
|
||||
# Two elements where each element is range(32)
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# Two elements where each element is a list of 4 elements where each element
|
||||
# is a list of 8.
|
||||
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||
for _ in range(2)
|
||||
for i in range(0, 32, 8)] # generates 4 elements
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInterleaveBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
32),
|
||||
cycle_length=2)
|
||||
# Two elements where each element is range(32)
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
|
||||
expected_output += expected_output
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelInterleaveBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
32),
|
||||
cycle_length=2,
|
||||
num_parallel_calls=2)
|
||||
# Two elements where each element is range(32)
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
|
||||
expected_output += expected_output
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowStaticBatch(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
|
||||
reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda
|
||||
batch_size=10)
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None, 3]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# pylint: disable=g-complex-comprehension
|
||||
expected_output = [[[j + i * 4 + k * 20] * 3
|
||||
for i in range(5)]
|
||||
for j in range(4)
|
||||
for k in range(2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatch(self):
|
||||
# {0, 1, 0, 1, ...}
|
||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||
|
||||
def reduce_fn(key, ds):
|
||||
# key == 0 -> .batch(5)
|
||||
# key == 1 -> .batch(10)
|
||||
return ds.batch(batch_size=(key + 1) * 5)
|
||||
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||
# [(batch_size, value), ...]
|
||||
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)]
|
||||
pairs = pairs * 2
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatchWithPartialBatch(self):
|
||||
# {0, 1, 0, 1, ...}
|
||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||
|
||||
def reduce_fn(key, ds):
|
||||
# key == 0 -> .batch(5)
|
||||
# key == 1 -> .batch(10)
|
||||
return ds.batch(batch_size=(key + 1) * 5)
|
||||
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1),
|
||||
(1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
|
||||
# This test exercises nested batch functionality, dynamic batch size
|
||||
# and drop_remainder=True together.
|
||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||
|
||||
def reduce_fn(key, ds):
|
||||
# key == 0 -> .batch(5)
|
||||
# key == 1 -> .batch(10)
|
||||
return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True)
|
||||
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||
# [(batch_size, value), ...]
|
||||
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)]
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testScanAfterBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
|
||||
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeBatchedFeaturesDataset(self):
|
||||
# Set up
|
||||
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
|
||||
writer = python_io.TFRecordWriter(fn)
|
||||
for i in range(1024):
|
||||
writer.write(
|
||||
example_pb2.Example(
|
||||
features=feature_pb2.Features(
|
||||
feature={
|
||||
"value":
|
||||
feature_pb2.Feature(
|
||||
int64_list=feature_pb2.Int64List(value=[i]))
|
||||
})).SerializeToString())
|
||||
writer.close()
|
||||
|
||||
dataset = readers.make_batched_features_dataset(
|
||||
file_pattern=fn,
|
||||
batch_size=32,
|
||||
features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)},
|
||||
shuffle=False,
|
||||
num_epochs=1,
|
||||
drop_final_batch=False)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
|
||||
expected_output = [{
|
||||
"value": [k for k in range(i, i + 8)]
|
||||
} for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRaggedTensorDataset(self):
|
||||
# Set up a dataset that produces ragged tensors with a static batch size.
|
||||
|
Loading…
Reference in New Issue
Block a user