[tf.data] Modifying auto-sharding to remove unsupported assert_cardinality
transformations (instead of triggering an error).
PiperOrigin-RevId: 308424479 Change-Id: I55f8e2f1eb818916003cab6a25ebed938066ce33
This commit is contained in:
parent
ea8087efe5
commit
3f3ab6a249
@ -30,12 +30,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
// clang-format off
|
||||
constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
|
||||
constexpr char kShardDatasetOpName[] = "ShardDataset";
|
||||
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
|
||||
@ -48,6 +49,7 @@ constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
|
||||
// clang-format off
|
||||
constexpr std::array<const char*, 6> kReaderDatasetOps = {
|
||||
"FixedLengthRecordDataset",
|
||||
"FixedLengthRecordDatasetV2",
|
||||
@ -62,9 +64,8 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 31> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 30> kPassThroughOps = {
|
||||
"_Retval",
|
||||
"AssertCardinalityDataset",
|
||||
"AssertNextDataset",
|
||||
"BatchDataset",
|
||||
"BatchDatasetV2",
|
||||
@ -415,6 +416,16 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||
FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph,
|
||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
||||
if (node.op() == kAssertCardinalityDatasetOpName) {
|
||||
LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
|
||||
"handled by the auto-shard rewrite and will be removed.";
|
||||
nodes_to_delete->insert(node.name());
|
||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
|
||||
nodes_to_delete);
|
||||
}
|
||||
|
||||
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
|
||||
return errors::NotFound("Found an unshardable source dataset: ",
|
||||
node.DebugString());
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
@ -429,6 +430,21 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
dataset = distribute._AutoShardDataset(dataset, 2, 2)
|
||||
self.evaluate(self.getNext(dataset)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssertCardinality(self):
|
||||
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||
dataset = dataset.batch(5)
|
||||
dataset = dataset.apply(cardinality.assert_cardinality(42))
|
||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in (0, 5)
|
||||
for r in range(0, 10)
|
||||
]
|
||||
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||
|
||||
|
||||
class AutoShardTextLineDatasetTest(
|
||||
reader_dataset_ops_test_base.TextLineDatasetTestBase,
|
||||
|
Loading…
Reference in New Issue
Block a user