[Build cleanup] Split "core_cpu_impl" into fine-grained targets (5/5).

This changes moves the remaining sources in "core_cpu_rump_impl" into fine-grained targets. It leaves two circular dependencies in place:

* ":colocation_graph" includes both "colocation_graph.{h,cc}" and "inspecting_placer.{h,cc}" because `ColocationGraph` has an `InspectingPlacer`, and an `InspectingPlacer` creates a `ColocationGraph`.
* ":function" includes both "function.{h,cc}" and "process_function_library_runtime.{h,cc}" because `ProcessFunctionLibraryRuntime` calls `NewFunctionLibraryRuntime()` and `FunctionLibraryRuntimeImpl` invokes methods on `ProcessFunctionLibraryRuntime`. Several files also depend on the indirect inclusion of "process_function_library_runtime.h" in "function.h".

PiperOrigin-RevId: 308949398
Change-Id: I65e0664556e9ee7659640a1ed23fdb4072ab16c2
This commit is contained in:
Derek Murray 2020-04-28 20:15:25 -07:00 committed by TensorFlower Gardener
parent 4bfe1dce64
commit 913b88bd0d
20 changed files with 386 additions and 149 deletions

View File

@ -157,6 +157,7 @@ filegroup(
"eval_const_tensor.h",
"function.h",
"function_body.h",
"function_def_utils.h",
"function_utils.h",
"graph_constructor.h",
"graph_def_builder_util.h",
@ -399,6 +400,50 @@ cc_library(
],
)
cc_library(
name = "colocation_graph",
srcs = [
"colocation_graph.cc",
"inspecting_placer.cc",
],
hdrs = [
"colocation_graph.h",
"inspecting_placer.h",
],
copts = tf_copts(),
deps = [
":composite_device",
":device",
":device_set",
":function_body",
":function_def_utils",
":input_colocation_exemption_registry",
":partitioning_utils",
":placer_inspection_required_ops_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "composite_device",
srcs = ["composite_device.cc"],
hdrs = ["composite_device.h"],
copts = tf_copts(),
deps = [
":device",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
@ -574,6 +619,50 @@ cc_library(
deps = ["//tensorflow/core:framework"],
)
cc_library(
name = "function",
srcs = [
"function.cc",
"process_function_library_runtime.cc",
],
hdrs = [
"function.h",
"process_function_library_runtime.h",
],
copts = tf_copts(),
deps = [
":device",
":device_mgr",
":device_set",
":executor",
":executor_factory",
":function_body",
":function_def_utils",
":function_optimization_registry",
":function_utils",
":gradients",
":graph_constructor",
":graph_optimizer",
":inline_function_utils",
":memory_types",
":optimization_registry",
":partitioning_utils",
":placer",
":process_util",
":rendezvous_mgr",
":rendezvous_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "function_body",
srcs = ["function_body.cc"],
@ -599,10 +688,25 @@ cc_library(
],
)
cc_library(
name = "function_def_utils",
srcs = ["function_def_utils.cc"],
hdrs = ["function_def_utils.h"],
copts = tf_copts(),
deps = [
":function_body",
":graph_constructor",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "function_utils",
srcs = ["function_utils.cc"],
hdrs = ["function_utils.h"],
copts = tf_copts(),
deps = [
":function_body",
"//tensorflow/core:framework",
@ -777,6 +881,21 @@ cc_library(
],
)
cc_library(
name = "isolate_placer_inspection_required_ops_pass",
srcs = ["isolate_placer_inspection_required_ops_pass.cc"],
hdrs = ["isolate_placer_inspection_required_ops_pass.h"],
copts = tf_copts(),
deps = [
":optimization_registry",
":placer_inspection_required_ops_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_library(
name = "local_device",
srcs = ["local_device.cc"],
@ -803,6 +922,81 @@ cc_library(
],
)
cc_library(
name = "lower_case_op",
srcs = ["lower_case_op.cc"],
hdrs = ["lower_case_op.h"],
deps = [
":inline_function_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "lower_function_call_op",
srcs = ["lower_function_call_op.cc"],
hdrs = ["lower_function_call_op.h"],
copts = tf_copts(),
deps = [
":function_def_utils",
":inline_function_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
],
)
cc_library(
name = "lower_functional_ops",
srcs = ["lower_functional_ops.cc"],
hdrs = ["lower_functional_ops.h"],
copts = tf_copts(),
deps = [
":function_utils",
":inline_function_utils",
":lower_case_op",
":lower_function_call_op",
":lower_if_op",
":lower_while_op",
":optimization_registry",
":session_options",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_library(
name = "lower_if_op",
srcs = ["lower_if_op.cc"],
hdrs = ["lower_if_op.h"],
copts = tf_copts(),
deps = [
":inline_function_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "lower_while_op",
srcs = ["lower_while_op.cc"],
hdrs = ["lower_while_op.h"],
copts = tf_copts(),
deps = [
":inline_function_utils",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "memory_types",
srcs = ["memory_types.cc"],
@ -918,6 +1112,37 @@ cc_library(
],
)
cc_library(
name = "placer",
srcs = ["placer.cc"],
hdrs = ["placer.h"],
copts = tf_copts(),
deps = [
":colocation_graph",
":device",
":device_set",
":session_options",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "placer_inspection_required_ops_utils",
srcs = ["placer_inspection_required_ops_utils.cc"],
hdrs = ["placer_inspection_required_ops_utils.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "process_state",
srcs = ["process_state.cc"],
@ -1289,69 +1514,6 @@ cc_library(
alwayslink = 1,
)
tf_cuda_library(
name = "core_cpu_rump_impl",
srcs = [
"colocation_graph.cc",
"composite_device.cc",
"function.cc",
"inspecting_placer.cc",
"isolate_placer_inspection_required_ops_pass.cc",
"lower_case_op.cc",
"lower_function_call_op.cc",
"lower_functional_ops.cc",
"lower_if_op.cc",
"lower_while_op.cc",
"placer.cc",
"placer_inspection_required_ops_utils.cc",
"placer_inspection_required_ops_utils.h",
"process_function_library_runtime.cc",
],
hdrs = [":core_cpu_lib_headers"],
copts = tf_copts(),
deps = [
":device",
":entry",
":executor",
":executor_factory",
":function_body",
":function_optimization_registry",
":gradients",
":graph_constructor",
":graph_optimizer",
":graph_view",
":local_executor_params",
":immutable_executor_state",
":inline_function_utils",
":input_colocation_exemption_registry",
":pending_counts",
":propagator_debug_utils",
":propagator_state",
":session_options",
":simple_propagator_state",
":single_threaded_cpu_device",
"//tensorflow/core:graph",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
"//tensorflow/core/public:version",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:scoped_annotation",
"//tensorflow/core/profiler/lib:traceme",
] + mkl_deps(),
alwayslink = 1,
)
tf_cuda_library(
name = "core_cpu_impl",
hdrs = [":core_cpu_lib_headers"],
@ -1367,7 +1529,6 @@ tf_cuda_library(
":collective_rma_local",
":collective_util",
":copy_tensor",
":core_cpu_rump_impl",
":costmodel_manager",
":debugger_state_interface",
":device",
@ -1376,11 +1537,14 @@ tf_cuda_library(
":device_resolver_local",
":device_set",
":entry",
":function",
":graph_def_builder_util",
":graph_view",
":hierarchical_tree_broadcaster",
":input_colocation_exemption_registry",
":isolate_placer_inspection_required_ops_pass",
":local_device",
":lower_functional_ops",
":memory_types",
":mkl_cpu_allocator",
":mkl_layout_pass",
@ -1389,6 +1553,7 @@ tf_cuda_library(
":parallel_concat_optimizer",
":partitioning_utils",
":pending_counts",
":placer",
":pool_allocator",
":process_state",
":process_util",

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/inspecting_placer.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h"

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/function.h"
@ -1395,38 +1396,4 @@ std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
return SymbolicGradientHelper(f).Compute();
}
Status FunctionDefToBodyHelper(
const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* const lib_def,
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
std::unique_ptr<FunctionBody>* fbody) {
// Instantiates the function template into a graph def.
InstantiationResult result;
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
std::unique_ptr<Graph> graph(new Graph(lib_def));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
// Call BuildControlFlowInfo to validate that this function body has
// well-formed control flow.
std::vector<ControlFlowInfo> dummy;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
result.ret_types, graph.release());
return Status::OK();
}
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
std::unique_ptr<FunctionBody>* fbody) {
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
return lib_def->LookUpOpDef(op, sig);
};
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
}
} // end namespace tensorflow

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
@ -80,26 +81,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
// TODO(zhifengc): Asks math expert to say the comment again.
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
// Returns true iff `n` represents a function call. `n` can be a native
// function call (n.type_string() is the function name),
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
// has been deprecated for a while).
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef.
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
std::unique_ptr<FunctionBody>* fbody);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef. Use custom function
// signature lookup, in case instantiated function is not in the 'lib_def'.
Status FunctionDefToBodyHelper(
const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
std::unique_ptr<FunctionBody>* fbody);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include <vector>
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
Status FunctionDefToBodyHelper(
const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* const lib_def,
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
std::unique_ptr<FunctionBody>* fbody) {
// Instantiates the function template into a graph def.
InstantiationResult result;
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
std::unique_ptr<Graph> graph(new Graph(lib_def));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
// Call BuildControlFlowInfo to validate that this function body has
// well-formed control flow.
std::vector<ControlFlowInfo> dummy;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
result.ret_types, graph.release());
return Status::OK();
}
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
std::unique_ptr<FunctionBody>* fbody) {
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
return lib_def->LookUpOpDef(op, sig);
};
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
}
} // end namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
#include <functional>
#include <memory>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class AttrSlice;
struct FunctionBody;
class FunctionDef;
class FunctionLibraryDefinition;
class OpDef;
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef.
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
std::unique_ptr<FunctionBody>* fbody);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef. Use custom function
// signature lookup, in case instantiated function is not in the 'lib_def'.
Status FunctionDefToBodyHelper(
const FunctionDef& fdef, const AttrSlice& attrs,
const FunctionLibraryDefinition* lib_def,
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
std::unique_ptr<FunctionBody>* fbody);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_

View File

@ -43,6 +43,11 @@ limitations under the License.
namespace tensorflow {
/*static*/ constexpr const char* const
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
/*static*/ constexpr const char* const
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
namespace {
// A few string constant used throughout this module.
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;

View File

@ -231,6 +231,13 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
}
struct LowerFunctionalOpsConstants {
static constexpr const char* const kLowerUsingSwitchMergeAttr =
"_lower_using_switch_merge";
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
"_lower_as_multi_device_function";
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_

View File

@ -21,7 +21,8 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/colocation_graph.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"

View File

@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_case_op.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@ -29,7 +28,7 @@ namespace {
using NodeOut = NodeBuilder::NodeOut;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
// Convenience builder to make it easy to construct a case with a single
// function call in each branch. This first converts the Case node

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class Graph;
class Node;
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);

View File

@ -16,9 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "absl/algorithm/container.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_node_util.h"
@ -30,7 +29,7 @@ using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
bool LowerAsMultiDeviceFunction(const Node* n) {
if (n->IsPartitionedCall()) return true;

View File

@ -16,11 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class FunctionLibraryDefinition;
class Graph;
class Node;
// Replaces function call node `n` with its function body. Uses
// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function
// inlining is not possible or safe (see ValidateInlining), leaves the graph in

View File

@ -15,7 +15,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/lower_case_op.h"
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
@ -27,17 +28,12 @@ limitations under the License.
namespace tensorflow {
/*static*/ constexpr const char* const
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
/*static*/ constexpr const char* const
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
namespace {
constexpr const char* const kLowerUsingSwitchMergeAttr =
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
@ -173,7 +169,7 @@ Status LowerFunctionalOpsPass::Run(
DCHECK(!lower_control_flow(n))
<< "Node " << FormatNodeForError(*n) << " of type "
<< n->type_string() << " has '"
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
<< LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr
<< "' attr set but it does not support lowering.\n";
}
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
@ -38,9 +39,9 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
static constexpr const char* const kLowerUsingSwitchMergeAttr =
"_lower_using_switch_merge";
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
"_lower_as_multi_device_function";
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
private:
// If defined use the value to control if functional ops must be fetchable

View File

@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@ -27,7 +26,7 @@ namespace {
using NodeOut = NodeBuilder::NodeOut;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
// Convenience builder to make it easy to construct a conditional with a single
// function call in the then and else branch. This first converts the if node

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class Graph;
class Node;
// Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);

View File

@ -15,9 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_while_op.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
@ -30,7 +28,7 @@ namespace {
using NodeOut = NodeBuilder::NodeOut;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
// Helper to convert a functional While op to its lowered form.
//

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class Graph;
class Node;
// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
// Merge, NextIteration and LoopCond nodes.
Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"