[tf.data] Have experimental_slack fail on nonconforming input pipelines. Also updates optimize_dataset_op to fail if any of the optimizers fail (previously, it would log an error but continue).

PiperOrigin-RevId: 251259335
This commit is contained in:
Rachel Lim 2019-06-03 10:28:46 -07:00 committed by TensorFlower Gardener
parent a6e1a25eab
commit 5085774f0f
6 changed files with 197 additions and 61 deletions
tensorflow
core
grappler/optimizers/data
kernels/data
python/data

View File

@ -34,9 +34,71 @@ namespace grappler {
namespace {
constexpr char kRetValOp[] = "_Retval";
constexpr char kPrefetchDatasetOp[] = "PrefetchDataset";
template <std::size_t SIZE>
bool IsDatasetNodeOfType(const NodeDef& node,
const std::array<const char*, SIZE>& arr) {
for (const auto& dataset_op_name : arr) {
if (node.op() == dataset_op_name) return true;
}
return false;
}
// We don't pass through "Batch*" ops and nested dataset ops (FlatMap, etc)
// because the correct slack_period cannot be determined directly in those
// cases.
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset", "ConcatenateDataset"};
constexpr std::array<const char*, 15> kPassThroughOps = {
"CacheDataset",
"FilterDataset",
"Identity",
"MapDataset",
"ModelDataset",
"OptimizeDataset",
"ParallelMapDataset",
"ReduceDataset",
"RepeatDataset",
"ShardDataset",
"ShuffleAndRepeatDataset",
"ShuffleDataset",
"SkipDataset",
"TakeDataset",
"WindowDataset"};
} // namespace
Status Slack::RecursivelyHandleOp(const MutableGraphView& graph,
NodeDef* dataset_node) {
if (dataset_node->op() == kPrefetchDatasetOp) {
if (HasNodeAttr(*dataset_node, "slack_period")) {
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
} else {
AddNodeAttr("slack_period", slack_period_, dataset_node);
}
return Status::OK();
}
if (IsDatasetNodeOfType(*dataset_node, kPassThroughOps)) {
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, 0);
return RecursivelyHandleOp(graph, input_node);
}
if (IsDatasetNodeOfType(*dataset_node, kMultipleInputsDatasetOps)) {
// For all multiple input datasets, all inputs are datasets themselves
for (int i = 0; i < dataset_node->input_size(); ++i) {
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, i);
TF_RETURN_IF_ERROR(RecursivelyHandleOp(graph, input_node));
}
return Status::OK();
}
return errors::InvalidArgument(
"Encountered unsupported op \"", dataset_node->op(),
"\" when rewriting the input pipeline graph to use slack in its "
"final prefetch transformation.");
}
Status Slack::OptimizeAndCollectStats(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output,
@ -63,30 +125,10 @@ Status Slack::OptimizeAndCollectStats(Cluster* cluster,
"Expected only one fetch node but there were ", item.fetch.size(), ": ",
absl::StrJoin(item.fetch, ", "));
}
// Walk the input pipeline backwards from the fetch node to find the last
// Walks the input pipeline backwards from the fetch node to find the last
// PrefetchDataset node in the pipeline.
// TODO(rachelim): This doesn't do the right thing when the "final" prefetch
// is nested under an interleave or flat_map. Make this work, similar to
// `auto_shard.cc` and `rebatch.cc`.
NodeDef* dataset_node = graph.GetNode(item.fetch.at(0));
while (true) {
if (dataset_node->op() == "PrefetchDataset") {
if (HasNodeAttr(*dataset_node, "slack_period")) {
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
} else {
AddNodeAttr("slack_period", slack_period_, dataset_node);
}
return Status::OK();
}
if (dataset_node->op() == "Identity" ||
(absl::EndsWith(dataset_node->op(), "Dataset") &&
dataset_node->input_size() > 0)) {
dataset_node = graph_utils::GetInputNode(*dataset_node, graph);
} else {
break;
}
}
return Status::OK();
return RecursivelyHandleOp(graph, dataset_node);
}
void Slack::Feedback(Cluster* cluster, const GrapplerItem& item,

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
namespace tensorflow {
@ -54,6 +55,9 @@ class Slack : public TFDataOptimizerBase {
private:
int64 slack_period_ = -1;
Status RecursivelyHandleOp(const MutableGraphView& graph,
NodeDef* dataset_node);
};
} // namespace grappler

View File

@ -60,6 +60,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
rewriter_config.set_fail_on_optimizer_errors(true);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =

View File

@ -450,6 +450,23 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
py_test(
name = "prefetch_with_slack_test",
size = "small",
srcs = ["prefetch_with_slack_test.py"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "reader_dataset_ops_test_base",
srcs = [

View File

@ -0,0 +1,111 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `experimental_slack` option."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_v1_only("b/121264236")
def testPrefetchWithSlackOption(self):
"""Determines slack_period based on num devices attached to iterator."""
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.prefetch(1)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
dataset, ["/cpu:1", "/cpu:2"])
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack:slack_period:2",
dataset.options()._static_optimization_configs())
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config):
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testPrefetchWithSlackOptionWithoutIterator(self):
"""Defaults to slack period of 1 without iterator."""
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.prefetch(1)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack:slack_period:1",
dataset.options()._static_optimization_configs())
self.assertDatasetProduces(dataset, range(10))
def testWithPassthroughDataset(self):
"""Should still work with a passthrough dataset after prefetch()."""
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.prefetch(1)
dataset = dataset.map(lambda x: x + 1)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, range(1, 11))
def testErrorWithoutPrefetch(self):
"""The rewrite fails if there is no prefetch() in the pipeline."""
dataset = dataset_ops.Dataset.range(10)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
with self.assertRaises(errors.InvalidArgumentError):
get_next = self.getNext(dataset)
self.evaluate(get_next())
def testErrorWithInvalidDataset(self):
"""With a nested dataset op after prefetch, the rewrite should fail."""
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.prefetch(1)
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensors)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
with self.assertRaises(errors.InvalidArgumentError):
get_next = self.getNext(dataset)
self.evaluate(get_next())
if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
test.main()

View File

@ -335,45 +335,6 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_2)
@test_util.run_all_in_graph_and_eager_modes
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_v1_only("b/121264236")
def testPrefetchWithSlackOption(self):
dataset = dataset_ops.Dataset.range(10)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
dataset, ["/cpu:1", "/cpu:2"])
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack:slack_period:2",
dataset.options()._static_optimization_configs())
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config):
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testPrefetchWithSlackOptionWithoutIterator(self):
dataset = dataset_ops.Dataset.range(10)
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack:slack_period:1",
dataset.options()._static_optimization_configs())
self.assertDatasetProduces(dataset, range(10))
if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={"CPU": 3, "GPU": 1}))