[tf.data] Meta optimizer to manage tf.data optimization passes to perform optimizations in a meaningful order (instead of an arbitrary order as before).

PiperOrigin-RevId: 226944242
This commit is contained in:
Rachel Lim 2018-12-26 12:39:23 -08:00 committed by TensorFlower Gardener
parent 742778f04e
commit a1193c2954
4 changed files with 222 additions and 31 deletions

View File

@ -3,6 +3,30 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
name = "meta_optimizer",
srcs = ["meta_optimizer.cc"],
hdrs = [
"meta_optimizer.h",
],
visibility = ["//visibility:public"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core:lib",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
"//tensorflow/core/grappler/optimizers:model_pruner",
"//tensorflow/core/grappler/optimizers:shape_optimizer",
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
"//tensorflow/core/grappler/optimizers:function_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ptr_util",
] + tf_protos_all(),
)
cc_library(
name = "filter_fusion",
srcs = ["filter_fusion.cc"],
@ -561,6 +585,7 @@ cc_library(
":map_fusion",
":map_parallelization",
":map_vectorization",
":meta_optimizer",
":noop_elimination",
":shuffle_and_repeat_fusion",
],

View File

@ -0,0 +1,125 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace grappler {
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
// Stores the optimized item so far.
GrapplerItem optimized_item = item;
// Perform optimizations in a meaningful order.
for (const auto& optimization :
{"noop_elimination",
"shuffle_and_repeat_fusion",
"map_fusion",
"filter_fusion",
"map_and_filter_fusion",
"hoist_random_uniform",
"map_parallelization",
"map_and_batch_fusion",
"map_vectorization",
"make_numa_aware",
"latency_all_edges",
"make_sloppy",
"pruning",
"function",
"shape",
"arithmetic",
"dependency"}) {
TF_RETURN_IF_ERROR(
ApplyOptimization(optimization, cluster, &optimized_item));
}
// Store the final result of all the optimizations in `output`.
output->Swap(&optimized_item.graph);
return Status::OK();
}
Status TFDataMetaOptimizer::ApplyOptimization(const string& name,
Cluster* cluster,
GrapplerItem* item) const {
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name);
if (!optimizer) {
return Status::OK();
}
GraphDef result;
(*optimizer)->set_deadline_usec(this->deadline_usec());
TF_RETURN_IF_ERROR((*optimizer)->Optimize(cluster, *item, &result));
item->graph.Swap(&result);
return Status::OK();
}
Status TFDataMetaOptimizer::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
if (!config) return Status::OK();
// Initialize custom tf.data optimizers based on config.
auto& optimizers = config->parameter_map().at("optimizers").list().s();
for (const auto& optimizer_name : optimizers) {
auto optimizer =
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
if (optimizer) {
// None of our data optimizers implement a meaningful Init function.
// This returns an error in case any of them does.
TF_RETURN_IF_ERROR(optimizer->Init());
enabled_optimizers_[optimizer_name] = std::move(optimizer);
} else {
// This should never happen.
return errors::Internal(
"Tried to register a dataset optimizer that doesn't exist: ",
optimizer_name);
}
}
// Initialize standard grappler optimizers.
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
enabled_optimizers_["function"] =
MakeUnique<FunctionOptimizer>(RewriterConfig::ON);
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
return Status::OK();
}
void TFDataMetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output,
double result) {
// no-op
}
REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer");
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,56 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
// This optimizer performs tf.data-specific optimizations by invoking
// other optimizers.
class TFDataMetaOptimizer : public CustomGraphOptimizer {
public:
TFDataMetaOptimizer() = default;
~TFDataMetaOptimizer() override = default;
string name() const override { return "tf_data_meta_optimizer"; };
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
private:
absl::flat_hash_map<string, std::unique_ptr<GraphOptimizer>>
enabled_optimizers_;
// Applies an optimization with the specified name on `item`, and stores
// the result in `item.graph`
Status ApplyOptimization(const string& name, Cluster* cluster,
GrapplerItem* item) const;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_

View File

@ -40,6 +40,8 @@ namespace tensorflow {
namespace data {
namespace {
static const char* const kOptimizerName = "tf_data_meta_optimizer";
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
@ -286,31 +288,6 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::ConfigProto config;
RewriterConfig& rewriter_config =
*config.mutable_graph_options()->mutable_rewrite_options();
for (const string& optimization : optimizations_) {
rewriter_config.add_optimizers(optimization);
}
// If no optimizations were specified, supply a non-existent
// optimization to prevent Grappler from applying the default set of
// optimizations as some of them do not work out of the box at the
// moment (e.g. because we have no cost model for dataset ops).
if (optimizations_.empty()) {
rewriter_config.add_optimizers("non-existent");
} else {
// If we apply custom dataset optimizers, explicitly trigger a subset of
// standard grappler optimizations to further optimize modified dataset
// graphs (e.g. performing constant folding on merged functions,
// removing unused graph nodes)
// TODO(b/118175421): This should be part of the tf.data optimization
// pass manager.
// TODO(b/120437209): Apply `constfold` optimization when it is fixed.
for (const auto& optimizer :
{"pruning", "function", "shape", "arithmetic", "dependency"}) {
rewriter_config.add_optimizers(optimizer);
}
}
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
@ -319,13 +296,21 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run optimizer.
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Performing the following optimizations:";
for (const string& optimization : optimizations_) {
LOG(INFO) << " " << optimization;
}
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
RewriterConfig& rewriter_config =
*config.mutable_graph_options()->mutable_rewrite_options();
rewriter_config.add_optimizers(kOptimizerName);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =
(*custom_optimizer->mutable_parameter_map())["optimizers"]
.mutable_list();
for (const auto& opt : optimizations_) {
custom_optimizations_list->add_s(opt);
}
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));