diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7a21749faa7..07b89f6fd93 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -974,6 +974,7 @@ tf_cuda_library( "common_runtime/device.h", "common_runtime/device_factory.h", "common_runtime/function.h", + "common_runtime/function_optimization_registry.h", "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "//tensorflow/core/graph:core_cpu_headers", @@ -2471,6 +2472,7 @@ filegroup( "common_runtime/dma_helper.h", "common_runtime/executor.h", "common_runtime/executor_factory.h", + "common_runtime/function_optimization_registry.h", "common_runtime/graph_optimizer.h", "common_runtime/input_colocation_exemption_registry.h", "common_runtime/isolate_placer_inspection_required_ops_pass.h", @@ -2531,6 +2533,7 @@ tf_cuda_library( "common_runtime/executor.cc", "common_runtime/executor_factory.cc", "common_runtime/function.cc", + "common_runtime/function_optimization_registry.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/hierarchical_tree_broadcaster.cc", @@ -3147,6 +3150,10 @@ tf_cc_tests( "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", "common_runtime/dynamic_device_mgr_test.cc", + "common_runtime/function_optimization_registration_test.cc", + "common_runtime/function_optimization_registry_no_pass_test.cc", + "common_runtime/function_optimization_registry_pass_failure_test.cc", + "common_runtime/function_optimization_registry_test.cc", "common_runtime/isolate_placer_inspection_required_ops_pass_test.cc", "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", diff --git a/tensorflow/core/common_runtime/function_optimization_registration_test.cc b/tensorflow/core/common_runtime/function_optimization_registration_test.cc new file mode 100644 index 00000000000..51381622231 --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registration_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class TestFunctionPass : public FunctionOptimizationPass { + public: + static bool ran_; + + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { + ran_ = true; + return Status::OK(); + } +}; + +bool TestFunctionPass::ran_ = false; + +static function_optimization_registration::FunctionOptimizationPassRegistration + register_test_pass(std::make_unique()); + +TEST(FunctionOptimizationPassRegistry, RegisteredPass) { + EXPECT_FALSE(TestFunctionPass::ran_); + + DeviceSet device_set; + ConfigProto config_proto; + Status status = FunctionOptimizationPassRegistry::Global().Run( + device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr, + /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); + + EXPECT_EQ(status, Status::OK()); + EXPECT_TRUE(TestFunctionPass::ran_); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_optimization_registry.cc b/tensorflow/core/common_runtime/function_optimization_registry.cc new file mode 100644 index 00000000000..8d622407c78 --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registry.cc @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function_optimization_registry.h" + +namespace tensorflow { + +void FunctionOptimizationPassRegistry::Init( + std::unique_ptr pass) { + DCHECK(!pass_) << "Only one pass should be set."; + pass_ = std::move(pass); +} + +Status FunctionOptimizationPassRegistry::Run( + const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) { + if (pass_) + TF_RETURN_IF_ERROR(pass_->Run(device_set, config_proto, graph, flib_def, + control_ret_node_names, + control_rets_updated)); + + return Status::OK(); +} + +// static +FunctionOptimizationPassRegistry& FunctionOptimizationPassRegistry::Global() { + static FunctionOptimizationPassRegistry* kGlobalRegistry = + new FunctionOptimizationPassRegistry; + return *kGlobalRegistry; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_optimization_registry.h b/tensorflow/core/common_runtime/function_optimization_registry.h new file mode 100644 index 00000000000..2f9dcf94e77 --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registry.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +// Classes to maintain a static registry of Graph based passes to be applied to +// a function graph. + +namespace tensorflow { + +// A pass to be registered with the FunctionOptimizationPassRegistry. This pass +// takes in a DeviceSet (available devices for executing the Graph), ConfigProto +// (session configuration parameters), Graph (computation), +// FunctionLibraryDefinition (mapping between function names and function +// definitions of the Graph), control ret/target node names (names of nodes that +// must execute but their data outputs, if they have any, are irrelevant), and +// whether control ret nodes (via thier name) were updated. Mutations to the +// Graph and other associated arguments are performed inplace by the pass. +class FunctionOptimizationPass { + public: + virtual ~FunctionOptimizationPass() {} + virtual Status Run(const DeviceSet& device_set, + const ConfigProto& config_proto, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) = 0; +}; + +// A global function optimization pass registry that is used to hold one +// FunctionOptimizationPass. Passes registered to this registry will run before +// passes registered in OptimizationPassRegistry. +class FunctionOptimizationPassRegistry { + public: + // Initializes registry with a pass. Only one pass should be set. An assertion + // will be triggered if the registry already has a pass set and is being + // initialized with another pass. + void Init(std::unique_ptr pass); + + // Runs a pass if the registry contains one. + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated); + + // Returns the global registry of function graph passes. + static FunctionOptimizationPassRegistry& Global(); + + private: + std::unique_ptr pass_; +}; + +namespace function_optimization_registration { + +class FunctionOptimizationPassRegistration { + public: + explicit FunctionOptimizationPassRegistration( + std::unique_ptr pass) { + FunctionOptimizationPassRegistry::Global().Init(std::move(pass)); + } +}; + +} // namespace function_optimization_registration + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ diff --git a/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc b/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc new file mode 100644 index 00000000000..90f91244d62 --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +TEST(FunctionOptimizationPassRegistry, NoPassSet) { + FunctionOptimizationPassRegistry::Global().Init( + std::unique_ptr()); + DeviceSet device_set; + ConfigProto config_proto; + Status status = FunctionOptimizationPassRegistry::Global().Run( + device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr, + /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); + + EXPECT_EQ(status, Status::OK()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_optimization_registry_pass_failure_test.cc b/tensorflow/core/common_runtime/function_optimization_registry_pass_failure_test.cc new file mode 100644 index 00000000000..4adf510849f --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registry_pass_failure_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class FailingFunctionPass : public FunctionOptimizationPass { + public: + static bool ran_; + + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { + ran_ = true; + return errors::Unknown(""); + } +}; + +bool FailingFunctionPass::ran_ = false; + +TEST(FunctionOptimizationPassRegistry, PassWithError) { + EXPECT_FALSE(FailingFunctionPass::ran_); + + FunctionOptimizationPassRegistry::Global().Init( + std::make_unique()); + DeviceSet device_set; + ConfigProto config_proto; + Status status = FunctionOptimizationPassRegistry::Global().Run( + device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr, + /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); + + EXPECT_TRUE(errors::IsUnknown(status)); + EXPECT_TRUE(FailingFunctionPass::ran_); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_optimization_registry_test.cc b/tensorflow/core/common_runtime/function_optimization_registry_test.cc new file mode 100644 index 00000000000..64ca05a9a3a --- /dev/null +++ b/tensorflow/core/common_runtime/function_optimization_registry_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function_optimization_registry.h" + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class PassingFunctionPass : public FunctionOptimizationPass { + public: + static bool ran_; + + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { + ran_ = true; + return Status::OK(); + } +}; + +bool PassingFunctionPass::ran_ = false; + +TEST(FunctionOptimizationPassRegistry, PassNoError) { + EXPECT_FALSE(PassingFunctionPass::ran_); + + FunctionOptimizationPassRegistry::Global().Init( + std::make_unique()); + DeviceSet device_set; + ConfigProto config_proto; + Status status = FunctionOptimizationPassRegistry::Global().Run( + device_set, config_proto, /*graph=*/nullptr, /*flib_def=*/nullptr, + /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); + + EXPECT_EQ(status, Status::OK()); + EXPECT_TRUE(PassingFunctionPass::ran_); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 10c269f18a5..1b35b1d3ee9 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/partitioning_utils.h" #include "tensorflow/core/common_runtime/placer.h" @@ -672,6 +673,26 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( function_name, function_key, ret_node_names.size(), lib_def->ReachableDefinitions(*fdef), std::move(ret_types)); + // Mapping from a function body node name to the control output name. + std::unordered_map node_name_to_control_ret; + + bool control_rets_updated = false; + TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( + device_set_, options.config_proto, &graph, &data->lib_def_, + &control_ret_node_names, &control_rets_updated)); + + if (control_rets_updated) { + // Function graph pass may have resulted in different nodes/node names for + // control rets. + for (const auto& control_ret : control_ret_node_names) { + node_name_to_control_ret.emplace(control_ret, control_ret); + } + } else { + for (const auto& control_ret : fdef->control_ret()) { + node_name_to_control_ret.emplace(control_ret.second, control_ret.first); + } + } + GraphOptimizationPassOptions optimization_options; // TODO(iga): Thread other relevant options from SessionOptions. SessionOptions session_options; @@ -768,12 +789,6 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( } } - // Mapping from a function body node name to the control output name. - std::unordered_map node_name_to_control_ret; - for (const auto& control_ret : fdef->control_ret()) { - node_name_to_control_ret.emplace(control_ret.second, control_ret.first); - } - // We must preserve control returns in each of the function components, // otherwise after function inlining we might prune side-effectful nodes. const auto control_ret =