From 208edc153e450412e121196b92e25b127c1cb3f2 Mon Sep 17 00:00:00 2001
From: Jiri Simsa <jsimsa@google.com>
Date: Sun, 22 Nov 2020 11:36:20 -0800
Subject: [PATCH] [tf.data] Making sure setting tf.data threading options does
 not prevent auto-sharding from using file-level sharding.

PiperOrigin-RevId: 343746154
Change-Id: I241e0c951c46086e9f043588ddb3aa0eff23e593
---
 .../grappler/optimizers/data/auto_shard.cc    |  4 ++-
 .../kernel_tests/auto_shard_dataset_test.py   | 30 +++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
index 1288f9695b9..ae315e97ccb 100644
--- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc
+++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
@@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
     "ZipDataset"
 };
 
-constexpr std::array<const char*, 26> kPassThroughOps = {
+constexpr std::array<const char*, 28> kPassThroughOps = {
     "_Retval",
     "AssertNextDataset",
     "BatchDataset",
@@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = {
     "Identity",
     "MapAndBatchDataset",
     "MapDataset",
+    "MaxIntraOpParallelismDataset",
     "ModelDataset",
     "OptimizeDataset",
     "PaddedBatchDataset",
     "ParallelMapDataset",
     "ParseExampleDataset",
     "PrefetchDataset",
+    "PrivateThreadPoolDataset",
     "ReduceDataset",
     "RebatchDataset",
     "RepeatDataset",
diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
index d428baca9c0..86088488e05 100644
--- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
@@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
     ]
     self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testMaxIntraOpParallelism(self):
+    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
+    dataset = dataset.flat_map(core_readers.TFRecordDataset)
+    dataset = dataset.batch(5)
+    dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
+    dataset = distribute._AutoShardDataset(dataset, 5, 0)
+
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in (0, 5)
+        for r in range(0, 10)
+    ]
+    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testPrivateThreadpool(self):
+    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
+    dataset = dataset.flat_map(core_readers.TFRecordDataset)
+    dataset = dataset.batch(5)
+    dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
+    dataset = distribute._AutoShardDataset(dataset, 5, 0)
+
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in (0, 5)
+        for r in range(0, 10)
+    ]
+    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
+
   @combinations.generate(test_base.default_test_combinations())
   def testMakeBatchedFeaturesDataset(self):
     files = 2