[tf.data] Have experimental_slack fail on nonconforming input pipelines. Also updates optimize_dataset_op to fail if any of the optimizers fail (previously, it would log an error but continue).
PiperOrigin-RevId: 251259335
This commit is contained in:
parent
a6e1a25eab
commit
5085774f0f
tensorflow
core
python/data
@ -34,9 +34,71 @@ namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kRetValOp[] = "_Retval";
|
||||
constexpr char kPrefetchDatasetOp[] = "PrefetchDataset";
|
||||
|
||||
template <std::size_t SIZE>
|
||||
bool IsDatasetNodeOfType(const NodeDef& node,
|
||||
const std::array<const char*, SIZE>& arr) {
|
||||
for (const auto& dataset_op_name : arr) {
|
||||
if (node.op() == dataset_op_name) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// We don't pass through "Batch*" ops and nested dataset ops (FlatMap, etc)
|
||||
// because the correct slack_period cannot be determined directly in those
|
||||
// cases.
|
||||
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset", "ConcatenateDataset"};
|
||||
|
||||
constexpr std::array<const char*, 15> kPassThroughOps = {
|
||||
"CacheDataset",
|
||||
"FilterDataset",
|
||||
"Identity",
|
||||
"MapDataset",
|
||||
"ModelDataset",
|
||||
"OptimizeDataset",
|
||||
"ParallelMapDataset",
|
||||
"ReduceDataset",
|
||||
"RepeatDataset",
|
||||
"ShardDataset",
|
||||
"ShuffleAndRepeatDataset",
|
||||
"ShuffleDataset",
|
||||
"SkipDataset",
|
||||
"TakeDataset",
|
||||
"WindowDataset"};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Slack::RecursivelyHandleOp(const MutableGraphView& graph,
|
||||
NodeDef* dataset_node) {
|
||||
if (dataset_node->op() == kPrefetchDatasetOp) {
|
||||
if (HasNodeAttr(*dataset_node, "slack_period")) {
|
||||
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
|
||||
} else {
|
||||
AddNodeAttr("slack_period", slack_period_, dataset_node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (IsDatasetNodeOfType(*dataset_node, kPassThroughOps)) {
|
||||
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, 0);
|
||||
return RecursivelyHandleOp(graph, input_node);
|
||||
}
|
||||
if (IsDatasetNodeOfType(*dataset_node, kMultipleInputsDatasetOps)) {
|
||||
// For all multiple input datasets, all inputs are datasets themselves
|
||||
for (int i = 0; i < dataset_node->input_size(); ++i) {
|
||||
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, i);
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(graph, input_node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return errors::InvalidArgument(
|
||||
"Encountered unsupported op \"", dataset_node->op(),
|
||||
"\" when rewriting the input pipeline graph to use slack in its "
|
||||
"final prefetch transformation.");
|
||||
}
|
||||
|
||||
Status Slack::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
@ -63,30 +125,10 @@ Status Slack::OptimizeAndCollectStats(Cluster* cluster,
|
||||
"Expected only one fetch node but there were ", item.fetch.size(), ": ",
|
||||
absl::StrJoin(item.fetch, ", "));
|
||||
}
|
||||
// Walk the input pipeline backwards from the fetch node to find the last
|
||||
// Walks the input pipeline backwards from the fetch node to find the last
|
||||
// PrefetchDataset node in the pipeline.
|
||||
// TODO(rachelim): This doesn't do the right thing when the "final" prefetch
|
||||
// is nested under an interleave or flat_map. Make this work, similar to
|
||||
// `auto_shard.cc` and `rebatch.cc`.
|
||||
NodeDef* dataset_node = graph.GetNode(item.fetch.at(0));
|
||||
while (true) {
|
||||
if (dataset_node->op() == "PrefetchDataset") {
|
||||
if (HasNodeAttr(*dataset_node, "slack_period")) {
|
||||
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
|
||||
} else {
|
||||
AddNodeAttr("slack_period", slack_period_, dataset_node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (dataset_node->op() == "Identity" ||
|
||||
(absl::EndsWith(dataset_node->op(), "Dataset") &&
|
||||
dataset_node->input_size() > 0)) {
|
||||
dataset_node = graph_utils::GetInputNode(*dataset_node, graph);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
return RecursivelyHandleOp(graph, dataset_node);
|
||||
}
|
||||
|
||||
void Slack::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -54,6 +55,9 @@ class Slack : public TFDataOptimizerBase {
|
||||
|
||||
private:
|
||||
int64 slack_period_ = -1;
|
||||
|
||||
Status RecursivelyHandleOp(const MutableGraphView& graph,
|
||||
NodeDef* dataset_node);
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
|
@ -60,6 +60,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(
|
||||
RewriterConfig_NumIterationsType_ONE);
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
auto* custom_optimizations_list =
|
||||
|
@ -450,6 +450,23 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "prefetch_with_slack_test",
|
||||
size = "small",
|
||||
srcs = ["prefetch_with_slack_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "reader_dataset_ops_test_base",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,111 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `experimental_slack` option."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
def testPrefetchWithSlackOption(self):
|
||||
"""Determines slack_period based on num devices attached to iterator."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.prefetch(1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
dataset, ["/cpu:1", "/cpu:2"])
|
||||
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
|
||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||
self.assertIn("slack:slack_period:2",
|
||||
dataset.options()._static_optimization_configs())
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config):
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testPrefetchWithSlackOptionWithoutIterator(self):
|
||||
"""Defaults to slack period of 1 without iterator."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.prefetch(1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||
self.assertIn("slack:slack_period:1",
|
||||
dataset.options()._static_optimization_configs())
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
|
||||
def testWithPassthroughDataset(self):
|
||||
"""Should still work with a passthrough dataset after prefetch()."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.prefetch(1)
|
||||
dataset = dataset.map(lambda x: x + 1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, range(1, 11))
|
||||
|
||||
def testErrorWithoutPrefetch(self):
|
||||
"""The rewrite fails if there is no prefetch() in the pipeline."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
def testErrorWithInvalidDataset(self):
|
||||
"""With a nested dataset op after prefetch, the rewrite should fail."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.prefetch(1)
|
||||
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensors)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
||||
test.main()
|
@ -335,45 +335,6 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
def testPrefetchWithSlackOption(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
dataset, ["/cpu:1", "/cpu:2"])
|
||||
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
|
||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||
self.assertIn("slack:slack_period:2",
|
||||
dataset.options()._static_optimization_configs())
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config):
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testPrefetchWithSlackOptionWithoutIterator(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||
self.assertIn("slack:slack_period:1",
|
||||
dataset.options()._static_optimization_configs())
|
||||
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 3, "GPU": 1}))
|
||||
|
Loading…
Reference in New Issue
Block a user