[tf.data] Modifying auto-sharding to remove unsupported assert_cardinality transformations (instead of triggering an error).

PiperOrigin-RevId: 308424479
Change-Id: I55f8e2f1eb818916003cab6a25ebed938066ce33
This commit is contained in:
Jiri Simsa 2020-04-25 10:33:38 -07:00 committed by TensorFlower Gardener
parent ea8087efe5
commit 3f3ab6a249
2 changed files with 30 additions and 3 deletions

View File

@ -30,12 +30,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace grappler {
namespace {
// clang-format off
constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
constexpr char kShardDatasetOpName[] = "ShardDataset";
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
@ -48,6 +49,7 @@ constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
constexpr char kOutputShapes[] = "output_shapes";
constexpr char kOutputTypes[] = "output_types";
// clang-format off
constexpr std::array<const char*, 6> kReaderDatasetOps = {
"FixedLengthRecordDataset",
"FixedLengthRecordDatasetV2",
@ -62,9 +64,8 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr std::array<const char*, 31> kPassThroughOps = {
constexpr std::array<const char*, 30> kPassThroughOps = {
"_Retval",
"AssertCardinalityDataset",
"AssertNextDataset",
"BatchDataset",
"BatchDatasetV2",
@ -415,6 +416,16 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib,
MutableGraphView* graph,
absl::flat_hash_set<string>* nodes_to_delete) {
if (node.op() == kAssertCardinalityDatasetOpName) {
LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
"handled by the auto-shard rewrite and will be removed.";
nodes_to_delete->insert(node.name());
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
nodes_to_delete);
}
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
return errors::NotFound("Found an unshardable source dataset: ",
node.DebugString());

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import interleave_ops
@ -429,6 +430,21 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
dataset = distribute._AutoShardDataset(dataset, 2, 2)
self.evaluate(self.getNext(dataset)())
@combinations.generate(test_base.default_test_combinations())
def testAssertCardinality(self):
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = dataset.apply(cardinality.assert_cardinality(42))
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
class AutoShardTextLineDatasetTest(
reader_dataset_ops_test_base.TextLineDatasetTestBase,