[tf.data] Making sure setting tf.data threading options does not prevent auto-sharding from using file-level sharding.
PiperOrigin-RevId: 343746154 Change-Id: I241e0c951c46086e9f043588ddb3aa0eff23e593
This commit is contained in:
parent
0b06f2927b
commit
208edc153e
tensorflow
core/grappler/optimizers/data
python/data/experimental/kernel_tests
@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 26> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||
"_Retval",
|
||||
"AssertNextDataset",
|
||||
"BatchDataset",
|
||||
@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
|
||||
"Identity",
|
||||
"MapAndBatchDataset",
|
||||
"MapDataset",
|
||||
"MaxIntraOpParallelismDataset",
|
||||
"ModelDataset",
|
||||
"OptimizeDataset",
|
||||
"PaddedBatchDataset",
|
||||
"ParallelMapDataset",
|
||||
"ParseExampleDataset",
|
||||
"PrefetchDataset",
|
||||
"PrivateThreadPoolDataset",
|
||||
"ReduceDataset",
|
||||
"RebatchDataset",
|
||||
"RepeatDataset",
|
||||
|
@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
]
|
||||
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMaxIntraOpParallelism(self):
|
||||
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||
dataset = dataset.batch(5)
|
||||
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
|
||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in (0, 5)
|
||||
for r in range(0, 10)
|
||||
]
|
||||
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPrivateThreadpool(self):
|
||||
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||
dataset = dataset.batch(5)
|
||||
dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
|
||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in (0, 5)
|
||||
for r in range(0, 10)
|
||||
]
|
||||
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeBatchedFeaturesDataset(self):
|
||||
files = 2
|
||||
|
Loading…
Reference in New Issue
Block a user