[tf.data] Meta optimizer to manage tf.data optimization passes to perform optimizations in a meaningful order (instead of an arbitrary order as before).
PiperOrigin-RevId: 226944242
This commit is contained in:
parent
742778f04e
commit
a1193c2954
@ -3,6 +3,30 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "meta_optimizer",
|
||||||
|
srcs = ["meta_optimizer.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"meta_optimizer.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
|
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||||
|
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:ptr_util",
|
||||||
|
] + tf_protos_all(),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "filter_fusion",
|
name = "filter_fusion",
|
||||||
srcs = ["filter_fusion.cc"],
|
srcs = ["filter_fusion.cc"],
|
||||||
@ -561,6 +585,7 @@ cc_library(
|
|||||||
":map_fusion",
|
":map_fusion",
|
||||||
":map_parallelization",
|
":map_parallelization",
|
||||||
":map_vectorization",
|
":map_vectorization",
|
||||||
|
":meta_optimizer",
|
||||||
":noop_elimination",
|
":noop_elimination",
|
||||||
":shuffle_and_repeat_fusion",
|
":shuffle_and_repeat_fusion",
|
||||||
],
|
],
|
||||||
|
125
tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
Normal file
125
tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* output) {
|
||||||
|
// Stores the optimized item so far.
|
||||||
|
GrapplerItem optimized_item = item;
|
||||||
|
|
||||||
|
// Perform optimizations in a meaningful order.
|
||||||
|
for (const auto& optimization :
|
||||||
|
{"noop_elimination",
|
||||||
|
"shuffle_and_repeat_fusion",
|
||||||
|
"map_fusion",
|
||||||
|
"filter_fusion",
|
||||||
|
"map_and_filter_fusion",
|
||||||
|
"hoist_random_uniform",
|
||||||
|
"map_parallelization",
|
||||||
|
"map_and_batch_fusion",
|
||||||
|
"map_vectorization",
|
||||||
|
"make_numa_aware",
|
||||||
|
"latency_all_edges",
|
||||||
|
"make_sloppy",
|
||||||
|
"pruning",
|
||||||
|
"function",
|
||||||
|
"shape",
|
||||||
|
"arithmetic",
|
||||||
|
"dependency"}) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the final result of all the optimizations in `output`.
|
||||||
|
output->Swap(&optimized_item.graph);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TFDataMetaOptimizer::ApplyOptimization(const string& name,
|
||||||
|
Cluster* cluster,
|
||||||
|
GrapplerItem* item) const {
|
||||||
|
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||||
|
|
||||||
|
const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name);
|
||||||
|
if (!optimizer) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
(*optimizer)->set_deadline_usec(this->deadline_usec());
|
||||||
|
TF_RETURN_IF_ERROR((*optimizer)->Optimize(cluster, *item, &result));
|
||||||
|
item->graph.Swap(&result);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TFDataMetaOptimizer::Init(
|
||||||
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
||||||
|
if (!config) return Status::OK();
|
||||||
|
|
||||||
|
// Initialize custom tf.data optimizers based on config.
|
||||||
|
auto& optimizers = config->parameter_map().at("optimizers").list().s();
|
||||||
|
for (const auto& optimizer_name : optimizers) {
|
||||||
|
auto optimizer =
|
||||||
|
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
|
||||||
|
if (optimizer) {
|
||||||
|
// None of our data optimizers implement a meaningful Init function.
|
||||||
|
// This returns an error in case any of them does.
|
||||||
|
TF_RETURN_IF_ERROR(optimizer->Init());
|
||||||
|
enabled_optimizers_[optimizer_name] = std::move(optimizer);
|
||||||
|
} else {
|
||||||
|
// This should never happen.
|
||||||
|
return errors::Internal(
|
||||||
|
"Tried to register a dataset optimizer that doesn't exist: ",
|
||||||
|
optimizer_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize standard grappler optimizers.
|
||||||
|
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
||||||
|
enabled_optimizers_["function"] =
|
||||||
|
MakeUnique<FunctionOptimizer>(RewriterConfig::ON);
|
||||||
|
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
||||||
|
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
||||||
|
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFDataMetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimize_output,
|
||||||
|
double result) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer");
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
56
tensorflow/core/grappler/optimizers/data/meta_optimizer.h
Normal file
56
tensorflow/core/grappler/optimizers/data/meta_optimizer.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
// This optimizer performs tf.data-specific optimizations by invoking
|
||||||
|
// other optimizers.
|
||||||
|
class TFDataMetaOptimizer : public CustomGraphOptimizer {
|
||||||
|
public:
|
||||||
|
TFDataMetaOptimizer() = default;
|
||||||
|
~TFDataMetaOptimizer() override = default;
|
||||||
|
|
||||||
|
string name() const override { return "tf_data_meta_optimizer"; };
|
||||||
|
|
||||||
|
Status Init(
|
||||||
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
|
||||||
|
|
||||||
|
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* output) override;
|
||||||
|
|
||||||
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::flat_hash_map<string, std::unique_ptr<GraphOptimizer>>
|
||||||
|
enabled_optimizers_;
|
||||||
|
|
||||||
|
// Applies an optimization with the specified name on `item`, and stores
|
||||||
|
// the result in `item.graph`
|
||||||
|
Status ApplyOptimization(const string& name, Cluster* cluster,
|
||||||
|
GrapplerItem* item) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_
|
@ -40,6 +40,8 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static const char* const kOptimizerName = "tf_data_meta_optimizer";
|
||||||
|
|
||||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||||
// description of the following op.
|
// description of the following op.
|
||||||
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||||
@ -286,31 +288,6 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
|
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
|
||||||
|
|
||||||
// Create Grappler item.
|
// Create Grappler item.
|
||||||
tensorflow::ConfigProto config;
|
|
||||||
RewriterConfig& rewriter_config =
|
|
||||||
*config.mutable_graph_options()->mutable_rewrite_options();
|
|
||||||
for (const string& optimization : optimizations_) {
|
|
||||||
rewriter_config.add_optimizers(optimization);
|
|
||||||
}
|
|
||||||
// If no optimizations were specified, supply a non-existent
|
|
||||||
// optimization to prevent Grappler from applying the default set of
|
|
||||||
// optimizations as some of them do not work out of the box at the
|
|
||||||
// moment (e.g. because we have no cost model for dataset ops).
|
|
||||||
if (optimizations_.empty()) {
|
|
||||||
rewriter_config.add_optimizers("non-existent");
|
|
||||||
} else {
|
|
||||||
// If we apply custom dataset optimizers, explicitly trigger a subset of
|
|
||||||
// standard grappler optimizations to further optimize modified dataset
|
|
||||||
// graphs (e.g. performing constant folding on merged functions,
|
|
||||||
// removing unused graph nodes)
|
|
||||||
// TODO(b/118175421): This should be part of the tf.data optimization
|
|
||||||
// pass manager.
|
|
||||||
// TODO(b/120437209): Apply `constfold` optimization when it is fixed.
|
|
||||||
for (const auto& optimizer :
|
|
||||||
{"pruning", "function", "shape", "arithmetic", "dependency"}) {
|
|
||||||
rewriter_config.add_optimizers(optimizer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensorflow::grappler::ItemConfig item_config;
|
tensorflow::grappler::ItemConfig item_config;
|
||||||
item_config.apply_optimizations = true;
|
item_config.apply_optimizations = true;
|
||||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||||
@ -319,13 +296,21 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
||||||
tensorflow::grappler::VirtualCluster cluster(device_map);
|
tensorflow::grappler::VirtualCluster cluster(device_map);
|
||||||
|
|
||||||
// Run optimizer.
|
// Run data optimizer using grappler's meta optimizer.
|
||||||
if (VLOG_IS_ON(2)) {
|
tensorflow::ConfigProto config;
|
||||||
LOG(INFO) << "Performing the following optimizations:";
|
RewriterConfig& rewriter_config =
|
||||||
for (const string& optimization : optimizations_) {
|
*config.mutable_graph_options()->mutable_rewrite_options();
|
||||||
LOG(INFO) << " " << optimization;
|
rewriter_config.add_optimizers(kOptimizerName);
|
||||||
}
|
|
||||||
|
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||||
|
custom_optimizer->set_name(kOptimizerName);
|
||||||
|
auto* custom_optimizations_list =
|
||||||
|
(*custom_optimizer->mutable_parameter_map())["optimizers"]
|
||||||
|
.mutable_list();
|
||||||
|
for (const auto& opt : optimizations_) {
|
||||||
|
custom_optimizations_list->add_s(opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
||||||
*grappler_item, config, ctx->device(), &cluster, graph_def));
|
*grappler_item, config, ctx->device(), &cluster, graph_def));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user