[tf.data] Modifying auto-sharding to remove unsupported assert_cardinality
transformations (instead of triggering an error).
PiperOrigin-RevId: 308424479 Change-Id: I55f8e2f1eb818916003cab6a25ebed938066ce33
This commit is contained in:
parent
ea8087efe5
commit
3f3ab6a249
@ -30,12 +30,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// clang-format off
|
constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
|
||||||
constexpr char kShardDatasetOpName[] = "ShardDataset";
|
constexpr char kShardDatasetOpName[] = "ShardDataset";
|
||||||
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
|
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
|
||||||
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
|
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
|
||||||
@ -48,6 +49,7 @@ constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
|||||||
constexpr char kOutputShapes[] = "output_shapes";
|
constexpr char kOutputShapes[] = "output_shapes";
|
||||||
constexpr char kOutputTypes[] = "output_types";
|
constexpr char kOutputTypes[] = "output_types";
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
constexpr std::array<const char*, 6> kReaderDatasetOps = {
|
constexpr std::array<const char*, 6> kReaderDatasetOps = {
|
||||||
"FixedLengthRecordDataset",
|
"FixedLengthRecordDataset",
|
||||||
"FixedLengthRecordDatasetV2",
|
"FixedLengthRecordDatasetV2",
|
||||||
@ -62,9 +64,8 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
|||||||
"ZipDataset"
|
"ZipDataset"
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 31> kPassThroughOps = {
|
constexpr std::array<const char*, 30> kPassThroughOps = {
|
||||||
"_Retval",
|
"_Retval",
|
||||||
"AssertCardinalityDataset",
|
|
||||||
"AssertNextDataset",
|
"AssertNextDataset",
|
||||||
"BatchDataset",
|
"BatchDataset",
|
||||||
"BatchDatasetV2",
|
"BatchDatasetV2",
|
||||||
@ -415,6 +416,16 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
|||||||
FunctionLibraryDefinition* flib,
|
FunctionLibraryDefinition* flib,
|
||||||
MutableGraphView* graph,
|
MutableGraphView* graph,
|
||||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
absl::flat_hash_set<string>* nodes_to_delete) {
|
||||||
|
if (node.op() == kAssertCardinalityDatasetOpName) {
|
||||||
|
LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
|
||||||
|
"handled by the auto-shard rewrite and will be removed.";
|
||||||
|
nodes_to_delete->insert(node.name());
|
||||||
|
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||||
|
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||||
|
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
|
||||||
|
nodes_to_delete);
|
||||||
|
}
|
||||||
|
|
||||||
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
|
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
|
||||||
return errors::NotFound("Found an unshardable source dataset: ",
|
return errors::NotFound("Found an unshardable source dataset: ",
|
||||||
node.DebugString());
|
node.DebugString());
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||||
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
from tensorflow.python.data.experimental.ops import distribute
|
from tensorflow.python.data.experimental.ops import distribute
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||||
@ -429,6 +430,21 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
dataset = distribute._AutoShardDataset(dataset, 2, 2)
|
dataset = distribute._AutoShardDataset(dataset, 2, 2)
|
||||||
self.evaluate(self.getNext(dataset)())
|
self.evaluate(self.getNext(dataset)())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testAssertCardinality(self):
|
||||||
|
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||||
|
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||||
|
dataset = dataset.batch(5)
|
||||||
|
dataset = dataset.apply(cardinality.assert_cardinality(42))
|
||||||
|
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in (0, 5)
|
||||||
|
for r in range(0, 10)
|
||||||
|
]
|
||||||
|
self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
|
||||||
|
|
||||||
|
|
||||||
class AutoShardTextLineDatasetTest(
|
class AutoShardTextLineDatasetTest(
|
||||||
reader_dataset_ops_test_base.TextLineDatasetTestBase,
|
reader_dataset_ops_test_base.TextLineDatasetTestBase,
|
||||||
|
Loading…
Reference in New Issue
Block a user