[tf.data] Improvements to static optimizations.
This CL: 1) Introduces a mechanism for identifying function graphs associated with tf.data user defined functions and uses it to exclude such graphs from Grappler optimizations applied by TensorFlow optimizer to the entire function library runtime (in graph mode). 2) Extends the set of Grappler optimizations applied by tf.data optimizer to function graphs associated with tf.data user-defined functions to match the default TensorFlow optimizer. 3) Introduces a new tf.data rewrite, which sets the device of Conv2D without explicit device assignment to CPU. The layout optimization assumes that all ops without explicit device assignment will be placed on GPU (if possible), which is not true for ops without tf.data user-defined function and result in incorrect layout optimization. PiperOrigin-RevId: 306538159 Change-Id: I963831533b462dc1fb96406caa811c082cdb7125
This commit is contained in:
parent
76079c00ac
commit
3ec28a7ee3
@ -59,6 +59,8 @@ namespace data {
|
||||
|
||||
using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
|
||||
|
||||
constexpr char kTFDataFunction[] = "_tf_data_function";
|
||||
|
||||
constexpr int kInfiniteCardinality = -1;
|
||||
constexpr int kUnknownCardinality = -2;
|
||||
|
||||
|
@ -640,9 +640,13 @@ cc_library(
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:loop_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
"//tensorflow/core/grappler/optimizers:remapper",
|
||||
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
] + tf_protos_all(),
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
@ -428,8 +429,14 @@ FunctionDef* FuseFunctions(
|
||||
StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
|
||||
const SetInputFn& set_input, const SetOutputFn& set_output,
|
||||
const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
|
||||
if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
|
||||
return nullptr; // Functions with attributes are currently not supported
|
||||
auto has_attrs = [](const FunctionDef& func) {
|
||||
return !(
|
||||
func.attr_size() == 0 ||
|
||||
(func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction)));
|
||||
};
|
||||
if (has_attrs(first_function) || has_attrs(second_function)) {
|
||||
return nullptr; // Functions with attributes are currently not supported.
|
||||
}
|
||||
|
||||
// This function will be used as a clone of second function, having unique
|
||||
// names.
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
@ -23,8 +26,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
@ -56,8 +62,12 @@ constexpr std::array<const char*, 16> kTFDataOptimizations = {
|
||||
"inject_prefetch"};
|
||||
|
||||
// Standard grappler optimizations, in the order we want to perform them.
|
||||
constexpr std::array<const char*, 5> kGrapplerOptimizations = {
|
||||
"pruning", "function", "shape", "arithmetic", "dependency"};
|
||||
// The order matches the order in the generic meta optimizer.
|
||||
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
|
||||
"pruning", "function", "common_subgraph_elimination",
|
||||
"shape", "arithmetic", "layout_optimizer",
|
||||
"remapper", "loop", "dependency",
|
||||
};
|
||||
|
||||
// Parses a list of string optimizer configurations into a map from
|
||||
// optimizer name -> rewriter config for that optimizer.
|
||||
@ -116,6 +126,48 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Store the final result of all the optimizations in `output`.
|
||||
output->Swap(&optimized_item.graph);
|
||||
|
||||
// Optimize tf.data user-defined functions.
|
||||
FunctionLibraryDefinition flib =
|
||||
FunctionLibraryDefinition(OpRegistry::Global(), output->library())
|
||||
.ReachableDefinitions(*output);
|
||||
const auto producer = output->versions().producer();
|
||||
bool optimized_functions = false;
|
||||
for (const FunctionDef& func : output->library().function()) {
|
||||
// Skip non tf.data functions.
|
||||
if (!func.attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func.signature().name();
|
||||
optimized_functions = true;
|
||||
|
||||
// Make a GrapplerItem from a FunctionDef.
|
||||
GrapplerFunctionItem func_item;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
|
||||
|
||||
GraphDef optimized_func_graph;
|
||||
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
|
||||
|
||||
// Function body optimization might have created new functions. Add them to
|
||||
// the library.
|
||||
for (const FunctionDef& func_def :
|
||||
optimized_func_graph.library().function()) {
|
||||
if (flib.Find(func_def.signature().name()) == nullptr) {
|
||||
TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert optimized graph back to FunctionDef.
|
||||
FunctionDef optimized_func;
|
||||
func_item.SwapFunctionBody(std::move(optimized_func_graph));
|
||||
TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
|
||||
|
||||
// Replace optimized function with a new FunctionDef.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flib.ReplaceFunction(func.signature().name(), optimized_func));
|
||||
}
|
||||
if (optimized_functions) {
|
||||
*output->mutable_library() = flib.ToProto();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -170,15 +222,26 @@ Status TFDataMetaOptimizer::Init(
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize standard grappler optimizers.
|
||||
// Enable a subset of grappler optimization that are enabled by default.
|
||||
//
|
||||
// Layout optimizations are excluded because they assume that ops without
|
||||
// explicit device assignment will be placed on GPU (if available) but that's
|
||||
// not the case for operations within tf.data functions.
|
||||
//
|
||||
// TODO(b/120437209): Re-enable constant folding.
|
||||
//
|
||||
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
|
||||
// tf.data functions configurable.
|
||||
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
||||
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
|
||||
RewriterConfig::ON, /*lower_control_flow=*/true);
|
||||
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
||||
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
|
||||
enabled_optimizers_["common_subgraph_elimination"] =
|
||||
MakeUnique<CommonSubgraphElimination>();
|
||||
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
||||
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
||||
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
|
||||
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
|
||||
RewriterConfig::ON, /*lower_control_flow=*/true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
@ -90,6 +91,10 @@ bool IsRunOnceOptimizer(const string& name) {
|
||||
name == "loop_optimizer" || name == "auto_mixed_precision";
|
||||
}
|
||||
|
||||
bool IsTFDataFunction(const FunctionDef& func) {
|
||||
return func.attr().contains(data::kTFDataFunction);
|
||||
}
|
||||
|
||||
// Creates a function library stub from a real function library: copy only
|
||||
// signatures and attributes of all the function defined in fdef_lib. This stub
|
||||
// can be swapped with real function library in a graph, before passing it to
|
||||
@ -696,6 +701,9 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||
// the function optimizer, before we can optimize function body.
|
||||
if (IsParametrized(func)) continue;
|
||||
|
||||
// Skip tf.data functions as they are optimized by tf.data meta optimizer.
|
||||
if (IsTFDataFunction(func)) continue;
|
||||
|
||||
VLOG(3) << "Optimize function: function=" << func_name << " ["
|
||||
<< function_idx++ << " of "
|
||||
<< optimized_graph->library().function_size() << "]";
|
||||
|
@ -61,7 +61,6 @@ void AutoShardDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
// function optimization and explicitly handle function modifications
|
||||
// for those datasets in the rewrite.
|
||||
OP_REQUIRES_OK(ctx, RewriteDataset(ctx, input, std::move(config_factory),
|
||||
/*optimize_function_library=*/false,
|
||||
/*record_fingerprint=*/false, output));
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,6 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
return CreateConfig(optimizations, optimization_configs_);
|
||||
};
|
||||
Status s = RewriteDataset(ctx, input, std::move(config_factory),
|
||||
/*optimize_function_library=*/true,
|
||||
/*record_fingerprint=*/true, output);
|
||||
if (errors::IsDeadlineExceeded(s)) {
|
||||
// Ignore DeadlineExceeded as it implies that the attempted rewrite took too
|
||||
|
@ -79,8 +79,7 @@ void RemoveFakeSinks(FunctionDef* function_def) {
|
||||
|
||||
Status ApplyRewrites(OpKernelContext* ctx,
|
||||
const std::function<RewriterConfig(void)> config_factory,
|
||||
bool optimize_function_library, GraphDef* graph_def,
|
||||
string* output_node) {
|
||||
GraphDef* graph_def, string* output_node) {
|
||||
// Add an identity node as the fetch node, otherwise we might get 'placeholder
|
||||
// is both fed and fetched' errors in some cases when using input list with
|
||||
// placeholder dataset nodes.
|
||||
@ -117,8 +116,6 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||
"graph", meta_graph_def, item_config);
|
||||
grappler_item->optimization_options().optimize_function_library =
|
||||
optimize_function_library;
|
||||
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
||||
tensorflow::grappler::VirtualCluster cluster(device_map);
|
||||
|
||||
@ -143,8 +140,7 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
||||
|
||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::function<RewriterConfig(void)> config_factory,
|
||||
bool optimize_function_library, bool record_fingerprint,
|
||||
DatasetBase** rewritten_input) {
|
||||
bool record_fingerprint, DatasetBase** rewritten_input) {
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
@ -166,9 +162,8 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
}
|
||||
|
||||
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
||||
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
|
||||
optimize_function_library, &graph_def,
|
||||
&output_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ApplyRewrites(ctx, config_factory, &graph_def, &output_node));
|
||||
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
|
||||
|
||||
// Instantiate the optimized input pipeline by running the optimized graph
|
||||
|
@ -27,8 +27,7 @@ namespace data {
|
||||
// Rewrites the input dataset using the given config.
|
||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::function<RewriterConfig(void)> config_factory,
|
||||
bool optimize_function_library, bool record_fingerprint,
|
||||
DatasetBase** rewritten_input);
|
||||
bool record_fingerprint, DatasetBase** rewritten_input);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -74,6 +75,28 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "grappler_test",
|
||||
size = "medium",
|
||||
srcs = ["grappler_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python/data/experimental/ops:optimization_options",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "hoist_random_uniform_test",
|
||||
size = "small",
|
||||
|
@ -0,0 +1,81 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the generic Grappler optimizations used within tf.data."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GrapplerTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConstantFoldingVarLenFeature(self):
|
||||
example = example_pb2.Example(features=feature_pb2.Features(feature={}))
|
||||
dataset = dataset_ops.Dataset.from_tensors(example.SerializeToString())
|
||||
|
||||
def parse_fn(serialized):
|
||||
features = {"x": parsing_ops.VarLenFeature(dtypes.int64)}
|
||||
parsed = parsing_ops.parse_single_example(serialized, features)
|
||||
parsed = parsed["x"].values
|
||||
|
||||
size = array_ops.size(parsed)
|
||||
value = math_ops.cast(parsed, dtypes.bool)
|
||||
return control_flow_ops.cond(size > 0,
|
||||
lambda: array_ops.reshape(value, []),
|
||||
lambda: array_ops.zeros([], dtypes.bool))
|
||||
|
||||
dataset = dataset.map(parse_fn)
|
||||
|
||||
self.assertDatasetProduces(dataset, expected_output=[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testLayoutOptimizationConv2D(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
# Compute convolution with input and filter of [1, 1, 1, 1] shape.
|
||||
# Verify that Grappler doesn't transpose Conv2D data format to NCHW.
|
||||
dataset = dataset_ops.Dataset.from_tensors((1, 1))
|
||||
|
||||
def map_function(x, y):
|
||||
i = math_ops.cast(x, dtypes.float32)
|
||||
i = array_ops.reshape(i, [1, 1, 1, 1])
|
||||
f = math_ops.cast(y, dtypes.float32)
|
||||
f = array_ops.reshape(f, [1, 1, 1, 1])
|
||||
c = nn_ops.conv2d(i, f, strides=[1, 1, 1, 1], padding="VALID")
|
||||
return array_ops.reshape(c, ())
|
||||
|
||||
dataset = dataset.map(map_function)
|
||||
self.assertDatasetProduces(dataset, expected_output=[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -2818,7 +2818,7 @@ class Options(options_lib.OptionsBase):
|
||||
result.extend(
|
||||
optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access
|
||||
|
||||
if self.experimental_deterministic is False:
|
||||
if self.experimental_deterministic is False: # pylint: disable=g-bool-id-comparison
|
||||
result.append("make_sloppy")
|
||||
if self.experimental_stats and self.experimental_stats.latency_all_edges:
|
||||
result.append("latency_all_edges")
|
||||
@ -3108,7 +3108,6 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
|
||||
class StructuredFunctionWrapper(object):
|
||||
"""A function wrapper that supports structured arguments and return values."""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def __init__(self,
|
||||
func,
|
||||
transformation_name,
|
||||
@ -3151,6 +3150,7 @@ class StructuredFunctionWrapper(object):
|
||||
ValueError: If an invalid combination of `dataset`, `input_classes`,
|
||||
`input_shapes`, and `input_types` is passed.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if input_structure is None:
|
||||
if dataset is None:
|
||||
if input_classes is None or input_shapes is None or input_types is None:
|
||||
@ -3271,6 +3271,7 @@ class StructuredFunctionWrapper(object):
|
||||
|
||||
else:
|
||||
defun_kwargs.update({"func_name": func_name})
|
||||
defun_kwargs.update({"_tf_data_function": True})
|
||||
|
||||
# Note: _wrapper_helper will apply autograph based on context.
|
||||
@eager_function.defun_with_attributes(
|
||||
@ -3287,9 +3288,9 @@ class StructuredFunctionWrapper(object):
|
||||
with tracking.resource_tracker_scope(resource_tracker):
|
||||
# TODO(b/141462134): Switch to using garbage collection.
|
||||
self._function = wrapper_fn.get_concrete_function()
|
||||
|
||||
if add_to_graph:
|
||||
self._function.add_to_graph(ops.get_default_graph())
|
||||
|
||||
if resource_tracker.resources:
|
||||
_warn_if_collections(transformation_name)
|
||||
|
||||
@ -3301,7 +3302,6 @@ class StructuredFunctionWrapper(object):
|
||||
"if the random op has not been provided any seed. Explicitly set "
|
||||
"the seed in the function if this is not the intended behavior."
|
||||
%(outer_graph_seed, func_name), stacklevel=4)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
def output_structure(self):
|
||||
|
Loading…
Reference in New Issue
Block a user