diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 7a49e6e2561..ab159a92109 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -59,6 +59,8 @@ namespace data { using TraceMeMetadata = std::vector>; +constexpr char kTFDataFunction[] = "_tf_data_function"; + constexpr int kInfiniteCardinality = -1; constexpr int kUnknownCardinality = -2; diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 519f689b278..0b88f2f6a71 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -640,9 +640,13 @@ cc_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler/optimizers:dependency_optimizer", "//tensorflow/core/grappler/optimizers:function_optimizer", + "//tensorflow/core/grappler/optimizers:loop_optimizer", "//tensorflow/core/grappler/optimizers:model_pruner", + "//tensorflow/core/grappler/optimizers:remapper", "//tensorflow/core/grappler/optimizers:shape_optimizer", + "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core:framework", "//tensorflow/core:lib_internal", "//tensorflow/core:ptr_util", ] + tf_protos_all(), diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index d5308ad31a8..f9ee6f046ef 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -428,8 +429,14 @@ FunctionDef* FuseFunctions( StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, const SetInputFn& set_input, const SetOutputFn& set_output, const SetNodesFn& set_nodes, FunctionDefLibrary* library) { - if (first_function.attr_size() != 0 || second_function.attr_size() != 0) - return nullptr; // Functions with attributes are currently not supported + auto has_attrs = [](const FunctionDef& func) { + return !( + func.attr_size() == 0 || + (func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction))); + }; + if (has_attrs(first_function) || has_attrs(second_function)) { + return nullptr; // Functions with attributes are currently not supported. + } // This function will be used as a clone of second function, having unique // names. diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index 39b59a229df..b8b630c760d 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h" #include "absl/strings/str_split.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" @@ -23,8 +26,11 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" +#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" @@ -56,8 +62,12 @@ constexpr std::array kTFDataOptimizations = { "inject_prefetch"}; // Standard grappler optimizations, in the order we want to perform them. -constexpr std::array kGrapplerOptimizations = { - "pruning", "function", "shape", "arithmetic", "dependency"}; +// The order matches the order in the generic meta optimizer. +constexpr std::array kGrapplerOptimizations = { + "pruning", "function", "common_subgraph_elimination", + "shape", "arithmetic", "layout_optimizer", + "remapper", "loop", "dependency", +}; // Parses a list of string optimizer configurations into a map from // optimizer name -> rewriter config for that optimizer. @@ -116,6 +126,48 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Store the final result of all the optimizations in `output`. output->Swap(&optimized_item.graph); + + // Optimize tf.data user-defined functions. + FunctionLibraryDefinition flib = + FunctionLibraryDefinition(OpRegistry::Global(), output->library()) + .ReachableDefinitions(*output); + const auto producer = output->versions().producer(); + bool optimized_functions = false; + for (const FunctionDef& func : output->library().function()) { + // Skip non tf.data functions. + if (!func.attr().contains(data::kTFDataFunction)) continue; + VLOG(3) << "Optimize function: function=" << func.signature().name(); + optimized_functions = true; + + // Make a GrapplerItem from a FunctionDef. + GrapplerFunctionItem func_item; + TF_RETURN_IF_ERROR( + MakeGrapplerFunctionItem(func, flib, producer, &func_item)); + + GraphDef optimized_func_graph; + TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph)); + + // Function body optimization might have created new functions. Add them to + // the library. + for (const FunctionDef& func_def : + optimized_func_graph.library().function()) { + if (flib.Find(func_def.signature().name()) == nullptr) { + TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def)); + } + } + + // Convert optimized graph back to FunctionDef. + FunctionDef optimized_func; + func_item.SwapFunctionBody(std::move(optimized_func_graph)); + TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func)); + + // Replace optimized function with a new FunctionDef. + TF_RETURN_IF_ERROR( + flib.ReplaceFunction(func.signature().name(), optimized_func)); + } + if (optimized_functions) { + *output->mutable_library() = flib.ToProto(); + } return Status::OK(); } @@ -170,15 +222,26 @@ Status TFDataMetaOptimizer::Init( } } - // Initialize standard grappler optimizers. + // Enable a subset of grappler optimization that are enabled by default. + // + // Layout optimizations are excluded because they assume that ops without + // explicit device assignment will be placed on GPU (if available) but that's + // not the case for operations within tf.data functions. + // + // TODO(b/120437209): Re-enable constant folding. + // + // TODO(jsimsa): Make the set of generic Grappler optimization applied to + // tf.data functions configurable. enabled_optimizers_["pruning"] = MakeUnique(); - enabled_optimizers_["function"] = MakeUnique( - RewriterConfig::ON, /*lower_control_flow=*/true); enabled_optimizers_["shape"] = MakeUnique(); + enabled_optimizers_["remapping"] = MakeUnique(RewriterConfig::ON); enabled_optimizers_["common_subgraph_elimination"] = MakeUnique(); enabled_optimizers_["arithmetic"] = MakeUnique(); enabled_optimizers_["dependency"] = MakeUnique(); + enabled_optimizers_["loop"] = MakeUnique(); + enabled_optimizers_["function"] = MakeUnique( + RewriterConfig::ON, /*lower_control_flow=*/true); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 0c8fa0449f5..1f7ae9f1b44 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_util.h" @@ -90,6 +91,10 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer" || name == "auto_mixed_precision"; } +bool IsTFDataFunction(const FunctionDef& func) { + return func.attr().contains(data::kTFDataFunction); +} + // Creates a function library stub from a real function library: copy only // signatures and attributes of all the function defined in fdef_lib. This stub // can be swapped with real function library in a graph, before passing it to @@ -696,6 +701,9 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; + // Skip tf.data functions as they are optimized by tf.data meta optimizer. + if (IsTFDataFunction(func)) continue; + VLOG(3) << "Optimize function: function=" << func_name << " [" << function_idx++ << " of " << optimized_graph->library().function_size() << "]"; diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc index 7edbe1e8712..821314740a2 100644 --- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc @@ -61,7 +61,6 @@ void AutoShardDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // function optimization and explicitly handle function modifications // for those datasets in the rewrite. OP_REQUIRES_OK(ctx, RewriteDataset(ctx, input, std::move(config_factory), - /*optimize_function_library=*/false, /*record_fingerprint=*/false, output)); } diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index c249d82ee9b..d5d4e8c8f14 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -56,7 +56,6 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, return CreateConfig(optimizations, optimization_configs_); }; Status s = RewriteDataset(ctx, input, std::move(config_factory), - /*optimize_function_library=*/true, /*record_fingerprint=*/true, output); if (errors::IsDeadlineExceeded(s)) { // Ignore DeadlineExceeded as it implies that the attempted rewrite took too diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc index 609c402fd29..6e3f0d36298 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.cc +++ b/tensorflow/core/kernels/data/rewrite_utils.cc @@ -79,8 +79,7 @@ void RemoveFakeSinks(FunctionDef* function_def) { Status ApplyRewrites(OpKernelContext* ctx, const std::function config_factory, - bool optimize_function_library, GraphDef* graph_def, - string* output_node) { + GraphDef* graph_def, string* output_node) { // Add an identity node as the fetch node, otherwise we might get 'placeholder // is both fed and fetched' errors in some cases when using input list with // placeholder dataset nodes. @@ -117,8 +116,6 @@ Status ApplyRewrites(OpKernelContext* ctx, std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef( "graph", meta_graph_def, item_config); - grappler_item->optimization_options().optimize_function_library = - optimize_function_library; std::unordered_map device_map; tensorflow::grappler::VirtualCluster cluster(device_map); @@ -143,8 +140,7 @@ Status ApplyRewrites(OpKernelContext* ctx, Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, std::function config_factory, - bool optimize_function_library, bool record_fingerprint, - DatasetBase** rewritten_input) { + bool record_fingerprint, DatasetBase** rewritten_input) { SerializationContext::Params params; std::vector> input_list; params.input_list = &input_list; @@ -166,9 +162,8 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, } VLOG(3) << "Before graph rewrites: " << graph_def.DebugString(); - TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory, - optimize_function_library, &graph_def, - &output_node)); + TF_RETURN_IF_ERROR( + ApplyRewrites(ctx, config_factory, &graph_def, &output_node)); VLOG(3) << "After graph rewrites: " << graph_def.DebugString(); // Instantiate the optimized input pipeline by running the optimized graph diff --git a/tensorflow/core/kernels/data/rewrite_utils.h b/tensorflow/core/kernels/data/rewrite_utils.h index ebad6d276f3..aed878e79cf 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.h +++ b/tensorflow/core/kernels/data/rewrite_utils.h @@ -27,8 +27,7 @@ namespace data { // Rewrites the input dataset using the given config. Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, std::function config_factory, - bool optimize_function_library, bool record_fingerprint, - DatasetBase** rewritten_input); + bool record_fingerprint, DatasetBase** rewritten_input); } // namespace data } // namespace tensorflow diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index d32cba79124..1411481f0ac 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( default_visibility = ["//tensorflow:internal"], @@ -74,6 +75,28 @@ tf_py_test( ], ) +cuda_py_test( + name = "grappler_test", + size = "medium", + srcs = ["grappler_test.py"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python/data/experimental/ops:optimization_options", + "//tensorflow/python/data/experimental/ops:testing", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "hoist_random_uniform_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/grappler_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/grappler_test.py new file mode 100644 index 00000000000..f8ec7f9036d --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/grappler_test.py @@ -0,0 +1,81 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the generic Grappler optimizations used within tf.data.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test + + +class GrapplerTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate(test_base.default_test_combinations()) + def testConstantFoldingVarLenFeature(self): + example = example_pb2.Example(features=feature_pb2.Features(feature={})) + dataset = dataset_ops.Dataset.from_tensors(example.SerializeToString()) + + def parse_fn(serialized): + features = {"x": parsing_ops.VarLenFeature(dtypes.int64)} + parsed = parsing_ops.parse_single_example(serialized, features) + parsed = parsed["x"].values + + size = array_ops.size(parsed) + value = math_ops.cast(parsed, dtypes.bool) + return control_flow_ops.cond(size > 0, + lambda: array_ops.reshape(value, []), + lambda: array_ops.zeros([], dtypes.bool)) + + dataset = dataset.map(parse_fn) + + self.assertDatasetProduces(dataset, expected_output=[0]) + + @combinations.generate(test_base.default_test_combinations()) + def testLayoutOptimizationConv2D(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + # Compute convolution with input and filter of [1, 1, 1, 1] shape. + # Verify that Grappler doesn't transpose Conv2D data format to NCHW. + dataset = dataset_ops.Dataset.from_tensors((1, 1)) + + def map_function(x, y): + i = math_ops.cast(x, dtypes.float32) + i = array_ops.reshape(i, [1, 1, 1, 1]) + f = math_ops.cast(y, dtypes.float32) + f = array_ops.reshape(f, [1, 1, 1, 1]) + c = nn_ops.conv2d(i, f, strides=[1, 1, 1, 1], padding="VALID") + return array_ops.reshape(c, ()) + + dataset = dataset.map(map_function) + self.assertDatasetProduces(dataset, expected_output=[1]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b3c2e6edbff..2eddeb6eac6 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2818,7 +2818,7 @@ class Options(options_lib.OptionsBase): result.extend( optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access - if self.experimental_deterministic is False: + if self.experimental_deterministic is False: # pylint: disable=g-bool-id-comparison result.append("make_sloppy") if self.experimental_stats and self.experimental_stats.latency_all_edges: result.append("latency_all_edges") @@ -3108,7 +3108,6 @@ class DatasetSpec(type_spec.BatchableTypeSpec): class StructuredFunctionWrapper(object): """A function wrapper that supports structured arguments and return values.""" - # pylint: disable=protected-access def __init__(self, func, transformation_name, @@ -3151,6 +3150,7 @@ class StructuredFunctionWrapper(object): ValueError: If an invalid combination of `dataset`, `input_classes`, `input_shapes`, and `input_types` is passed. """ + # pylint: disable=protected-access if input_structure is None: if dataset is None: if input_classes is None or input_shapes is None or input_types is None: @@ -3271,6 +3271,7 @@ class StructuredFunctionWrapper(object): else: defun_kwargs.update({"func_name": func_name}) + defun_kwargs.update({"_tf_data_function": True}) # Note: _wrapper_helper will apply autograph based on context. @eager_function.defun_with_attributes( @@ -3287,9 +3288,9 @@ class StructuredFunctionWrapper(object): with tracking.resource_tracker_scope(resource_tracker): # TODO(b/141462134): Switch to using garbage collection. self._function = wrapper_fn.get_concrete_function() - if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) + if resource_tracker.resources: _warn_if_collections(transformation_name) @@ -3301,7 +3302,6 @@ class StructuredFunctionWrapper(object): "if the random op has not been provided any seed. Explicitly set " "the seed in the function if this is not the intended behavior." %(outer_graph_seed, func_name), stacklevel=4) - # pylint: enable=protected-access @property def output_structure(self):