[Build cleanup] Split "core_cpu_impl" into fine-grained targets (5/5).
This changes moves the remaining sources in "core_cpu_rump_impl" into fine-grained targets. It leaves two circular dependencies in place: * ":colocation_graph" includes both "colocation_graph.{h,cc}" and "inspecting_placer.{h,cc}" because `ColocationGraph` has an `InspectingPlacer`, and an `InspectingPlacer` creates a `ColocationGraph`. * ":function" includes both "function.{h,cc}" and "process_function_library_runtime.{h,cc}" because `ProcessFunctionLibraryRuntime` calls `NewFunctionLibraryRuntime()` and `FunctionLibraryRuntimeImpl` invokes methods on `ProcessFunctionLibraryRuntime`. Several files also depend on the indirect inclusion of "process_function_library_runtime.h" in "function.h". PiperOrigin-RevId: 308949398 Change-Id: I65e0664556e9ee7659640a1ed23fdb4072ab16c2
This commit is contained in:
parent
4bfe1dce64
commit
913b88bd0d
@ -157,6 +157,7 @@ filegroup(
|
|||||||
"eval_const_tensor.h",
|
"eval_const_tensor.h",
|
||||||
"function.h",
|
"function.h",
|
||||||
"function_body.h",
|
"function_body.h",
|
||||||
|
"function_def_utils.h",
|
||||||
"function_utils.h",
|
"function_utils.h",
|
||||||
"graph_constructor.h",
|
"graph_constructor.h",
|
||||||
"graph_def_builder_util.h",
|
"graph_def_builder_util.h",
|
||||||
@ -399,6 +400,50 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "colocation_graph",
|
||||||
|
srcs = [
|
||||||
|
"colocation_graph.cc",
|
||||||
|
"inspecting_placer.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"colocation_graph.h",
|
||||||
|
"inspecting_placer.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":composite_device",
|
||||||
|
":device",
|
||||||
|
":device_set",
|
||||||
|
":function_body",
|
||||||
|
":function_def_utils",
|
||||||
|
":input_colocation_exemption_registry",
|
||||||
|
":partitioning_utils",
|
||||||
|
":placer_inspection_required_ops_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "composite_device",
|
||||||
|
srcs = ["composite_device.cc"],
|
||||||
|
hdrs = ["composite_device.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":device",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "constant_folding",
|
name = "constant_folding",
|
||||||
srcs = ["constant_folding.cc"],
|
srcs = ["constant_folding.cc"],
|
||||||
@ -574,6 +619,50 @@ cc_library(
|
|||||||
deps = ["//tensorflow/core:framework"],
|
deps = ["//tensorflow/core:framework"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "function",
|
||||||
|
srcs = [
|
||||||
|
"function.cc",
|
||||||
|
"process_function_library_runtime.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"function.h",
|
||||||
|
"process_function_library_runtime.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":device",
|
||||||
|
":device_mgr",
|
||||||
|
":device_set",
|
||||||
|
":executor",
|
||||||
|
":executor_factory",
|
||||||
|
":function_body",
|
||||||
|
":function_def_utils",
|
||||||
|
":function_optimization_registry",
|
||||||
|
":function_utils",
|
||||||
|
":gradients",
|
||||||
|
":graph_constructor",
|
||||||
|
":graph_optimizer",
|
||||||
|
":inline_function_utils",
|
||||||
|
":memory_types",
|
||||||
|
":optimization_registry",
|
||||||
|
":partitioning_utils",
|
||||||
|
":placer",
|
||||||
|
":process_util",
|
||||||
|
":rendezvous_mgr",
|
||||||
|
":rendezvous_util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "function_body",
|
name = "function_body",
|
||||||
srcs = ["function_body.cc"],
|
srcs = ["function_body.cc"],
|
||||||
@ -599,10 +688,25 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "function_def_utils",
|
||||||
|
srcs = ["function_def_utils.cc"],
|
||||||
|
hdrs = ["function_def_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":function_body",
|
||||||
|
":graph_constructor",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "function_utils",
|
name = "function_utils",
|
||||||
srcs = ["function_utils.cc"],
|
srcs = ["function_utils.cc"],
|
||||||
hdrs = ["function_utils.h"],
|
hdrs = ["function_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":function_body",
|
":function_body",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -777,6 +881,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "isolate_placer_inspection_required_ops_pass",
|
||||||
|
srcs = ["isolate_placer_inspection_required_ops_pass.cc"],
|
||||||
|
hdrs = ["isolate_placer_inspection_required_ops_pass.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":optimization_registry",
|
||||||
|
":placer_inspection_required_ops_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "local_device",
|
name = "local_device",
|
||||||
srcs = ["local_device.cc"],
|
srcs = ["local_device.cc"],
|
||||||
@ -803,6 +922,81 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lower_case_op",
|
||||||
|
srcs = ["lower_case_op.cc"],
|
||||||
|
hdrs = ["lower_case_op.h"],
|
||||||
|
deps = [
|
||||||
|
":inline_function_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lower_function_call_op",
|
||||||
|
srcs = ["lower_function_call_op.cc"],
|
||||||
|
hdrs = ["lower_function_call_op.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":function_def_utils",
|
||||||
|
":inline_function_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lower_functional_ops",
|
||||||
|
srcs = ["lower_functional_ops.cc"],
|
||||||
|
hdrs = ["lower_functional_ops.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":function_utils",
|
||||||
|
":inline_function_utils",
|
||||||
|
":lower_case_op",
|
||||||
|
":lower_function_call_op",
|
||||||
|
":lower_if_op",
|
||||||
|
":lower_while_op",
|
||||||
|
":optimization_registry",
|
||||||
|
":session_options",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lower_if_op",
|
||||||
|
srcs = ["lower_if_op.cc"],
|
||||||
|
hdrs = ["lower_if_op.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":inline_function_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lower_while_op",
|
||||||
|
srcs = ["lower_while_op.cc"],
|
||||||
|
hdrs = ["lower_while_op.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":inline_function_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "memory_types",
|
name = "memory_types",
|
||||||
srcs = ["memory_types.cc"],
|
srcs = ["memory_types.cc"],
|
||||||
@ -918,6 +1112,37 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "placer",
|
||||||
|
srcs = ["placer.cc"],
|
||||||
|
hdrs = ["placer.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":colocation_graph",
|
||||||
|
":device",
|
||||||
|
":device_set",
|
||||||
|
":session_options",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "placer_inspection_required_ops_utils",
|
||||||
|
srcs = ["placer_inspection_required_ops_utils.cc"],
|
||||||
|
hdrs = ["placer_inspection_required_ops_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "process_state",
|
name = "process_state",
|
||||||
srcs = ["process_state.cc"],
|
srcs = ["process_state.cc"],
|
||||||
@ -1289,69 +1514,6 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
|
||||||
name = "core_cpu_rump_impl",
|
|
||||||
srcs = [
|
|
||||||
"colocation_graph.cc",
|
|
||||||
"composite_device.cc",
|
|
||||||
"function.cc",
|
|
||||||
"inspecting_placer.cc",
|
|
||||||
"isolate_placer_inspection_required_ops_pass.cc",
|
|
||||||
"lower_case_op.cc",
|
|
||||||
"lower_function_call_op.cc",
|
|
||||||
"lower_functional_ops.cc",
|
|
||||||
"lower_if_op.cc",
|
|
||||||
"lower_while_op.cc",
|
|
||||||
"placer.cc",
|
|
||||||
"placer_inspection_required_ops_utils.cc",
|
|
||||||
"placer_inspection_required_ops_utils.h",
|
|
||||||
"process_function_library_runtime.cc",
|
|
||||||
],
|
|
||||||
hdrs = [":core_cpu_lib_headers"],
|
|
||||||
copts = tf_copts(),
|
|
||||||
deps = [
|
|
||||||
":device",
|
|
||||||
":entry",
|
|
||||||
":executor",
|
|
||||||
":executor_factory",
|
|
||||||
":function_body",
|
|
||||||
":function_optimization_registry",
|
|
||||||
":gradients",
|
|
||||||
":graph_constructor",
|
|
||||||
":graph_optimizer",
|
|
||||||
":graph_view",
|
|
||||||
":local_executor_params",
|
|
||||||
":immutable_executor_state",
|
|
||||||
":inline_function_utils",
|
|
||||||
":input_colocation_exemption_registry",
|
|
||||||
":pending_counts",
|
|
||||||
":propagator_debug_utils",
|
|
||||||
":propagator_state",
|
|
||||||
":session_options",
|
|
||||||
":simple_propagator_state",
|
|
||||||
":single_threaded_cpu_device",
|
|
||||||
"//tensorflow/core:graph",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"@com_google_absl//absl/algorithm:container",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:optional",
|
|
||||||
"@com_google_absl//absl/types:variant",
|
|
||||||
"//tensorflow/core/public:version",
|
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
|
||||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
|
||||||
"//tensorflow/core/profiler/lib:scoped_annotation",
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
|
||||||
] + mkl_deps(),
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "core_cpu_impl",
|
name = "core_cpu_impl",
|
||||||
hdrs = [":core_cpu_lib_headers"],
|
hdrs = [":core_cpu_lib_headers"],
|
||||||
@ -1367,7 +1529,6 @@ tf_cuda_library(
|
|||||||
":collective_rma_local",
|
":collective_rma_local",
|
||||||
":collective_util",
|
":collective_util",
|
||||||
":copy_tensor",
|
":copy_tensor",
|
||||||
":core_cpu_rump_impl",
|
|
||||||
":costmodel_manager",
|
":costmodel_manager",
|
||||||
":debugger_state_interface",
|
":debugger_state_interface",
|
||||||
":device",
|
":device",
|
||||||
@ -1376,11 +1537,14 @@ tf_cuda_library(
|
|||||||
":device_resolver_local",
|
":device_resolver_local",
|
||||||
":device_set",
|
":device_set",
|
||||||
":entry",
|
":entry",
|
||||||
|
":function",
|
||||||
":graph_def_builder_util",
|
":graph_def_builder_util",
|
||||||
":graph_view",
|
":graph_view",
|
||||||
":hierarchical_tree_broadcaster",
|
":hierarchical_tree_broadcaster",
|
||||||
":input_colocation_exemption_registry",
|
":input_colocation_exemption_registry",
|
||||||
|
":isolate_placer_inspection_required_ops_pass",
|
||||||
":local_device",
|
":local_device",
|
||||||
|
":lower_functional_ops",
|
||||||
":memory_types",
|
":memory_types",
|
||||||
":mkl_cpu_allocator",
|
":mkl_cpu_allocator",
|
||||||
":mkl_layout_pass",
|
":mkl_layout_pass",
|
||||||
@ -1389,6 +1553,7 @@ tf_cuda_library(
|
|||||||
":parallel_concat_optimizer",
|
":parallel_concat_optimizer",
|
||||||
":partitioning_utils",
|
":partitioning_utils",
|
||||||
":pending_counts",
|
":pending_counts",
|
||||||
|
":placer",
|
||||||
":pool_allocator",
|
":pool_allocator",
|
||||||
":process_state",
|
":process_state",
|
||||||
":process_util",
|
":process_util",
|
||||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
|
||||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/inspecting_placer.h"
|
#include "tensorflow/core/common_runtime/inspecting_placer.h"
|
||||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||||
|
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -1395,38 +1396,4 @@ std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
|
|||||||
return SymbolicGradientHelper(f).Compute();
|
return SymbolicGradientHelper(f).Compute();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FunctionDefToBodyHelper(
|
|
||||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
|
||||||
const FunctionLibraryDefinition* const lib_def,
|
|
||||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
|
||||||
std::unique_ptr<FunctionBody>* fbody) {
|
|
||||||
// Instantiates the function template into a graph def.
|
|
||||||
InstantiationResult result;
|
|
||||||
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
|
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph(new Graph(lib_def));
|
|
||||||
GraphConstructorOptions opts;
|
|
||||||
opts.allow_internal_ops = true;
|
|
||||||
opts.expect_device_spec = false;
|
|
||||||
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
|
|
||||||
|
|
||||||
// Call BuildControlFlowInfo to validate that this function body has
|
|
||||||
// well-formed control flow.
|
|
||||||
std::vector<ControlFlowInfo> dummy;
|
|
||||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
|
|
||||||
|
|
||||||
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
|
|
||||||
result.ret_types, graph.release());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
|
||||||
const FunctionLibraryDefinition* lib_def,
|
|
||||||
std::unique_ptr<FunctionBody>* fbody) {
|
|
||||||
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
|
|
||||||
return lib_def->LookUpOpDef(op, sig);
|
|
||||||
};
|
|
||||||
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/function_body.h"
|
#include "tensorflow/core/common_runtime/function_body.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
@ -80,26 +81,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
|||||||
// TODO(zhifengc): Asks math expert to say the comment again.
|
// TODO(zhifengc): Asks math expert to say the comment again.
|
||||||
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
|
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
|
||||||
|
|
||||||
// Returns true iff `n` represents a function call. `n` can be a native
|
|
||||||
// function call (n.type_string() is the function name),
|
|
||||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
|
|
||||||
// has been deprecated for a while).
|
|
||||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
|
|
||||||
|
|
||||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
|
||||||
// FunctionBody that holds the instantiated FunctionDef.
|
|
||||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
|
||||||
const FunctionLibraryDefinition* lib_def,
|
|
||||||
std::unique_ptr<FunctionBody>* fbody);
|
|
||||||
|
|
||||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
|
||||||
// FunctionBody that holds the instantiated FunctionDef. Use custom function
|
|
||||||
// signature lookup, in case instantiated function is not in the 'lib_def'.
|
|
||||||
Status FunctionDefToBodyHelper(
|
|
||||||
const FunctionDef& fdef, const AttrSlice& attrs,
|
|
||||||
const FunctionLibraryDefinition* lib_def,
|
|
||||||
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
|
||||||
std::unique_ptr<FunctionBody>* fbody);
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
|
||||||
|
63
tensorflow/core/common_runtime/function_def_utils.cc
Normal file
63
tensorflow/core/common_runtime/function_def_utils.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/function_body.h"
|
||||||
|
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/graph/control_flow.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status FunctionDefToBodyHelper(
|
||||||
|
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||||
|
const FunctionLibraryDefinition* const lib_def,
|
||||||
|
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||||
|
std::unique_ptr<FunctionBody>* fbody) {
|
||||||
|
// Instantiates the function template into a graph def.
|
||||||
|
InstantiationResult result;
|
||||||
|
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(lib_def));
|
||||||
|
GraphConstructorOptions opts;
|
||||||
|
opts.allow_internal_ops = true;
|
||||||
|
opts.expect_device_spec = false;
|
||||||
|
TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
|
||||||
|
|
||||||
|
// Call BuildControlFlowInfo to validate that this function body has
|
||||||
|
// well-formed control flow.
|
||||||
|
std::vector<ControlFlowInfo> dummy;
|
||||||
|
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
|
||||||
|
|
||||||
|
*fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
|
||||||
|
result.ret_types, graph.release());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||||
|
const FunctionLibraryDefinition* lib_def,
|
||||||
|
std::unique_ptr<FunctionBody>* fbody) {
|
||||||
|
const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
|
||||||
|
return lib_def->LookUpOpDef(op, sig);
|
||||||
|
};
|
||||||
|
return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
49
tensorflow/core/common_runtime/function_def_utils.h
Normal file
49
tensorflow/core/common_runtime/function_def_utils.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class AttrSlice;
|
||||||
|
struct FunctionBody;
|
||||||
|
class FunctionDef;
|
||||||
|
class FunctionLibraryDefinition;
|
||||||
|
class OpDef;
|
||||||
|
|
||||||
|
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||||
|
// FunctionBody that holds the instantiated FunctionDef.
|
||||||
|
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
|
||||||
|
const FunctionLibraryDefinition* lib_def,
|
||||||
|
std::unique_ptr<FunctionBody>* fbody);
|
||||||
|
|
||||||
|
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||||
|
// FunctionBody that holds the instantiated FunctionDef. Use custom function
|
||||||
|
// signature lookup, in case instantiated function is not in the 'lib_def'.
|
||||||
|
Status FunctionDefToBodyHelper(
|
||||||
|
const FunctionDef& fdef, const AttrSlice& attrs,
|
||||||
|
const FunctionLibraryDefinition* lib_def,
|
||||||
|
const std::function<Status(const string&, const OpDef**)>& get_func_sig,
|
||||||
|
std::unique_ptr<FunctionBody>* fbody);
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
|
@ -43,6 +43,11 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
/*static*/ constexpr const char* const
|
||||||
|
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||||
|
/*static*/ constexpr const char* const
|
||||||
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// A few string constant used throughout this module.
|
// A few string constant used throughout this module.
|
||||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||||
|
@ -231,6 +231,13 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
|||||||
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
|
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LowerFunctionalOpsConstants {
|
||||||
|
static constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||||
|
"_lower_using_switch_merge";
|
||||||
|
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
|
"_lower_as_multi_device_function";
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
||||||
|
@ -21,7 +21,8 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function_body.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
|
#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
@ -15,8 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
@ -29,7 +28,7 @@ namespace {
|
|||||||
using NodeOut = NodeBuilder::NodeOut;
|
using NodeOut = NodeBuilder::NodeOut;
|
||||||
|
|
||||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
// Convenience builder to make it easy to construct a case with a single
|
// Convenience builder to make it easy to construct a case with a single
|
||||||
// function call in each branch. This first converts the Case node
|
// function call in each branch. This first converts the Case node
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class Graph;
|
||||||
|
class Node;
|
||||||
|
|
||||||
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
|
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
|
||||||
Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);
|
Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||||
|
|
||||||
|
@ -16,9 +16,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_node_util.h"
|
#include "tensorflow/core/graph/graph_node_util.h"
|
||||||
@ -30,7 +29,7 @@ using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
|
|||||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||||
|
|
||||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
bool LowerAsMultiDeviceFunction(const Node* n) {
|
bool LowerAsMultiDeviceFunction(const Node* n) {
|
||||||
if (n->IsPartitionedCall()) return true;
|
if (n->IsPartitionedCall()) return true;
|
||||||
|
@ -16,11 +16,14 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class FunctionLibraryDefinition;
|
||||||
|
class Graph;
|
||||||
|
class Node;
|
||||||
|
|
||||||
// Replaces function call node `n` with its function body. Uses
|
// Replaces function call node `n` with its function body. Uses
|
||||||
// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function
|
// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function
|
||||||
// inlining is not possible or safe (see ValidateInlining), leaves the graph in
|
// inlining is not possible or safe (see ValidateInlining), leaves the graph in
|
||||||
|
@ -15,7 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||||
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
@ -27,17 +28,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
/*static*/ constexpr const char* const
|
|
||||||
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
|
||||||
/*static*/ constexpr const char* const
|
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr const char* const kLowerUsingSwitchMergeAttr =
|
constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||||
LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
|
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
|
constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
|
||||||
constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
|
constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
|
||||||
@ -173,7 +169,7 @@ Status LowerFunctionalOpsPass::Run(
|
|||||||
DCHECK(!lower_control_flow(n))
|
DCHECK(!lower_control_flow(n))
|
||||||
<< "Node " << FormatNodeForError(*n) << " of type "
|
<< "Node " << FormatNodeForError(*n) << " of type "
|
||||||
<< n->type_string() << " has '"
|
<< n->type_string() << " has '"
|
||||||
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
|
<< LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr
|
||||||
<< "' attr set but it does not support lowering.\n";
|
<< "' attr set but it does not support lowering.\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
@ -38,9 +39,9 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass {
|
|||||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||||
|
|
||||||
static constexpr const char* const kLowerUsingSwitchMergeAttr =
|
static constexpr const char* const kLowerUsingSwitchMergeAttr =
|
||||||
"_lower_using_switch_merge";
|
LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
|
||||||
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
"_lower_as_multi_device_function";
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// If defined use the value to control if functional ops must be fetchable
|
// If defined use the value to control if functional ops must be fetchable
|
||||||
|
@ -15,8 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
@ -27,7 +26,7 @@ namespace {
|
|||||||
using NodeOut = NodeBuilder::NodeOut;
|
using NodeOut = NodeBuilder::NodeOut;
|
||||||
|
|
||||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
// Convenience builder to make it easy to construct a conditional with a single
|
// Convenience builder to make it easy to construct a conditional with a single
|
||||||
// function call in the then and else branch. This first converts the if node
|
// function call in the then and else branch. This first converts the if node
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class Graph;
|
||||||
|
class Node;
|
||||||
|
|
||||||
// Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
|
// Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
|
||||||
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);
|
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||||
|
|
||||||
|
@ -15,9 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -30,7 +28,7 @@ namespace {
|
|||||||
using NodeOut = NodeBuilder::NodeOut;
|
using NodeOut = NodeBuilder::NodeOut;
|
||||||
|
|
||||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
// Helper to convert a functional While op to its lowered form.
|
// Helper to convert a functional While op to its lowered form.
|
||||||
//
|
//
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class Graph;
|
||||||
|
class Node;
|
||||||
|
|
||||||
// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
|
// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
|
||||||
// Merge, NextIteration and LoopCond nodes.
|
// Merge, NextIteration and LoopCond nodes.
|
||||||
Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);
|
Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user