diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 7593023ff4d..57b0df3cb2d 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -3,6 +3,30 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") +cc_library( + name = "meta_optimizer", + srcs = ["meta_optimizer.cc"], + hdrs = [ + "meta_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core:lib", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:arithmetic_optimizer", + "//tensorflow/core/grappler/optimizers:model_pruner", + "//tensorflow/core/grappler/optimizers:shape_optimizer", + "//tensorflow/core/grappler/optimizers:dependency_optimizer", + "//tensorflow/core/grappler/optimizers:function_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ptr_util", + ] + tf_protos_all(), +) + cc_library( name = "filter_fusion", srcs = ["filter_fusion.cc"], @@ -561,6 +585,7 @@ cc_library( ":map_fusion", ":map_parallelization", ":map_vectorization", + ":meta_optimizer", ":noop_elimination", ":shuffle_and_repeat_fusion", ], diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc new file mode 100644 index 00000000000..0fc0cf43c62 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h" + +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include "tensorflow/core/grappler/optimizers/function_optimizer.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/optimizers/shape_optimizer.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace grappler { + +Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + // Stores the optimized item so far. + GrapplerItem optimized_item = item; + + // Perform optimizations in a meaningful order. + for (const auto& optimization : + {"noop_elimination", + "shuffle_and_repeat_fusion", + "map_fusion", + "filter_fusion", + "map_and_filter_fusion", + "hoist_random_uniform", + "map_parallelization", + "map_and_batch_fusion", + "map_vectorization", + "make_numa_aware", + "latency_all_edges", + "make_sloppy", + "pruning", + "function", + "shape", + "arithmetic", + "dependency"}) { + TF_RETURN_IF_ERROR( + ApplyOptimization(optimization, cluster, &optimized_item)); + } + + // Store the final result of all the optimizations in `output`. + output->Swap(&optimized_item.graph); + return Status::OK(); +} + +Status TFDataMetaOptimizer::ApplyOptimization(const string& name, + Cluster* cluster, + GrapplerItem* item) const { + GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); + + const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name); + if (!optimizer) { + return Status::OK(); + } + + GraphDef result; + (*optimizer)->set_deadline_usec(this->deadline_usec()); + TF_RETURN_IF_ERROR((*optimizer)->Optimize(cluster, *item, &result)); + item->graph.Swap(&result); + + return Status::OK(); +} + +Status TFDataMetaOptimizer::Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { + if (!config) return Status::OK(); + + // Initialize custom tf.data optimizers based on config. + auto& optimizers = config->parameter_map().at("optimizers").list().s(); + for (const auto& optimizer_name : optimizers) { + auto optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); + if (optimizer) { + // None of our data optimizers implement a meaningful Init function. + // This returns an error in case any of them does. + TF_RETURN_IF_ERROR(optimizer->Init()); + enabled_optimizers_[optimizer_name] = std::move(optimizer); + } else { + // This should never happen. + return errors::Internal( + "Tried to register a dataset optimizer that doesn't exist: ", + optimizer_name); + } + } + + // Initialize standard grappler optimizers. + enabled_optimizers_["pruning"] = MakeUnique(); + enabled_optimizers_["function"] = + MakeUnique(RewriterConfig::ON); + enabled_optimizers_["shape"] = MakeUnique(); + enabled_optimizers_["arithmetic"] = MakeUnique(); + enabled_optimizers_["dependency"] = MakeUnique(); + + return Status::OK(); +} + +void TFDataMetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.h b/tensorflow/core/grappler/optimizers/data/meta_optimizer.h new file mode 100644 index 00000000000..c39ddda4cb4 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimizer performs tf.data-specific optimizations by invoking +// other optimizers. +class TFDataMetaOptimizer : public CustomGraphOptimizer { + public: + TFDataMetaOptimizer() = default; + ~TFDataMetaOptimizer() override = default; + + string name() const override { return "tf_data_meta_optimizer"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + absl::flat_hash_map> + enabled_optimizers_; + + // Applies an optimization with the specified name on `item`, and stores + // the result in `item.graph` + Status ApplyOptimization(const string& name, Cluster* cluster, + GrapplerItem* item) const; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 9c50d8050a8..04cc48a0be5 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -40,6 +40,8 @@ namespace tensorflow { namespace data { namespace { +static const char* const kOptimizerName = "tf_data_meta_optimizer"; + // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. class OptimizeDatasetOp : public UnaryDatasetOpKernel { @@ -286,31 +288,6 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; // Create Grappler item. - tensorflow::ConfigProto config; - RewriterConfig& rewriter_config = - *config.mutable_graph_options()->mutable_rewrite_options(); - for (const string& optimization : optimizations_) { - rewriter_config.add_optimizers(optimization); - } - // If no optimizations were specified, supply a non-existent - // optimization to prevent Grappler from applying the default set of - // optimizations as some of them do not work out of the box at the - // moment (e.g. because we have no cost model for dataset ops). - if (optimizations_.empty()) { - rewriter_config.add_optimizers("non-existent"); - } else { - // If we apply custom dataset optimizers, explicitly trigger a subset of - // standard grappler optimizations to further optimize modified dataset - // graphs (e.g. performing constant folding on merged functions, - // removing unused graph nodes) - // TODO(b/118175421): This should be part of the tf.data optimization - // pass manager. - // TODO(b/120437209): Apply `constfold` optimization when it is fixed. - for (const auto& optimizer : - {"pruning", "function", "shape", "arithmetic", "dependency"}) { - rewriter_config.add_optimizers(optimizer); - } - } tensorflow::grappler::ItemConfig item_config; item_config.apply_optimizations = true; std::unique_ptr grappler_item = @@ -319,13 +296,21 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { std::unordered_map device_map; tensorflow::grappler::VirtualCluster cluster(device_map); - // Run optimizer. - if (VLOG_IS_ON(2)) { - LOG(INFO) << "Performing the following optimizations:"; - for (const string& optimization : optimizations_) { - LOG(INFO) << " " << optimization; - } + // Run data optimizer using grappler's meta optimizer. + tensorflow::ConfigProto config; + RewriterConfig& rewriter_config = + *config.mutable_graph_options()->mutable_rewrite_options(); + rewriter_config.add_optimizers(kOptimizerName); + + auto custom_optimizer = rewriter_config.add_custom_optimizers(); + custom_optimizer->set_name(kOptimizerName); + auto* custom_optimizations_list = + (*custom_optimizer->mutable_parameter_map())["optimizers"] + .mutable_list(); + for (const auto& opt : optimizations_) { + custom_optimizations_list->add_s(opt); } + TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( *grappler_item, config, ctx->device(), &cluster, graph_def));