[tf.data] Making sure setting tf.data threading options does not prevent auto-sharding from using file-level sharding.

PiperOrigin-RevId: 343746154
Change-Id: I241e0c951c46086e9f043588ddb3aa0eff23e593
This commit is contained in:
Jiri Simsa 2020-11-22 11:36:20 -08:00 committed by Geeta Chavan
parent 0b06f2927b
commit 208edc153e
2 changed files with 33 additions and 1 deletions
tensorflow
core/grappler/optimizers/data
python/data/experimental/kernel_tests

View File

@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr std::array<const char*, 26> kPassThroughOps = {
constexpr std::array<const char*, 28> kPassThroughOps = {
"_Retval",
"AssertNextDataset",
"BatchDataset",
@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
"Identity",
"MapAndBatchDataset",
"MapDataset",
"MaxIntraOpParallelismDataset",
"ModelDataset",
"OptimizeDataset",
"PaddedBatchDataset",
"ParallelMapDataset",
"ParseExampleDataset",
"PrefetchDataset",
"PrivateThreadPoolDataset",
"ReduceDataset",
"RebatchDataset",
"RepeatDataset",

View File

@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testMaxIntraOpParallelism(self):
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testPrivateThreadpool(self):
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testMakeBatchedFeaturesDataset(self):
files = 2