From 208edc153e450412e121196b92e25b127c1cb3f2 Mon Sep 17 00:00:00 2001 From: Jiri Simsa <jsimsa@google.com> Date: Sun, 22 Nov 2020 11:36:20 -0800 Subject: [PATCH] [tf.data] Making sure setting tf.data threading options does not prevent auto-sharding from using file-level sharding. PiperOrigin-RevId: 343746154 Change-Id: I241e0c951c46086e9f043588ddb3aa0eff23e593 --- .../grappler/optimizers/data/auto_shard.cc | 4 ++- .../kernel_tests/auto_shard_dataset_test.py | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 1288f9695b9..ae315e97ccb 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = { "ZipDataset" }; -constexpr std::array<const char*, 26> kPassThroughOps = { +constexpr std::array<const char*, 28> kPassThroughOps = { "_Retval", "AssertNextDataset", "BatchDataset", @@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = { "Identity", "MapAndBatchDataset", "MapDataset", + "MaxIntraOpParallelismDataset", "ModelDataset", "OptimizeDataset", "PaddedBatchDataset", "ParallelMapDataset", "ParseExampleDataset", "PrefetchDataset", + "PrivateThreadPoolDataset", "ReduceDataset", "RebatchDataset", "RepeatDataset", diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index d428baca9c0..86088488e05 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + @combinations.generate(test_base.default_test_combinations()) + def testMaxIntraOpParallelism(self): + dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) + dataset = dataset.flat_map(core_readers.TFRecordDataset) + dataset = dataset.batch(5) + dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) + dataset = distribute._AutoShardDataset(dataset, 5, 0) + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + + @combinations.generate(test_base.default_test_combinations()) + def testPrivateThreadpool(self): + dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) + dataset = dataset.flat_map(core_readers.TFRecordDataset) + dataset = dataset.batch(5) + dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1) + dataset = distribute._AutoShardDataset(dataset, 5, 0) + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): files = 2