diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 8a4f80ba3bc..c6b0088d8d0 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -157,6 +157,7 @@ filegroup( "eval_const_tensor.h", "function.h", "function_body.h", + "function_def_utils.h", "function_utils.h", "graph_constructor.h", "graph_def_builder_util.h", @@ -399,6 +400,50 @@ cc_library( ], ) +cc_library( + name = "colocation_graph", + srcs = [ + "colocation_graph.cc", + "inspecting_placer.cc", + ], + hdrs = [ + "colocation_graph.h", + "inspecting_placer.h", + ], + copts = tf_copts(), + deps = [ + ":composite_device", + ":device", + ":device_set", + ":function_body", + ":function_def_utils", + ":input_colocation_exemption_registry", + ":partitioning_utils", + ":placer_inspection_required_ops_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "composite_device", + srcs = ["composite_device.cc"], + hdrs = ["composite_device.h"], + copts = tf_copts(), + deps = [ + ":device", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "constant_folding", srcs = ["constant_folding.cc"], @@ -574,6 +619,50 @@ cc_library( deps = ["//tensorflow/core:framework"], ) +cc_library( + name = "function", + srcs = [ + "function.cc", + "process_function_library_runtime.cc", + ], + hdrs = [ + "function.h", + "process_function_library_runtime.h", + ], + copts = tf_copts(), + deps = [ + ":device", + ":device_mgr", + ":device_set", + ":executor", + ":executor_factory", + ":function_body", + ":function_def_utils", + ":function_optimization_registry", + ":function_utils", + ":gradients", + ":graph_constructor", + ":graph_optimizer", + ":inline_function_utils", + ":memory_types", + ":optimization_registry", + ":partitioning_utils", + ":placer", + ":process_util", + ":rendezvous_mgr", + ":rendezvous_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "function_body", srcs = ["function_body.cc"], @@ -599,10 +688,25 @@ cc_library( ], ) +cc_library( + name = "function_def_utils", + srcs = ["function_def_utils.cc"], + hdrs = ["function_def_utils.h"], + copts = tf_copts(), + deps = [ + ":function_body", + ":graph_constructor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "function_utils", srcs = ["function_utils.cc"], hdrs = ["function_utils.h"], + copts = tf_copts(), deps = [ ":function_body", "//tensorflow/core:framework", @@ -777,6 +881,21 @@ cc_library( ], ) +cc_library( + name = "isolate_placer_inspection_required_ops_pass", + srcs = ["isolate_placer_inspection_required_ops_pass.cc"], + hdrs = ["isolate_placer_inspection_required_ops_pass.h"], + copts = tf_copts(), + deps = [ + ":optimization_registry", + ":placer_inspection_required_ops_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + cc_library( name = "local_device", srcs = ["local_device.cc"], @@ -803,6 +922,81 @@ cc_library( ], ) +cc_library( + name = "lower_case_op", + srcs = ["lower_case_op.cc"], + hdrs = ["lower_case_op.h"], + deps = [ + ":inline_function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "lower_function_call_op", + srcs = ["lower_function_call_op.cc"], + hdrs = ["lower_function_call_op.h"], + copts = tf_copts(), + deps = [ + ":function_def_utils", + ":inline_function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "lower_functional_ops", + srcs = ["lower_functional_ops.cc"], + hdrs = ["lower_functional_ops.h"], + copts = tf_copts(), + deps = [ + ":function_utils", + ":inline_function_utils", + ":lower_case_op", + ":lower_function_call_op", + ":lower_if_op", + ":lower_while_op", + ":optimization_registry", + ":session_options", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "lower_if_op", + srcs = ["lower_if_op.cc"], + hdrs = ["lower_if_op.h"], + copts = tf_copts(), + deps = [ + ":inline_function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "lower_while_op", + srcs = ["lower_while_op.cc"], + hdrs = ["lower_while_op.h"], + copts = tf_copts(), + deps = [ + ":inline_function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "memory_types", srcs = ["memory_types.cc"], @@ -918,6 +1112,37 @@ cc_library( ], ) +cc_library( + name = "placer", + srcs = ["placer.cc"], + hdrs = ["placer.h"], + copts = tf_copts(), + deps = [ + ":colocation_graph", + ":device", + ":device_set", + ":session_options", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "placer_inspection_required_ops_utils", + srcs = ["placer_inspection_required_ops_utils.cc"], + hdrs = ["placer_inspection_required_ops_utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "process_state", srcs = ["process_state.cc"], @@ -1289,69 +1514,6 @@ cc_library( alwayslink = 1, ) -tf_cuda_library( - name = "core_cpu_rump_impl", - srcs = [ - "colocation_graph.cc", - "composite_device.cc", - "function.cc", - "inspecting_placer.cc", - "isolate_placer_inspection_required_ops_pass.cc", - "lower_case_op.cc", - "lower_function_call_op.cc", - "lower_functional_ops.cc", - "lower_if_op.cc", - "lower_while_op.cc", - "placer.cc", - "placer_inspection_required_ops_utils.cc", - "placer_inspection_required_ops_utils.h", - "process_function_library_runtime.cc", - ], - hdrs = [":core_cpu_lib_headers"], - copts = tf_copts(), - deps = [ - ":device", - ":entry", - ":executor", - ":executor_factory", - ":function_body", - ":function_optimization_registry", - ":gradients", - ":graph_constructor", - ":graph_optimizer", - ":graph_view", - ":local_executor_params", - ":immutable_executor_state", - ":inline_function_utils", - ":input_colocation_exemption_registry", - ":pending_counts", - ":propagator_debug_utils", - ":propagator_state", - ":session_options", - ":simple_propagator_state", - ":single_threaded_cpu_device", - "//tensorflow/core:graph", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", - "//tensorflow/core/public:version", - "//tensorflow/core/grappler/utils:functions", - "//tensorflow/core/profiler/lib:annotated_traceme", - "//tensorflow/core/profiler/lib:scoped_annotation", - "//tensorflow/core/profiler/lib:traceme", - ] + mkl_deps(), - alwayslink = 1, -) - tf_cuda_library( name = "core_cpu_impl", hdrs = [":core_cpu_lib_headers"], @@ -1367,7 +1529,6 @@ tf_cuda_library( ":collective_rma_local", ":collective_util", ":copy_tensor", - ":core_cpu_rump_impl", ":costmodel_manager", ":debugger_state_interface", ":device", @@ -1376,11 +1537,14 @@ tf_cuda_library( ":device_resolver_local", ":device_set", ":entry", + ":function", ":graph_def_builder_util", ":graph_view", ":hierarchical_tree_broadcaster", ":input_colocation_exemption_registry", + ":isolate_placer_inspection_required_ops_pass", ":local_device", + ":lower_functional_ops", ":memory_types", ":mkl_cpu_allocator", ":mkl_layout_pass", @@ -1389,6 +1553,7 @@ tf_cuda_library( ":parallel_concat_optimizer", ":partitioning_utils", ":pending_counts", + ":placer", ":pool_allocator", ":process_state", ":process_util", diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index f731edcb2c4..0a13f973106 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/common_runtime/inspecting_placer.h" #include "tensorflow/core/common_runtime/partitioning_utils.h" diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 167ff9b97e3..0df10490ef1 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/function.h" @@ -1395,38 +1396,4 @@ std::unique_ptr SymbolicGradient(const FunctionBody& f) { return SymbolicGradientHelper(f).Compute(); } -Status FunctionDefToBodyHelper( - const FunctionDef& fdef, const AttrSlice& attrs, - const FunctionLibraryDefinition* const lib_def, - const std::function& get_func_sig, - std::unique_ptr* fbody) { - // Instantiates the function template into a graph def. - InstantiationResult result; - TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); - - std::unique_ptr graph(new Graph(lib_def)); - GraphConstructorOptions opts; - opts.allow_internal_ops = true; - opts.expect_device_spec = false; - TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); - - // Call BuildControlFlowInfo to validate that this function body has - // well-formed control flow. - std::vector dummy; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); - - *fbody = absl::make_unique(fdef, result.arg_types, - result.ret_types, graph.release()); - return Status::OK(); -} - -Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, - const FunctionLibraryDefinition* lib_def, - std::unique_ptr* fbody) { - const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { - return lib_def->LookUpOpDef(op, sig); - }; - return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); -} - } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index af4b333aca0..e75e5c82a40 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/inline_function_utils.h" @@ -80,26 +81,6 @@ std::unique_ptr NewFunctionLibraryRuntime( // TODO(zhifengc): Asks math expert to say the comment again. std::unique_ptr SymbolicGradient(const FunctionBody& f); -// Returns true iff `n` represents a function call. `n` can be a native -// function call (n.type_string() is the function name), -// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which -// has been deprecated for a while). -bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); - -// Instantiates FunctionDef into a graph. Set *fbody to point to the -// FunctionBody that holds the instantiated FunctionDef. -Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, - const FunctionLibraryDefinition* lib_def, - std::unique_ptr* fbody); - -// Instantiates FunctionDef into a graph. Set *fbody to point to the -// FunctionBody that holds the instantiated FunctionDef. Use custom function -// signature lookup, in case instantiated function is not in the 'lib_def'. -Status FunctionDefToBodyHelper( - const FunctionDef& fdef, const AttrSlice& attrs, - const FunctionLibraryDefinition* lib_def, - const std::function& get_func_sig, - std::unique_ptr* fbody); } // end namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc new file mode 100644 index 00000000000..b880a5488f9 --- /dev/null +++ b/tensorflow/core/common_runtime/function_def_utils.cc @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function_def_utils.h" + +#include + +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +Status FunctionDefToBodyHelper( + const FunctionDef& fdef, const AttrSlice& attrs, + const FunctionLibraryDefinition* const lib_def, + const std::function& get_func_sig, + std::unique_ptr* fbody) { + // Instantiates the function template into a graph def. + InstantiationResult result; + TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); + + std::unique_ptr graph(new Graph(lib_def)); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = false; + TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); + + // Call BuildControlFlowInfo to validate that this function body has + // well-formed control flow. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); + + *fbody = absl::make_unique(fdef, result.arg_types, + result.ret_types, graph.release()); + return Status::OK(); +} + +Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + std::unique_ptr* fbody) { + const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { + return lib_def->LookUpOpDef(op, sig); + }; + return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_def_utils.h b/tensorflow/core/common_runtime/function_def_utils.h new file mode 100644 index 00000000000..f269cc6a608 --- /dev/null +++ b/tensorflow/core/common_runtime/function_def_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ + +#include +#include + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class AttrSlice; +struct FunctionBody; +class FunctionDef; +class FunctionLibraryDefinition; +class OpDef; + +// Instantiates FunctionDef into a graph. Set *fbody to point to the +// FunctionBody that holds the instantiated FunctionDef. +Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + std::unique_ptr* fbody); + +// Instantiates FunctionDef into a graph. Set *fbody to point to the +// FunctionBody that holds the instantiated FunctionDef. Use custom function +// signature lookup, in case instantiated function is not in the 'lib_def'. +Status FunctionDefToBodyHelper( + const FunctionDef& fdef, const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + const std::function& get_func_sig, + std::unique_ptr* fbody); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index 0dba49a0510..a074942629d 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -43,6 +43,11 @@ limitations under the License. namespace tensorflow { +/*static*/ constexpr const char* const + LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; +/*static*/ constexpr const char* const + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; + namespace { // A few string constant used throughout this module. static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; diff --git a/tensorflow/core/common_runtime/inline_function_utils.h b/tensorflow/core/common_runtime/inline_function_utils.h index bc873a3fc60..1469885ccda 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.h +++ b/tensorflow/core/common_runtime/inline_function_utils.h @@ -231,6 +231,13 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions()); } +struct LowerFunctionalOpsConstants { + static constexpr const char* const kLowerUsingSwitchMergeAttr = + "_lower_using_switch_merge"; + static constexpr const char* const kLowerAsMultiDeviceFunctionAttr = + "_lower_as_multi_device_function"; +}; + } // end namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc index 2dd4eaff303..468ab37bcdb 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.cc +++ b/tensorflow/core/common_runtime/inspecting_placer.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/core/common_runtime/colocation_graph.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/core/common_runtime/lower_case_op.cc b/tensorflow/core/common_runtime/lower_case_op.cc index 24ca8a94b85..a13c55d5aa5 100644 --- a/tensorflow/core/common_runtime/lower_case_op.cc +++ b/tensorflow/core/common_runtime/lower_case_op.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_case_op.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -29,7 +28,7 @@ namespace { using NodeOut = NodeBuilder::NodeOut; constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; // Convenience builder to make it easy to construct a case with a single // function call in each branch. This first converts the Case node diff --git a/tensorflow/core/common_runtime/lower_case_op.h b/tensorflow/core/common_runtime/lower_case_op.h index 9148f43c6c1..110ac20a929 100644 --- a/tensorflow/core/common_runtime/lower_case_op.h +++ b/tensorflow/core/common_runtime/lower_case_op.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class Graph; +class Node; + // Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes. Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable); diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc index b1b657c9c22..ed72d6d720b 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op.cc @@ -16,9 +16,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_function_call_op.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/inline_function_utils.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_node_util.h" @@ -30,7 +29,7 @@ using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; bool LowerAsMultiDeviceFunction(const Node* n) { if (n->IsPartitionedCall()) return true; diff --git a/tensorflow/core/common_runtime/lower_function_call_op.h b/tensorflow/core/common_runtime/lower_function_call_op.h index 6a418a92822..89ce6b28220 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.h +++ b/tensorflow/core/common_runtime/lower_function_call_op.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_ -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class FunctionLibraryDefinition; +class Graph; +class Node; + // Replaces function call node `n` with its function body. Uses // InlineFunctionBody from `common_runtime/function.{h,cc}`. If function // inlining is not possible or safe (see ValidateInlining), leaves the graph in diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 8c99fce17d5..7bfc36c14fc 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_functional_ops.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/lower_case_op.h" #include "tensorflow/core/common_runtime/lower_function_call_op.h" #include "tensorflow/core/common_runtime/lower_if_op.h" @@ -27,17 +28,12 @@ limitations under the License. namespace tensorflow { -/*static*/ constexpr const char* const - LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr; -/*static*/ constexpr const char* const - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; - namespace { constexpr const char* const kLowerUsingSwitchMergeAttr = - LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr; + LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; constexpr const char* const kTpuReplicateAttr = "_tpu_replicate"; constexpr const char* const kXlaClusterAttr = "_xla_compile_id"; @@ -173,7 +169,7 @@ Status LowerFunctionalOpsPass::Run( DCHECK(!lower_control_flow(n)) << "Node " << FormatNodeForError(*n) << " of type " << n->type_string() << " has '" - << LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr + << LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr << "' attr set but it does not support lowering.\n"; } } diff --git a/tensorflow/core/common_runtime/lower_functional_ops.h b/tensorflow/core/common_runtime/lower_functional_ops.h index 84d15a11572..32b6a450f1c 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.h +++ b/tensorflow/core/common_runtime/lower_functional_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_ #include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -38,9 +39,9 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; static constexpr const char* const kLowerUsingSwitchMergeAttr = - "_lower_using_switch_merge"; + LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; static constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - "_lower_as_multi_device_function"; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; private: // If defined use the value to control if functional ops must be fetchable diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 9b1d2b8e270..5cde4f9049c 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_if_op.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -27,7 +26,7 @@ namespace { using NodeOut = NodeBuilder::NodeOut; constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; // Convenience builder to make it easy to construct a conditional with a single // function call in the then and else branch. This first converts the if node diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h index cfaf15e71f1..55b7b91b56f 100644 --- a/tensorflow/core/common_runtime/lower_if_op.h +++ b/tensorflow/core/common_runtime/lower_if_op.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_ -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class Graph; +class Node; + // Replaces If node `n` with its lowered form that uses Switch and Merge nodes. Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable); diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc index 1f8cbe374bb..e9d322721f2 100644 --- a/tensorflow/core/common_runtime/lower_while_op.cc +++ b/tensorflow/core/common_runtime/lower_while_op.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_while_op.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" -#include "tensorflow/core/common_runtime/lower_if_op.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" @@ -30,7 +28,7 @@ namespace { using NodeOut = NodeBuilder::NodeOut; constexpr const char* const kLowerAsMultiDeviceFunctionAttr = - LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; // Helper to convert a functional While op to its lowered form. // diff --git a/tensorflow/core/common_runtime/lower_while_op.h b/tensorflow/core/common_runtime/lower_while_op.h index 9f016c45892..1dd22389ec4 100644 --- a/tensorflow/core/common_runtime/lower_while_op.h +++ b/tensorflow/core/common_runtime/lower_while_op.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_ -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class Graph; +class Node; + // Replaces While node `n` with its lowered form that uses Enter, Exit, Switch, // Merge, NextIteration and LoopCond nodes. Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable); diff --git a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc index 3363930c882..75d150834cc 100644 --- a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc +++ b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h"