Add separate registry for function graph optimization passes in function runtime.

This adds a separate registry similar to OptimizationPassRegistry but for invoking function graph optimization passes. Other fields that are necessary but not exposed via OptimizationPassRegistry are exposed in this new registry (e.g. control ret node names). This is also a workaround to moving MLIR dependencies into the runtime as the TensorFlow Graph -> TensorFlow MLIR converters depend on the runtime.

PiperOrigin-RevId: 293692159
Change-Id: If616fe4d8080f08f4d1513429fec222417d5fcfe
This commit is contained in:
Andy Ly 2020-02-06 15:34:02 -08:00 committed by TensorFlower Gardener
parent 9347d87f92
commit d00b2c39ab
8 changed files with 373 additions and 6 deletions

View File

@ -974,6 +974,7 @@ tf_cuda_library(
"common_runtime/device.h",
"common_runtime/device_factory.h",
"common_runtime/function.h",
"common_runtime/function_optimization_registry.h",
"common_runtime/optimization_registry.h",
"common_runtime/shape_refiner.h",
"//tensorflow/core/graph:core_cpu_headers",
@ -2471,6 +2472,7 @@ filegroup(
"common_runtime/dma_helper.h",
"common_runtime/executor.h",
"common_runtime/executor_factory.h",
"common_runtime/function_optimization_registry.h",
"common_runtime/graph_optimizer.h",
"common_runtime/input_colocation_exemption_registry.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
@ -2531,6 +2533,7 @@ tf_cuda_library(
"common_runtime/executor.cc",
"common_runtime/executor_factory.cc",
"common_runtime/function.cc",
"common_runtime/function_optimization_registry.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
"common_runtime/hierarchical_tree_broadcaster.cc",
@ -3147,6 +3150,10 @@ tf_cc_tests(
"common_runtime/device_resolver_local_test.cc",
"common_runtime/device_set_test.cc",
"common_runtime/dynamic_device_mgr_test.cc",
"common_runtime/function_optimization_registration_test.cc",
"common_runtime/function_optimization_registry_no_pass_test.cc",
"common_runtime/function_optimization_registry_pass_failure_test.cc",
"common_runtime/function_optimization_registry_test.cc",
"common_runtime/isolate_placer_inspection_required_ops_pass_test.cc",
"common_runtime/optimization_registry_test.cc",
"common_runtime/pending_counts_test.cc",

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
class TestFunctionPass : public FunctionOptimizationPass {
public:
static bool ran_;
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
ran_ = true;
return Status::OK();
}
};
bool TestFunctionPass::ran_ = false;
static function_optimization_registration::FunctionOptimizationPassRegistration
register_test_pass(std::make_unique<TestFunctionPass>());
TEST(FunctionOptimizationPassRegistry, RegisteredPass) {
EXPECT_FALSE(TestFunctionPass::ran_);
DeviceSet device_set;
ConfigProto config_proto;
Status status = FunctionOptimizationPassRegistry::Global().Run(
device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr,
/*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr);
EXPECT_EQ(status, Status::OK());
EXPECT_TRUE(TestFunctionPass::ran_);
}
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
namespace tensorflow {
void FunctionOptimizationPassRegistry::Init(
std::unique_ptr<FunctionOptimizationPass> pass) {
DCHECK(!pass_) << "Only one pass should be set.";
pass_ = std::move(pass);
}
Status FunctionOptimizationPassRegistry::Run(
const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
if (pass_)
TF_RETURN_IF_ERROR(pass_->Run(device_set, config_proto, graph, flib_def,
control_ret_node_names,
control_rets_updated));
return Status::OK();
}
// static
FunctionOptimizationPassRegistry& FunctionOptimizationPassRegistry::Global() {
static FunctionOptimizationPassRegistry* kGlobalRegistry =
new FunctionOptimizationPassRegistry;
return *kGlobalRegistry;
}
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/protobuf/config.pb.h"
// Classes to maintain a static registry of Graph based passes to be applied to
// a function graph.
namespace tensorflow {
// A pass to be registered with the FunctionOptimizationPassRegistry. This pass
// takes in a DeviceSet (available devices for executing the Graph), ConfigProto
// (session configuration parameters), Graph (computation),
// FunctionLibraryDefinition (mapping between function names and function
// definitions of the Graph), control ret/target node names (names of nodes that
// must execute but their data outputs, if they have any, are irrelevant), and
// whether control ret nodes (via thier name) were updated. Mutations to the
// Graph and other associated arguments are performed inplace by the pass.
class FunctionOptimizationPass {
public:
virtual ~FunctionOptimizationPass() {}
virtual Status Run(const DeviceSet& device_set,
const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) = 0;
};
// A global function optimization pass registry that is used to hold one
// FunctionOptimizationPass. Passes registered to this registry will run before
// passes registered in OptimizationPassRegistry.
class FunctionOptimizationPassRegistry {
public:
// Initializes registry with a pass. Only one pass should be set. An assertion
// will be triggered if the registry already has a pass set and is being
// initialized with another pass.
void Init(std::unique_ptr<FunctionOptimizationPass> pass);
// Runs a pass if the registry contains one.
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated);
// Returns the global registry of function graph passes.
static FunctionOptimizationPassRegistry& Global();
private:
std::unique_ptr<FunctionOptimizationPass> pass_;
};
namespace function_optimization_registration {
class FunctionOptimizationPassRegistration {
public:
explicit FunctionOptimizationPassRegistration(
std::unique_ptr<FunctionOptimizationPass> pass) {
FunctionOptimizationPassRegistry::Global().Init(std::move(pass));
}
};
} // namespace function_optimization_registration
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
TEST(FunctionOptimizationPassRegistry, NoPassSet) {
FunctionOptimizationPassRegistry::Global().Init(
std::unique_ptr<FunctionOptimizationPass>());
DeviceSet device_set;
ConfigProto config_proto;
Status status = FunctionOptimizationPassRegistry::Global().Run(
device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr,
/*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr);
EXPECT_EQ(status, Status::OK());
}
} // namespace tensorflow

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
class FailingFunctionPass : public FunctionOptimizationPass {
public:
static bool ran_;
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
ran_ = true;
return errors::Unknown("");
}
};
bool FailingFunctionPass::ran_ = false;
TEST(FunctionOptimizationPassRegistry, PassWithError) {
EXPECT_FALSE(FailingFunctionPass::ran_);
FunctionOptimizationPassRegistry::Global().Init(
std::make_unique<FailingFunctionPass>());
DeviceSet device_set;
ConfigProto config_proto;
Status status = FunctionOptimizationPassRegistry::Global().Run(
device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr,
/*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr);
EXPECT_TRUE(errors::IsUnknown(status));
EXPECT_TRUE(FailingFunctionPass::ran_);
}
} // namespace tensorflow

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include <memory>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
class PassingFunctionPass : public FunctionOptimizationPass {
public:
static bool ran_;
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
ran_ = true;
return Status::OK();
}
};
bool PassingFunctionPass::ran_ = false;
TEST(FunctionOptimizationPassRegistry, PassNoError) {
EXPECT_FALSE(PassingFunctionPass::ran_);
FunctionOptimizationPassRegistry::Global().Init(
std::make_unique<PassingFunctionPass>());
DeviceSet device_set;
ConfigProto config_proto;
Status status = FunctionOptimizationPassRegistry::Global().Run(
device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr,
/*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr);
EXPECT_EQ(status, Status::OK());
EXPECT_TRUE(PassingFunctionPass::ran_);
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/common_runtime/placer.h"
@ -672,6 +673,26 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
function_name, function_key, ret_node_names.size(),
lib_def->ReachableDefinitions(*fdef), std::move(ret_types));
// Mapping from a function body node name to the control output name.
std::unordered_map<string, string> node_name_to_control_ret;
bool control_rets_updated = false;
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
device_set_, options.config_proto, &graph, &data->lib_def_,
&control_ret_node_names, &control_rets_updated));
if (control_rets_updated) {
// Function graph pass may have resulted in different nodes/node names for
// control rets.
for (const auto& control_ret : control_ret_node_names) {
node_name_to_control_ret.emplace(control_ret, control_ret);
}
} else {
for (const auto& control_ret : fdef->control_ret()) {
node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
}
}
GraphOptimizationPassOptions optimization_options;
// TODO(iga): Thread other relevant options from SessionOptions.
SessionOptions session_options;
@ -768,12 +789,6 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
}
}
// Mapping from a function body node name to the control output name.
std::unordered_map<string, string> node_name_to_control_ret;
for (const auto& control_ret : fdef->control_ret()) {
node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
}
// We must preserve control returns in each of the function components,
// otherwise after function inlining we might prune side-effectful nodes.
const auto control_ret =