From 3f3ab6a249054b89947f33cadf308d685b0aff31 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 25 Apr 2020 10:33:38 -0700 Subject: [PATCH] [tf.data] Modifying auto-sharding to remove unsupported `assert_cardinality` transformations (instead of triggering an error). PiperOrigin-RevId: 308424479 Change-Id: I55f8e2f1eb818916003cab6a25ebed938066ce33 --- .../core/grappler/optimizers/data/auto_shard.cc | 17 ++++++++++++++--- .../kernel_tests/auto_shard_dataset_test.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index feabd7b2b5e..3e8583d74e9 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -30,12 +30,13 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace grappler { namespace { -// clang-format off +constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset"; constexpr char kShardDatasetOpName[] = "ShardDataset"; constexpr char kShuffleDatasetOpName[] = "ShuffleDataset"; constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2"; @@ -48,6 +49,7 @@ constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; constexpr char kOutputShapes[] = "output_shapes"; constexpr char kOutputTypes[] = "output_types"; +// clang-format off constexpr std::array kReaderDatasetOps = { "FixedLengthRecordDataset", "FixedLengthRecordDatasetV2", @@ -62,9 +64,8 @@ constexpr std::array kMultipleInputsDatasetOps = { "ZipDataset" }; -constexpr std::array kPassThroughOps = { +constexpr std::array kPassThroughOps = { "_Retval", - "AssertCardinalityDataset", "AssertNextDataset", "BatchDataset", "BatchDatasetV2", @@ -415,6 +416,16 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index, FunctionLibraryDefinition* flib, MutableGraphView* graph, absl::flat_hash_set* nodes_to_delete) { + if (node.op() == kAssertCardinalityDatasetOpName) { + LOG(WARNING) << "The `assert_cardinality` transformation is currently not " + "handled by the auto-shard rewrite and will be removed."; + nodes_to_delete->insert(node.name()); + TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0))); + const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); + return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph, + nodes_to_delete); + } + if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) { return errors::NotFound("Found an unshardable source dataset: ", node.DebugString()); diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index b2b348a436e..8271dbada7a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import interleave_ops @@ -429,6 +430,21 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)()) + @combinations.generate(test_base.default_test_combinations()) + def testAssertCardinality(self): + dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) + dataset = dataset.flat_map(core_readers.TFRecordDataset) + dataset = dataset.batch(5) + dataset = dataset.apply(cardinality.assert_cardinality(42)) + dataset = distribute._AutoShardDataset(dataset, 5, 0) + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + class AutoShardTextLineDatasetTest( reader_dataset_ops_test_base.TextLineDatasetTestBase,