[tf.data] Meta optimizer to manage tf.data optimization passes to perform optimizations in a meaningful order (instead of an arbitrary order as before).
PiperOrigin-RevId: 226944242
This commit is contained in:
parent
742778f04e
commit
a1193c2954
@ -3,6 +3,30 @@ licenses(["notice"]) # Apache 2.0
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
|
||||
|
||||
cc_library(
|
||||
name = "meta_optimizer",
|
||||
srcs = ["meta_optimizer.cc"],
|
||||
hdrs = [
|
||||
"meta_optimizer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "filter_fusion",
|
||||
srcs = ["filter_fusion.cc"],
|
||||
@ -561,6 +585,7 @@ cc_library(
|
||||
":map_fusion",
|
||||
":map_parallelization",
|
||||
":map_vectorization",
|
||||
":meta_optimizer",
|
||||
":noop_elimination",
|
||||
":shuffle_and_repeat_fusion",
|
||||
],
|
||||
|
125
tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
Normal file
125
tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
|
||||
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
// Stores the optimized item so far.
|
||||
GrapplerItem optimized_item = item;
|
||||
|
||||
// Perform optimizations in a meaningful order.
|
||||
for (const auto& optimization :
|
||||
{"noop_elimination",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_fusion",
|
||||
"filter_fusion",
|
||||
"map_and_filter_fusion",
|
||||
"hoist_random_uniform",
|
||||
"map_parallelization",
|
||||
"map_and_batch_fusion",
|
||||
"map_vectorization",
|
||||
"make_numa_aware",
|
||||
"latency_all_edges",
|
||||
"make_sloppy",
|
||||
"pruning",
|
||||
"function",
|
||||
"shape",
|
||||
"arithmetic",
|
||||
"dependency"}) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||
}
|
||||
|
||||
// Store the final result of all the optimizations in `output`.
|
||||
output->Swap(&optimized_item.graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFDataMetaOptimizer::ApplyOptimization(const string& name,
|
||||
Cluster* cluster,
|
||||
GrapplerItem* item) const {
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
|
||||
const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name);
|
||||
if (!optimizer) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GraphDef result;
|
||||
(*optimizer)->set_deadline_usec(this->deadline_usec());
|
||||
TF_RETURN_IF_ERROR((*optimizer)->Optimize(cluster, *item, &result));
|
||||
item->graph.Swap(&result);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFDataMetaOptimizer::Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
||||
if (!config) return Status::OK();
|
||||
|
||||
// Initialize custom tf.data optimizers based on config.
|
||||
auto& optimizers = config->parameter_map().at("optimizers").list().s();
|
||||
for (const auto& optimizer_name : optimizers) {
|
||||
auto optimizer =
|
||||
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
|
||||
if (optimizer) {
|
||||
// None of our data optimizers implement a meaningful Init function.
|
||||
// This returns an error in case any of them does.
|
||||
TF_RETURN_IF_ERROR(optimizer->Init());
|
||||
enabled_optimizers_[optimizer_name] = std::move(optimizer);
|
||||
} else {
|
||||
// This should never happen.
|
||||
return errors::Internal(
|
||||
"Tried to register a dataset optimizer that doesn't exist: ",
|
||||
optimizer_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize standard grappler optimizers.
|
||||
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
||||
enabled_optimizers_["function"] =
|
||||
MakeUnique<FunctionOptimizer>(RewriterConfig::ON);
|
||||
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
||||
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
||||
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TFDataMetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output,
|
||||
double result) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer");
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
56
tensorflow/core/grappler/optimizers/data/meta_optimizer.h
Normal file
56
tensorflow/core/grappler/optimizers/data/meta_optimizer.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This optimizer performs tf.data-specific optimizations by invoking
|
||||
// other optimizers.
|
||||
class TFDataMetaOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
TFDataMetaOptimizer() = default;
|
||||
~TFDataMetaOptimizer() override = default;
|
||||
|
||||
string name() const override { return "tf_data_meta_optimizer"; };
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, std::unique_ptr<GraphOptimizer>>
|
||||
enabled_optimizers_;
|
||||
|
||||
// Applies an optimization with the specified name on `item`, and stores
|
||||
// the result in `item.graph`
|
||||
Status ApplyOptimization(const string& name, Cluster* cluster,
|
||||
GrapplerItem* item) const;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
@ -40,6 +40,8 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
static const char* const kOptimizerName = "tf_data_meta_optimizer";
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
@ -286,31 +288,6 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
|
||||
|
||||
// Create Grappler item.
|
||||
tensorflow::ConfigProto config;
|
||||
RewriterConfig& rewriter_config =
|
||||
*config.mutable_graph_options()->mutable_rewrite_options();
|
||||
for (const string& optimization : optimizations_) {
|
||||
rewriter_config.add_optimizers(optimization);
|
||||
}
|
||||
// If no optimizations were specified, supply a non-existent
|
||||
// optimization to prevent Grappler from applying the default set of
|
||||
// optimizations as some of them do not work out of the box at the
|
||||
// moment (e.g. because we have no cost model for dataset ops).
|
||||
if (optimizations_.empty()) {
|
||||
rewriter_config.add_optimizers("non-existent");
|
||||
} else {
|
||||
// If we apply custom dataset optimizers, explicitly trigger a subset of
|
||||
// standard grappler optimizations to further optimize modified dataset
|
||||
// graphs (e.g. performing constant folding on merged functions,
|
||||
// removing unused graph nodes)
|
||||
// TODO(b/118175421): This should be part of the tf.data optimization
|
||||
// pass manager.
|
||||
// TODO(b/120437209): Apply `constfold` optimization when it is fixed.
|
||||
for (const auto& optimizer :
|
||||
{"pruning", "function", "shape", "arithmetic", "dependency"}) {
|
||||
rewriter_config.add_optimizers(optimizer);
|
||||
}
|
||||
}
|
||||
tensorflow::grappler::ItemConfig item_config;
|
||||
item_config.apply_optimizations = true;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
@ -319,13 +296,21 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
||||
tensorflow::grappler::VirtualCluster cluster(device_map);
|
||||
|
||||
// Run optimizer.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
LOG(INFO) << "Performing the following optimizations:";
|
||||
for (const string& optimization : optimizations_) {
|
||||
LOG(INFO) << " " << optimization;
|
||||
}
|
||||
// Run data optimizer using grappler's meta optimizer.
|
||||
tensorflow::ConfigProto config;
|
||||
RewriterConfig& rewriter_config =
|
||||
*config.mutable_graph_options()->mutable_rewrite_options();
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
auto* custom_optimizations_list =
|
||||
(*custom_optimizer->mutable_parameter_map())["optimizers"]
|
||||
.mutable_list();
|
||||
for (const auto& opt : optimizations_) {
|
||||
custom_optimizations_list->add_s(opt);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
||||
*grappler_item, config, ctx->device(), &cluster, graph_def));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user