Merge pull request #45116 from geetachavan1/cherrypicks_YI1P6
[CherryPick:r2.4][tf.data] Making sure setting tf.data threading options does not prevent auto-sharding from using file-level sharding.
This commit is contained in:
commit
4041fe82d3
@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
|||||||
"ZipDataset"
|
"ZipDataset"
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 26> kPassThroughOps = {
|
constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||||
"_Retval",
|
"_Retval",
|
||||||
"AssertNextDataset",
|
"AssertNextDataset",
|
||||||
"BatchDataset",
|
"BatchDataset",
|
||||||
@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
|
|||||||
"Identity",
|
"Identity",
|
||||||
"MapAndBatchDataset",
|
"MapAndBatchDataset",
|
||||||
"MapDataset",
|
"MapDataset",
|
||||||
|
"MaxIntraOpParallelismDataset",
|
||||||
"ModelDataset",
|
"ModelDataset",
|
||||||
"OptimizeDataset",
|
"OptimizeDataset",
|
||||||
"PaddedBatchDataset",
|
"PaddedBatchDataset",
|
||||||
"ParallelMapDataset",
|
"ParallelMapDataset",
|
||||||
"ParseExampleDataset",
|
"ParseExampleDataset",
|
||||||
"PrefetchDataset",
|
"PrefetchDataset",
|
||||||
|
"PrivateThreadPoolDataset",
|
||||||
"ReduceDataset",
|
"ReduceDataset",
|
||||||
"RebatchDataset",
|
"RebatchDataset",
|
||||||
"RepeatDataset",
|
"RepeatDataset",
|
||||||
|
@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
]
|
]
|
||||||
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testMaxIntraOpParallelism(self):
|
||||||
|
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||||
|
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||||
|
dataset = dataset.batch(5)
|
||||||
|
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
|
||||||
|
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in (0, 5)
|
||||||
|
for r in range(0, 10)
|
||||||
|
]
|
||||||
|
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testPrivateThreadpool(self):
|
||||||
|
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||||
|
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||||
|
dataset = dataset.batch(5)
|
||||||
|
dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
|
||||||
|
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in (0, 5)
|
||||||
|
for r in range(0, 10)
|
||||||
|
]
|
||||||
|
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testMakeBatchedFeaturesDataset(self):
|
def testMakeBatchedFeaturesDataset(self):
|
||||||
files = 2
|
files = 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user