Merge pull request #45116 from geetachavan1/cherrypicks_YI1P6

[CherryPick:r2.4][tf.data] Making sure setting tf.data threading options does not prevent auto-sharding from using file-level sharding.
This commit is contained in:
Goldie Gadde 2020-11-23 15:07:02 -08:00 committed by GitHub
commit 4041fe82d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 1 deletions

View File

@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr std::array<const char*, 26> kPassThroughOps = {
constexpr std::array<const char*, 28> kPassThroughOps = {
"_Retval",
"AssertNextDataset",
"BatchDataset",
@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
"Identity",
"MapAndBatchDataset",
"MapDataset",
"MaxIntraOpParallelismDataset",
"ModelDataset",
"OptimizeDataset",
"PaddedBatchDataset",
"ParallelMapDataset",
"ParseExampleDataset",
"PrefetchDataset",
"PrivateThreadPoolDataset",
"ReduceDataset",
"RebatchDataset",
"RepeatDataset",

View File

@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testMaxIntraOpParallelism(self):
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testPrivateThreadpool(self):
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
@combinations.generate(test_base.default_test_combinations())
def testMakeBatchedFeaturesDataset(self):
files = 2