[tf.data] Have experimental_slack fail on nonconforming input pipelines. Also updates optimize_dataset_op to fail if any of the optimizers fail (previously, it would log an error but continue).
PiperOrigin-RevId: 251259335
This commit is contained in:
parent
a6e1a25eab
commit
5085774f0f
@ -34,9 +34,71 @@ namespace grappler {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kRetValOp[] = "_Retval";
|
constexpr char kRetValOp[] = "_Retval";
|
||||||
|
constexpr char kPrefetchDatasetOp[] = "PrefetchDataset";
|
||||||
|
|
||||||
|
template <std::size_t SIZE>
|
||||||
|
bool IsDatasetNodeOfType(const NodeDef& node,
|
||||||
|
const std::array<const char*, SIZE>& arr) {
|
||||||
|
for (const auto& dataset_op_name : arr) {
|
||||||
|
if (node.op() == dataset_op_name) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't pass through "Batch*" ops and nested dataset ops (FlatMap, etc)
|
||||||
|
// because the correct slack_period cannot be determined directly in those
|
||||||
|
// cases.
|
||||||
|
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||||
|
"ZipDataset", "ConcatenateDataset"};
|
||||||
|
|
||||||
|
constexpr std::array<const char*, 15> kPassThroughOps = {
|
||||||
|
"CacheDataset",
|
||||||
|
"FilterDataset",
|
||||||
|
"Identity",
|
||||||
|
"MapDataset",
|
||||||
|
"ModelDataset",
|
||||||
|
"OptimizeDataset",
|
||||||
|
"ParallelMapDataset",
|
||||||
|
"ReduceDataset",
|
||||||
|
"RepeatDataset",
|
||||||
|
"ShardDataset",
|
||||||
|
"ShuffleAndRepeatDataset",
|
||||||
|
"ShuffleDataset",
|
||||||
|
"SkipDataset",
|
||||||
|
"TakeDataset",
|
||||||
|
"WindowDataset"};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
Status Slack::RecursivelyHandleOp(const MutableGraphView& graph,
|
||||||
|
NodeDef* dataset_node) {
|
||||||
|
if (dataset_node->op() == kPrefetchDatasetOp) {
|
||||||
|
if (HasNodeAttr(*dataset_node, "slack_period")) {
|
||||||
|
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
|
||||||
|
} else {
|
||||||
|
AddNodeAttr("slack_period", slack_period_, dataset_node);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (IsDatasetNodeOfType(*dataset_node, kPassThroughOps)) {
|
||||||
|
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, 0);
|
||||||
|
return RecursivelyHandleOp(graph, input_node);
|
||||||
|
}
|
||||||
|
if (IsDatasetNodeOfType(*dataset_node, kMultipleInputsDatasetOps)) {
|
||||||
|
// For all multiple input datasets, all inputs are datasets themselves
|
||||||
|
for (int i = 0; i < dataset_node->input_size(); ++i) {
|
||||||
|
NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, i);
|
||||||
|
TF_RETURN_IF_ERROR(RecursivelyHandleOp(graph, input_node));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Encountered unsupported op \"", dataset_node->op(),
|
||||||
|
"\" when rewriting the input pipeline graph to use slack in its "
|
||||||
|
"final prefetch transformation.");
|
||||||
|
}
|
||||||
|
|
||||||
Status Slack::OptimizeAndCollectStats(Cluster* cluster,
|
Status Slack::OptimizeAndCollectStats(Cluster* cluster,
|
||||||
const GrapplerItem& item,
|
const GrapplerItem& item,
|
||||||
GraphDef* output,
|
GraphDef* output,
|
||||||
@ -63,30 +125,10 @@ Status Slack::OptimizeAndCollectStats(Cluster* cluster,
|
|||||||
"Expected only one fetch node but there were ", item.fetch.size(), ": ",
|
"Expected only one fetch node but there were ", item.fetch.size(), ": ",
|
||||||
absl::StrJoin(item.fetch, ", "));
|
absl::StrJoin(item.fetch, ", "));
|
||||||
}
|
}
|
||||||
// Walk the input pipeline backwards from the fetch node to find the last
|
// Walks the input pipeline backwards from the fetch node to find the last
|
||||||
// PrefetchDataset node in the pipeline.
|
// PrefetchDataset node in the pipeline.
|
||||||
// TODO(rachelim): This doesn't do the right thing when the "final" prefetch
|
|
||||||
// is nested under an interleave or flat_map. Make this work, similar to
|
|
||||||
// `auto_shard.cc` and `rebatch.cc`.
|
|
||||||
NodeDef* dataset_node = graph.GetNode(item.fetch.at(0));
|
NodeDef* dataset_node = graph.GetNode(item.fetch.at(0));
|
||||||
while (true) {
|
return RecursivelyHandleOp(graph, dataset_node);
|
||||||
if (dataset_node->op() == "PrefetchDataset") {
|
|
||||||
if (HasNodeAttr(*dataset_node, "slack_period")) {
|
|
||||||
(*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_);
|
|
||||||
} else {
|
|
||||||
AddNodeAttr("slack_period", slack_period_, dataset_node);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
if (dataset_node->op() == "Identity" ||
|
|
||||||
(absl::EndsWith(dataset_node->op(), "Dataset") &&
|
|
||||||
dataset_node->input_size() > 0)) {
|
|
||||||
dataset_node = graph_utils::GetInputNode(*dataset_node, graph);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Slack::Feedback(Cluster* cluster, const GrapplerItem& item,
|
void Slack::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -54,6 +55,9 @@ class Slack : public TFDataOptimizerBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int64 slack_period_ = -1;
|
int64 slack_period_ = -1;
|
||||||
|
|
||||||
|
Status RecursivelyHandleOp(const MutableGraphView& graph,
|
||||||
|
NodeDef* dataset_node);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
|
@ -60,6 +60,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
rewriter_config.add_optimizers(kOptimizerName);
|
rewriter_config.add_optimizers(kOptimizerName);
|
||||||
rewriter_config.set_meta_optimizer_iterations(
|
rewriter_config.set_meta_optimizer_iterations(
|
||||||
RewriterConfig_NumIterationsType_ONE);
|
RewriterConfig_NumIterationsType_ONE);
|
||||||
|
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||||
custom_optimizer->set_name(kOptimizerName);
|
custom_optimizer->set_name(kOptimizerName);
|
||||||
auto* custom_optimizations_list =
|
auto* custom_optimizations_list =
|
||||||
|
@ -450,6 +450,23 @@ cuda_py_test(
|
|||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "prefetch_with_slack_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["prefetch_with_slack_test.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/data/ops:iterator_ops",
|
||||||
|
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "reader_dataset_ops_test_base",
|
name = "reader_dataset_ops_test_base",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for `experimental_slack` option."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_v1_only("b/121264236")
|
||||||
|
def testPrefetchWithSlackOption(self):
|
||||||
|
"""Determines slack_period based on num devices attached to iterator."""
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.prefetch(1)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_slack = True
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
dataset, ["/cpu:1", "/cpu:2"])
|
||||||
|
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
|
||||||
|
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||||
|
self.assertIn("slack:slack_period:2",
|
||||||
|
dataset.options()._static_optimization_configs())
|
||||||
|
|
||||||
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
|
with self.test_session(config=config):
|
||||||
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
|
for i in range(0, 10, 2):
|
||||||
|
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||||
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||||
|
self.evaluate(elem_on_1)
|
||||||
|
self.evaluate(elem_on_2)
|
||||||
|
|
||||||
|
def testPrefetchWithSlackOptionWithoutIterator(self):
|
||||||
|
"""Defaults to slack period of 1 without iterator."""
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.prefetch(1)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_slack = True
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
self.assertIn("slack", dataset.options()._static_optimizations())
|
||||||
|
self.assertIn("slack:slack_period:1",
|
||||||
|
dataset.options()._static_optimization_configs())
|
||||||
|
self.assertDatasetProduces(dataset, range(10))
|
||||||
|
|
||||||
|
def testWithPassthroughDataset(self):
|
||||||
|
"""Should still work with a passthrough dataset after prefetch()."""
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.prefetch(1)
|
||||||
|
dataset = dataset.map(lambda x: x + 1)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_slack = True
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
self.assertDatasetProduces(dataset, range(1, 11))
|
||||||
|
|
||||||
|
def testErrorWithoutPrefetch(self):
|
||||||
|
"""The rewrite fails if there is no prefetch() in the pipeline."""
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_slack = True
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
get_next = self.getNext(dataset)
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
def testErrorWithInvalidDataset(self):
|
||||||
|
"""With a nested dataset op after prefetch, the rewrite should fail."""
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.prefetch(1)
|
||||||
|
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensors)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_slack = True
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
get_next = self.getNext(dataset)
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ops.enable_eager_execution(
|
||||||
|
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
||||||
|
test.main()
|
@ -335,45 +335,6 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
|||||||
self.evaluate(elem_on_2)
|
self.evaluate(elem_on_2)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
|
||||||
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/121264236")
|
|
||||||
def testPrefetchWithSlackOption(self):
|
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
|
||||||
options = dataset_ops.Options()
|
|
||||||
options.experimental_slack = True
|
|
||||||
dataset = dataset.with_options(options)
|
|
||||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
|
||||||
dataset, ["/cpu:1", "/cpu:2"])
|
|
||||||
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
|
|
||||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
|
||||||
self.assertIn("slack:slack_period:2",
|
|
||||||
dataset.options()._static_optimization_configs())
|
|
||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
|
||||||
with self.test_session(config=config):
|
|
||||||
self.evaluate(multi_device_iterator.initializer)
|
|
||||||
for i in range(0, 10, 2):
|
|
||||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
|
||||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
|
||||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
|
||||||
self.evaluate(elem_on_1)
|
|
||||||
self.evaluate(elem_on_2)
|
|
||||||
|
|
||||||
def testPrefetchWithSlackOptionWithoutIterator(self):
|
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
|
||||||
options = dataset_ops.Options()
|
|
||||||
options.experimental_slack = True
|
|
||||||
dataset = dataset.with_options(options)
|
|
||||||
self.assertIn("slack", dataset.options()._static_optimizations())
|
|
||||||
self.assertIn("slack:slack_period:1",
|
|
||||||
dataset.options()._static_optimization_configs())
|
|
||||||
|
|
||||||
self.assertDatasetProduces(dataset, range(10))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution(
|
ops.enable_eager_execution(
|
||||||
config=config_pb2.ConfigProto(device_count={"CPU": 3, "GPU": 1}))
|
config=config_pb2.ConfigProto(device_count={"CPU": 3, "GPU": 1}))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user