[Build cleanup] Split "core_cpu_impl" into fine-grained targets (5/5).
This changes moves the remaining sources in "core_cpu_rump_impl" into fine-grained targets. It leaves two circular dependencies in place: * ":colocation_graph" includes both "colocation_graph.{h,cc}" and "inspecting_placer.{h,cc}" because `ColocationGraph` has an `InspectingPlacer`, and an `InspectingPlacer` creates a `ColocationGraph`. * ":function" includes both "function.{h,cc}" and "process_function_library_runtime.{h,cc}" because `ProcessFunctionLibraryRuntime` calls `NewFunctionLibraryRuntime()` and `FunctionLibraryRuntimeImpl` invokes methods on `ProcessFunctionLibraryRuntime`. Several files also depend on the indirect inclusion of "process_function_library_runtime.h" in "function.h". PiperOrigin-RevId: 308949398 Change-Id: I65e0664556e9ee7659640a1ed23fdb4072ab16c2
This commit is contained in:
parent
4bfe1dce64
commit
913b88bd0d
tensorflow/core/common_runtime
BUILDcolocation_graph.ccfunction.ccfunction.hfunction_def_utils.ccfunction_def_utils.hinline_function_utils.ccinline_function_utils.hinspecting_placer.cclower_case_op.cclower_case_op.hlower_function_call_op.cclower_function_call_op.hlower_functional_ops.cclower_functional_ops.hlower_if_op.cclower_if_op.hlower_while_op.cclower_while_op.hplacer_inspection_required_ops_utils.cc
@ -157,6 +157,7 @@ filegroup(
|
||||
"eval_const_tensor.h",
|
||||
"function.h",
|
||||
"function_body.h",
|
||||
"function_def_utils.h",
|
||||
"function_utils.h",
|
||||
"graph_constructor.h",
|
||||
"graph_def_builder_util.h",
|
||||
@ -399,6 +400,50 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "colocation_graph",
|
||||
srcs = [
|
||||
"colocation_graph.cc",
|
||||
"inspecting_placer.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"colocation_graph.h",
|
||||
"inspecting_placer.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":composite_device",
|
||||
":device",
|
||||
":device_set",
|
||||
":function_body",
|
||||
":function_def_utils",
|
||||
":input_colocation_exemption_registry",
|
||||
":partitioning_utils",
|
||||
":placer_inspection_required_ops_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "composite_device",
|
||||
srcs = ["composite_device.cc"],
|
||||
hdrs = ["composite_device.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_folding",
|
||||
srcs = ["constant_folding.cc"],
|
||||
@ -574,6 +619,50 @@ cc_library(
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function",
|
||||
srcs = [
|
||||
"function.cc",
|
||||
"process_function_library_runtime.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"function.h",
|
||||
"process_function_library_runtime.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":device_mgr",
|
||||
":device_set",
|
||||
":executor",
|
||||
":executor_factory",
|
||||
":function_body",
|
||||
":function_def_utils",
|
||||
":function_optimization_registry",
|
||||
":function_utils",
|
||||
":gradients",
|
||||
":graph_constructor",
|
||||
":graph_optimizer",
|
||||
":inline_function_utils",
|
||||
":memory_types",
|
||||
":optimization_registry",
|
||||
":partitioning_utils",
|
||||
":placer",
|
||||
":process_util",
|
||||
":rendezvous_mgr",
|
||||
":rendezvous_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_body",
|
||||
srcs = ["function_body.cc"],
|
||||
@ -599,10 +688,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_def_utils",
|
||||
srcs = ["function_def_utils.cc"],
|
||||
hdrs = ["function_def_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":function_body",
|
||||
":graph_constructor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_utils",
|
||||
srcs = ["function_utils.cc"],
|
||||
hdrs = ["function_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":function_body",
|
||||
"//tensorflow/core:framework",
|
||||
@ -777,6 +881,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "isolate_placer_inspection_required_ops_pass",
|
||||
srcs = ["isolate_placer_inspection_required_ops_pass.cc"],
|
||||
hdrs = ["isolate_placer_inspection_required_ops_pass.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":optimization_registry",
|
||||
":placer_inspection_required_ops_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_device",
|
||||
srcs = ["local_device.cc"],
|
||||
@ -803,6 +922,81 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lower_case_op",
|
||||
srcs = ["lower_case_op.cc"],
|
||||
hdrs = ["lower_case_op.h"],
|
||||
deps = [
|
||||
":inline_function_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lower_function_call_op",
|
||||
srcs = ["lower_function_call_op.cc"],
|
||||
hdrs = ["lower_function_call_op.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":function_def_utils",
|
||||
":inline_function_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lower_functional_ops",
|
||||
srcs = ["lower_functional_ops.cc"],
|
||||
hdrs = ["lower_functional_ops.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":function_utils",
|
||||
":inline_function_utils",
|
||||
":lower_case_op",
|
||||
":lower_function_call_op",
|
||||
":lower_if_op",
|
||||
":lower_while_op",
|
||||
":optimization_registry",
|
||||
":session_options",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lower_if_op",
|
||||
srcs = ["lower_if_op.cc"],
|
||||
hdrs = ["lower_if_op.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":inline_function_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lower_while_op",
|
||||
srcs = ["lower_while_op.cc"],
|
||||
hdrs = ["lower_while_op.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":inline_function_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_types",
|
||||
srcs = ["memory_types.cc"],
|
||||
@ -918,6 +1112,37 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "placer",
|
||||
srcs = ["placer.cc"],
|
||||
hdrs = ["placer.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":colocation_graph",
|
||||
":device",
|
||||
":device_set",
|
||||
":session_options",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "placer_inspection_required_ops_utils",
|
||||
srcs = ["placer_inspection_required_ops_utils.cc"],
|
||||
hdrs = ["placer_inspection_required_ops_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "process_state",
|
||||
srcs = ["process_state.cc"],
|
||||
@ -1289,69 +1514,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_rump_impl",
|
||||
srcs = [
|
||||
"colocation_graph.cc",
|
||||
"composite_device.cc",
|
||||
"function.cc",
|
||||
"inspecting_placer.cc",
|
||||
"isolate_placer_inspection_required_ops_pass.cc",
|
||||
"lower_case_op.cc",
|
||||
"lower_function_call_op.cc",
|
||||
"lower_functional_ops.cc",
|
||||
"lower_if_op.cc",
|
||||
"lower_while_op.cc",
|
||||
"placer.cc",
|
||||
"placer_inspection_required_ops_utils.cc",
|
||||
"placer_inspection_required_ops_utils.h",
|
||||
"process_function_library_runtime.cc",
|
||||
],
|
||||
hdrs = [":core_cpu_lib_headers"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":entry",
|
||||
":executor",
|
||||
":executor_factory",
|
||||
":function_body",
|
||||
":function_optimization_registry",
|
||||
":gradients",
|
||||
":graph_constructor",
|
||||
":graph_optimizer",
|
||||
":graph_view",
|
||||
":local_executor_params",
|
||||
":immutable_executor_state",
|
||||
":inline_function_utils",
|
||||
":input_colocation_exemption_registry",
|
||||
":pending_counts",
|
||||
":propagator_debug_utils",
|
||||
":propagator_state",
|
||||
":session_options",
|
||||
":simple_propagator_state",
|
||||
":single_threaded_cpu_device",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/core/public:version",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:scoped_annotation",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
] + mkl_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_impl",
|
||||
hdrs = [":core_cpu_lib_headers"],
|
||||
@ -1367,7 +1529,6 @@ tf_cuda_library(
|
||||
":collective_rma_local",
|
||||
":collective_util",
|
||||
":copy_tensor",
|
||||
":core_cpu_rump_impl",
|
||||
":costmodel_manager",
|
||||
":debugger_state_interface",
|
||||
":device",
|
||||
@ -1376,11 +1537,14 @@ tf_cuda_library(
|
||||
":device_resolver_local",
|
||||
":device_set",
|
||||
":entry",
|
||||
":function",
|
||||
":graph_def_builder_util",
|
||||
":graph_view",
|
||||
":hierarchical_tree_broadcaster",
|
||||
":input_colocation_exemption_registry",
|
||||
":isolate_placer_inspection_required_ops_pass",
|
||||
":local_device",
|
||||
":lower_functional_ops",
|
||||
":memory_types",
|
||||
":mkl_cpu_allocator",
|
||||
":mkl_layout_pass",
|
||||
@ -1389,6 +1553,7 @@ tf_cuda_library(
|
||||
":parallel_concat_optimizer",
|
||||
":partitioning_utils",
|
||||
":pending_counts",
|
||||
":placer",
|
||||
":pool_allocator",
|
||||
":process_state",
|
||||
":process_util",
|
||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/inspecting_placer.h"
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -1395,38 +1396,4 @@ std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
|
||||
return SymbolicGradientHelper(f).Compute();
|
||||
}
|
||||
|
||||
Status FunctionDefToBodyHelper(
|
||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* const lib_def,
|
||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||
std::unique_ptr<FunctionBody>* fbody) {
|
||||
// Instantiates the function template into a graph def.
|
||||
InstantiationResult result;
|
||||
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(lib_def));
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
|
||||
|
||||
// Call BuildControlFlowInfo to validate that this function body has
|
||||
// well-formed control flow.
|
||||
std::vector<ControlFlowInfo> dummy;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
|
||||
|
||||
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
|
||||
result.ret_types, graph.release());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionBody>* fbody) {
|
||||
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
|
||||
return lib_def->LookUpOpDef(op, sig);
|
||||
};
|
||||
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
@ -80,26 +81,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
// TODO(zhifengc): Asks math expert to say the comment again.
|
||||
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
|
||||
|
||||
// Returns true iff `n` represents a function call. `n` can be a native
|
||||
// function call (n.type_string() is the function name),
|
||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
|
||||
// has been deprecated for a while).
|
||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
|
||||
|
||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||
// FunctionBody that holds the instantiated FunctionDef.
|
||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionBody>* fbody);
|
||||
|
||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||
// FunctionBody that holds the instantiated FunctionDef. Use custom function
|
||||
// signature lookup, in case instantiated function is not in the 'lib_def'.
|
||||
Status FunctionDefToBodyHelper(
|
||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||
std::unique_ptr<FunctionBody>* fbody);
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
|
||||
|
63
tensorflow/core/common_runtime/function_def_utils.cc
Normal file
63
tensorflow/core/common_runtime/function_def_utils.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status FunctionDefToBodyHelper(
|
||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* const lib_def,
|
||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||
std::unique_ptr<FunctionBody>* fbody) {
|
||||
// Instantiates the function template into a graph def.
|
||||
InstantiationResult result;
|
||||
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(lib_def));
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
|
||||
|
||||
// Call BuildControlFlowInfo to validate that this function body has
|
||||
// well-formed control flow.
|
||||
std::vector<ControlFlowInfo> dummy;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
|
||||
|
||||
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
|
||||
result.ret_types, graph.release());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionBody>* fbody) {
|
||||
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
|
||||
return lib_def->LookUpOpDef(op, sig);
|
||||
};
|
||||
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
49
tensorflow/core/common_runtime/function_def_utils.h
Normal file
49
tensorflow/core/common_runtime/function_def_utils.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AttrSlice;
|
||||
struct FunctionBody;
|
||||
class FunctionDef;
|
||||
class FunctionLibraryDefinition;
|
||||
class OpDef;
|
||||
|
||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||
// FunctionBody that holds the instantiated FunctionDef.
|
||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionBody>* fbody);
|
||||
|
||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||
// FunctionBody that holds the instantiated FunctionDef. Use custom function
|
||||
// signature lookup, in case instantiated function is not in the 'lib_def'.
|
||||
Status FunctionDefToBodyHelper(
|
||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||
std::unique_ptr<FunctionBody>* fbody);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
@ -43,6 +43,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*static*/ constexpr const char* const
|
||||
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||
/*static*/ constexpr const char* const
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
namespace {
|
||||
// A few string constant used throughout this module.
|
||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||
|
@ -231,6 +231,13 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
|
||||
}
|
||||
|
||||
struct LowerFunctionalOpsConstants {
|
||||
static constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||
"_lower_using_switch_merge";
|
||||
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
"_lower_as_multi_device_function";
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
||||
|
@ -21,7 +21,8 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
|
@ -15,8 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -29,7 +28,7 @@ namespace {
|
||||
using NodeOut = NodeBuilder::NodeOut;
|
||||
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
// Convenience builder to make it easy to construct a case with a single
|
||||
// function call in each branch. This first converts the Case node
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
|
||||
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
|
||||
Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||
|
||||
|
@ -16,9 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
@ -30,7 +29,7 @@ using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
bool LowerAsMultiDeviceFunction(const Node* n) {
|
||||
if (n->IsPartitionedCall()) return true;
|
||||
|
@ -16,11 +16,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FunctionLibraryDefinition;
|
||||
class Graph;
|
||||
class Node;
|
||||
|
||||
// Replaces function call node `n` with its function body. Uses
|
||||
// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function
|
||||
// inlining is not possible or safe (see ValidateInlining), leaves the graph in
|
||||
|
@ -15,7 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
@ -27,17 +28,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*static*/ constexpr const char* const
|
||||
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
||||
/*static*/ constexpr const char* const
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
||||
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
|
||||
constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
|
||||
@ -173,7 +169,7 @@ Status LowerFunctionalOpsPass::Run(
|
||||
DCHECK(!lower_control_flow(n))
|
||||
<< "Node " << FormatNodeForError(*n) << " of type "
|
||||
<< n->type_string() << " has '"
|
||||
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
|
||||
<< LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr
|
||||
<< "' attr set but it does not support lowering.\n";
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -38,9 +39,9 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass {
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
static constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||
"_lower_using_switch_merge";
|
||||
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
"_lower_as_multi_device_function";
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
private:
|
||||
// If defined use the value to control if functional ops must be fetchable
|
||||
|
@ -15,8 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -27,7 +26,7 @@ namespace {
|
||||
using NodeOut = NodeBuilder::NodeOut;
|
||||
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
// Convenience builder to make it easy to construct a conditional with a single
|
||||
// function call in the then and else branch. This first converts the if node
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
|
||||
// Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
|
||||
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||
|
||||
|
@ -15,9 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -30,7 +28,7 @@ namespace {
|
||||
using NodeOut = NodeBuilder::NodeOut;
|
||||
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
// Helper to convert a functional While op to its lowered form.
|
||||
//
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
|
||||
// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
|
||||
// Merge, NextIteration and LoopCond nodes.
|
||||
Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
Loading…
Reference in New Issue
Block a user