diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 1288f9695b9..4d324ecbd3d 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -45,9 +45,6 @@ constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3"; constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset"; constexpr char kRebatchDatasetOpName[] = "RebatchDataset"; constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2"; -constexpr char kTensorDatasetOpName[] = "TensorDataset"; -constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset"; -constexpr char kPlaceholderOpName[] = "Placeholder"; constexpr char kNumWorkersAttrName[] = "num_workers"; constexpr char kNumReplicasAttrName[] = "num_replicas"; @@ -71,13 +68,12 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = { "ZipDataset" }; -constexpr std::array<const char*, 26> kPassThroughOps = { +constexpr std::array<const char*, 25> kPassThroughOps = { "_Retval", "AssertNextDataset", "BatchDataset", "CacheDataset", "ExperimentalMapAndBatchDataset", - "ExperimentalParseExampleDataset", "ExperimentalRebatchDataset", "FilterDataset", "Identity", @@ -417,33 +413,6 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node, return Status::OK(); } -const NodeDef* FindFuncAndTensorSliceDataset( - const NodeDef* node, int64 num_workers, int64 index, - FunctionLibraryDefinition* flib, MutableGraphView* graph, - absl::flat_hash_set<string>* nodes_to_delete) { - if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) { - const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0); - if (input_node->op() == kTensorSliceDatasetOpName || - input_node->op() == kTensorDatasetOpName) { - const NodeDef* next_input_node = - graph_utils::GetInputNode(*input_node, *graph, 0); - if (next_input_node->op() == kPlaceholderOpName) { - return node; - } - } - } - - if (!IsDatasetNodeOfType(*node, kPassThroughOps)) { - return nullptr; - } - - // Sometimes there are other nodes between the last InterleaveDataset and the - // second to last FlatMapDataset, so we need to skip over those. - const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0); - return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib, - graph, nodes_to_delete); -} - Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index, FunctionLibraryDefinition* flib, MutableGraphView* graph, @@ -472,39 +441,6 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index, return Status::OK(); } - // This handles the case for the following subgraph: - // Placeholder -> TensorSliceDataset -> FlatMapDataset -x-> - // (other preprocessing datasets) -> InterleaveDataset - // and then inserting the shard node immediately after the FlatMapDataset. - // - // This is used for some training pipelines where a dataset is created with - // the following code: - // - // def make_dataset_pipeline(): - // file_globs = [...] - // datasets = [] - // for file_glob in file_globs: - // datasets.append(Dataset.list_files(file_glob).map(TFRecordReader)) - // dataset = Dataset.from_tensor_slices(datasets) - // dataset = dataset.flat_map(lambda x: x) - // dataset = ... # additional preprocessing - // dataset = dataset.interleave(lambda x: x, cycle_length=...) - // return dataset - if (IsDatasetNodeOfType(node, kFuncDatasetOps)) { - const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); - const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset( - input_node, num_workers, index, flib, graph, nodes_to_delete); - - if (flat_map_node != nullptr) { - auto fanouts = graph->GetFanouts(*flat_map_node, false); - // FlatMapDataset should only be the input to one other dataset. - if (fanouts.size() == 1) { - return ProcessDatasetSourceNode(graph, *fanouts.begin()->node, - nodes_to_delete, num_workers, index); - } - } - } - // This handles the case where a reader Dataset is contained within a // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example: // @@ -634,6 +570,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index, MutableGraphView graph(output); FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library()); + NodeDef* sink_node; TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node)); diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index d428baca9c0..564dda0cf11 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -103,43 +103,6 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) - @combinations.generate( - combinations.times(test_base.default_test_combinations(), - combinations.combine(batch_size=[1, 3, 10]))) - def testDatasetOfReaderDatasetsPipeline(self, batch_size): - # This tests a scenario where a list_files main return multiple files - # due to the glob containing wildcards. - def batch(iterator, n): - l = len(iterator) - for i in range(0, l, n): - yield iterator[i:min(i + n, l)] - - datasets = [] - for files in batch(self.test_filenames, batch_size): - datasets.append( - dataset_ops.Dataset.list_files(files, shuffle=False).map( - core_readers.TFRecordDataset)) - dataset = dataset_ops.Dataset.from_tensor_slices(datasets) - dataset = dataset.flat_map(lambda x: x) - - # Simulate additional ops in between flat_map and interleave. This should be - # a no-op since if ShardDataset is placed right after flat_map, we will only - # have two datasets left at this point. - dataset = dataset.prefetch(1) - dataset = dataset.prefetch(1) - - dataset = dataset.interleave( - lambda x: x, cycle_length=1, num_parallel_calls=1) - - dataset = distribute._AutoShardDataset(dataset, 5, 0) - expected = [ - b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension - for f in (0, 5) - for r in range(0, 10) - ] - - self.assertDatasetProduces(dataset, expected) - @combinations.generate(test_base.default_test_combinations()) def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files(