[Build cleanup] Split "core_cpu_impl" into fine-grained targets (5/5).
This changes moves the remaining sources in "core_cpu_rump_impl" into fine-grained targets. It leaves two circular dependencies in place:
* ":colocation_graph" includes both "colocation_graph.{h,cc}" and "inspecting_placer.{h,cc}" because `ColocationGraph` has an `InspectingPlacer`, and an `InspectingPlacer` creates a `ColocationGraph`.
* ":function" includes both "function.{h,cc}" and "process_function_library_runtime.{h,cc}" because `ProcessFunctionLibraryRuntime` calls `NewFunctionLibraryRuntime()` and `FunctionLibraryRuntimeImpl` invokes methods on `ProcessFunctionLibraryRuntime`. Several files also depend on the indirect inclusion of "process_function_library_runtime.h" in "function.h".
PiperOrigin-RevId: 308949398
Change-Id: I65e0664556e9ee7659640a1ed23fdb4072ab16c2
			
			
This commit is contained in:
		
							parent
							
								
									4bfe1dce64
								
							
						
					
					
						commit
						913b88bd0d
					
				@ -157,6 +157,7 @@ filegroup(
 | 
			
		||||
        "eval_const_tensor.h",
 | 
			
		||||
        "function.h",
 | 
			
		||||
        "function_body.h",
 | 
			
		||||
        "function_def_utils.h",
 | 
			
		||||
        "function_utils.h",
 | 
			
		||||
        "graph_constructor.h",
 | 
			
		||||
        "graph_def_builder_util.h",
 | 
			
		||||
@ -399,6 +400,50 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "colocation_graph",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "colocation_graph.cc",
 | 
			
		||||
        "inspecting_placer.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "colocation_graph.h",
 | 
			
		||||
        "inspecting_placer.h",
 | 
			
		||||
    ],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":composite_device",
 | 
			
		||||
        ":device",
 | 
			
		||||
        ":device_set",
 | 
			
		||||
        ":function_body",
 | 
			
		||||
        ":function_def_utils",
 | 
			
		||||
        ":input_colocation_exemption_registry",
 | 
			
		||||
        ":partitioning_utils",
 | 
			
		||||
        ":placer_inspection_required_ops_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/types:optional",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "composite_device",
 | 
			
		||||
    srcs = ["composite_device.cc"],
 | 
			
		||||
    hdrs = ["composite_device.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":device",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "constant_folding",
 | 
			
		||||
    srcs = ["constant_folding.cc"],
 | 
			
		||||
@ -574,6 +619,50 @@ cc_library(
 | 
			
		||||
    deps = ["//tensorflow/core:framework"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "function",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "function.cc",
 | 
			
		||||
        "process_function_library_runtime.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "function.h",
 | 
			
		||||
        "process_function_library_runtime.h",
 | 
			
		||||
    ],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":device",
 | 
			
		||||
        ":device_mgr",
 | 
			
		||||
        ":device_set",
 | 
			
		||||
        ":executor",
 | 
			
		||||
        ":executor_factory",
 | 
			
		||||
        ":function_body",
 | 
			
		||||
        ":function_def_utils",
 | 
			
		||||
        ":function_optimization_registry",
 | 
			
		||||
        ":function_utils",
 | 
			
		||||
        ":gradients",
 | 
			
		||||
        ":graph_constructor",
 | 
			
		||||
        ":graph_optimizer",
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        ":memory_types",
 | 
			
		||||
        ":optimization_registry",
 | 
			
		||||
        ":partitioning_utils",
 | 
			
		||||
        ":placer",
 | 
			
		||||
        ":process_util",
 | 
			
		||||
        ":rendezvous_mgr",
 | 
			
		||||
        ":rendezvous_util",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:lib_internal",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core/profiler/lib:traceme",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "function_body",
 | 
			
		||||
    srcs = ["function_body.cc"],
 | 
			
		||||
@ -599,10 +688,25 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "function_def_utils",
 | 
			
		||||
    srcs = ["function_def_utils.cc"],
 | 
			
		||||
    hdrs = ["function_def_utils.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":function_body",
 | 
			
		||||
        ":graph_constructor",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "function_utils",
 | 
			
		||||
    srcs = ["function_utils.cc"],
 | 
			
		||||
    hdrs = ["function_utils.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":function_body",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
@ -777,6 +881,21 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "isolate_placer_inspection_required_ops_pass",
 | 
			
		||||
    srcs = ["isolate_placer_inspection_required_ops_pass.cc"],
 | 
			
		||||
    hdrs = ["isolate_placer_inspection_required_ops_pass.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":optimization_registry",
 | 
			
		||||
        ":placer_inspection_required_ops_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "local_device",
 | 
			
		||||
    srcs = ["local_device.cc"],
 | 
			
		||||
@ -803,6 +922,81 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lower_case_op",
 | 
			
		||||
    srcs = ["lower_case_op.cc"],
 | 
			
		||||
    hdrs = ["lower_case_op.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lower_function_call_op",
 | 
			
		||||
    srcs = ["lower_function_call_op.cc"],
 | 
			
		||||
    hdrs = ["lower_function_call_op.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":function_def_utils",
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lower_functional_ops",
 | 
			
		||||
    srcs = ["lower_functional_ops.cc"],
 | 
			
		||||
    hdrs = ["lower_functional_ops.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":function_utils",
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        ":lower_case_op",
 | 
			
		||||
        ":lower_function_call_op",
 | 
			
		||||
        ":lower_if_op",
 | 
			
		||||
        ":lower_while_op",
 | 
			
		||||
        ":optimization_registry",
 | 
			
		||||
        ":session_options",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lower_if_op",
 | 
			
		||||
    srcs = ["lower_if_op.cc"],
 | 
			
		||||
    hdrs = ["lower_if_op.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lower_while_op",
 | 
			
		||||
    srcs = ["lower_while_op.cc"],
 | 
			
		||||
    hdrs = ["lower_while_op.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "memory_types",
 | 
			
		||||
    srcs = ["memory_types.cc"],
 | 
			
		||||
@ -918,6 +1112,37 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "placer",
 | 
			
		||||
    srcs = ["placer.cc"],
 | 
			
		||||
    hdrs = ["placer.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":colocation_graph",
 | 
			
		||||
        ":device",
 | 
			
		||||
        ":device_set",
 | 
			
		||||
        ":session_options",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "placer_inspection_required_ops_utils",
 | 
			
		||||
    srcs = ["placer_inspection_required_ops_utils.cc"],
 | 
			
		||||
    hdrs = ["placer_inspection_required_ops_utils.h"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/types:optional",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "process_state",
 | 
			
		||||
    srcs = ["process_state.cc"],
 | 
			
		||||
@ -1289,69 +1514,6 @@ cc_library(
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cuda_library(
 | 
			
		||||
    name = "core_cpu_rump_impl",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "colocation_graph.cc",
 | 
			
		||||
        "composite_device.cc",
 | 
			
		||||
        "function.cc",
 | 
			
		||||
        "inspecting_placer.cc",
 | 
			
		||||
        "isolate_placer_inspection_required_ops_pass.cc",
 | 
			
		||||
        "lower_case_op.cc",
 | 
			
		||||
        "lower_function_call_op.cc",
 | 
			
		||||
        "lower_functional_ops.cc",
 | 
			
		||||
        "lower_if_op.cc",
 | 
			
		||||
        "lower_while_op.cc",
 | 
			
		||||
        "placer.cc",
 | 
			
		||||
        "placer_inspection_required_ops_utils.cc",
 | 
			
		||||
        "placer_inspection_required_ops_utils.h",
 | 
			
		||||
        "process_function_library_runtime.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [":core_cpu_lib_headers"],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":device",
 | 
			
		||||
        ":entry",
 | 
			
		||||
        ":executor",
 | 
			
		||||
        ":executor_factory",
 | 
			
		||||
        ":function_body",
 | 
			
		||||
        ":function_optimization_registry",
 | 
			
		||||
        ":gradients",
 | 
			
		||||
        ":graph_constructor",
 | 
			
		||||
        ":graph_optimizer",
 | 
			
		||||
        ":graph_view",
 | 
			
		||||
        ":local_executor_params",
 | 
			
		||||
        ":immutable_executor_state",
 | 
			
		||||
        ":inline_function_utils",
 | 
			
		||||
        ":input_colocation_exemption_registry",
 | 
			
		||||
        ":pending_counts",
 | 
			
		||||
        ":propagator_debug_utils",
 | 
			
		||||
        ":propagator_state",
 | 
			
		||||
        ":session_options",
 | 
			
		||||
        ":simple_propagator_state",
 | 
			
		||||
        ":single_threaded_cpu_device",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:lib_internal",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_map",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/types:optional",
 | 
			
		||||
        "@com_google_absl//absl/types:variant",
 | 
			
		||||
        "//tensorflow/core/public:version",
 | 
			
		||||
        "//tensorflow/core/grappler/utils:functions",
 | 
			
		||||
        "//tensorflow/core/profiler/lib:annotated_traceme",
 | 
			
		||||
        "//tensorflow/core/profiler/lib:scoped_annotation",
 | 
			
		||||
        "//tensorflow/core/profiler/lib:traceme",
 | 
			
		||||
    ] + mkl_deps(),
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cuda_library(
 | 
			
		||||
    name = "core_cpu_impl",
 | 
			
		||||
    hdrs = [":core_cpu_lib_headers"],
 | 
			
		||||
@ -1367,7 +1529,6 @@ tf_cuda_library(
 | 
			
		||||
        ":collective_rma_local",
 | 
			
		||||
        ":collective_util",
 | 
			
		||||
        ":copy_tensor",
 | 
			
		||||
        ":core_cpu_rump_impl",
 | 
			
		||||
        ":costmodel_manager",
 | 
			
		||||
        ":debugger_state_interface",
 | 
			
		||||
        ":device",
 | 
			
		||||
@ -1376,11 +1537,14 @@ tf_cuda_library(
 | 
			
		||||
        ":device_resolver_local",
 | 
			
		||||
        ":device_set",
 | 
			
		||||
        ":entry",
 | 
			
		||||
        ":function",
 | 
			
		||||
        ":graph_def_builder_util",
 | 
			
		||||
        ":graph_view",
 | 
			
		||||
        ":hierarchical_tree_broadcaster",
 | 
			
		||||
        ":input_colocation_exemption_registry",
 | 
			
		||||
        ":isolate_placer_inspection_required_ops_pass",
 | 
			
		||||
        ":local_device",
 | 
			
		||||
        ":lower_functional_ops",
 | 
			
		||||
        ":memory_types",
 | 
			
		||||
        ":mkl_cpu_allocator",
 | 
			
		||||
        ":mkl_layout_pass",
 | 
			
		||||
@ -1389,6 +1553,7 @@ tf_cuda_library(
 | 
			
		||||
        ":parallel_concat_optimizer",
 | 
			
		||||
        ":partitioning_utils",
 | 
			
		||||
        ":pending_counts",
 | 
			
		||||
        ":placer",
 | 
			
		||||
        ":pool_allocator",
 | 
			
		||||
        ":process_state",
 | 
			
		||||
        ":process_util",
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,6 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/common_runtime/composite_device.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/device.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/device_set.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inspecting_placer.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/memory_types.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 | 
			
		||||
#include "tensorflow/core/framework/collective.h"
 | 
			
		||||
#include "tensorflow/core/framework/function.h"
 | 
			
		||||
@ -1395,38 +1396,4 @@ std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
 | 
			
		||||
  return SymbolicGradientHelper(f).Compute();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status FunctionDefToBodyHelper(
 | 
			
		||||
    const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
    const FunctionLibraryDefinition* const lib_def,
 | 
			
		||||
    const std::function<Status(const string&, const OpDef**)>& get_func_sig,
 | 
			
		||||
    std::unique_ptr<FunctionBody>* fbody) {
 | 
			
		||||
  // Instantiates the function template into a graph def.
 | 
			
		||||
  InstantiationResult result;
 | 
			
		||||
  TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Graph> graph(new Graph(lib_def));
 | 
			
		||||
  GraphConstructorOptions opts;
 | 
			
		||||
  opts.allow_internal_ops = true;
 | 
			
		||||
  opts.expect_device_spec = false;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
 | 
			
		||||
 | 
			
		||||
  // Call BuildControlFlowInfo to validate that this function body has
 | 
			
		||||
  // well-formed control flow.
 | 
			
		||||
  std::vector<ControlFlowInfo> dummy;
 | 
			
		||||
  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
 | 
			
		||||
 | 
			
		||||
  *fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
 | 
			
		||||
                                           result.ret_types, graph.release());
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
                               const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
                               std::unique_ptr<FunctionBody>* fbody) {
 | 
			
		||||
  const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
 | 
			
		||||
    return lib_def->LookUpOpDef(op, sig);
 | 
			
		||||
  };
 | 
			
		||||
  return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/common_runtime/device.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/device_mgr.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_body.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
@ -80,26 +81,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
 | 
			
		||||
// TODO(zhifengc): Asks math expert to say the comment again.
 | 
			
		||||
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
 | 
			
		||||
 | 
			
		||||
// Returns true iff `n` represents a function call. `n` can be a native
 | 
			
		||||
// function call (n.type_string() is the function name),
 | 
			
		||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
 | 
			
		||||
// has been deprecated for a while).
 | 
			
		||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
 | 
			
		||||
 | 
			
		||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
 | 
			
		||||
// FunctionBody that holds the instantiated FunctionDef.
 | 
			
		||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
                               const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
                               std::unique_ptr<FunctionBody>* fbody);
 | 
			
		||||
 | 
			
		||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
 | 
			
		||||
// FunctionBody that holds the instantiated FunctionDef. Use custom function
 | 
			
		||||
// signature lookup, in case instantiated function is not in the 'lib_def'.
 | 
			
		||||
Status FunctionDefToBodyHelper(
 | 
			
		||||
    const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
    const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
    const std::function<Status(const string&, const OpDef**)>& get_func_sig,
 | 
			
		||||
    std::unique_ptr<FunctionBody>* fbody);
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										63
									
								
								tensorflow/core/common_runtime/function_def_utils.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								tensorflow/core/common_runtime/function_def_utils.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_body.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
 | 
			
		||||
#include "tensorflow/core/framework/function.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_util.h"
 | 
			
		||||
#include "tensorflow/core/graph/control_flow.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
Status FunctionDefToBodyHelper(
 | 
			
		||||
    const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
    const FunctionLibraryDefinition* const lib_def,
 | 
			
		||||
    const std::function<Status(const string&, const OpDef**)>& get_func_sig,
 | 
			
		||||
    std::unique_ptr<FunctionBody>* fbody) {
 | 
			
		||||
  // Instantiates the function template into a graph def.
 | 
			
		||||
  InstantiationResult result;
 | 
			
		||||
  TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Graph> graph(new Graph(lib_def));
 | 
			
		||||
  GraphConstructorOptions opts;
 | 
			
		||||
  opts.allow_internal_ops = true;
 | 
			
		||||
  opts.expect_device_spec = false;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
 | 
			
		||||
 | 
			
		||||
  // Call BuildControlFlowInfo to validate that this function body has
 | 
			
		||||
  // well-formed control flow.
 | 
			
		||||
  std::vector<ControlFlowInfo> dummy;
 | 
			
		||||
  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
 | 
			
		||||
 | 
			
		||||
  *fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types,
 | 
			
		||||
                                           result.ret_types, graph.release());
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
                               const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
                               std::unique_ptr<FunctionBody>* fbody) {
 | 
			
		||||
  const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
 | 
			
		||||
    return lib_def->LookUpOpDef(op, sig);
 | 
			
		||||
  };
 | 
			
		||||
  return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
							
								
								
									
										49
									
								
								tensorflow/core/common_runtime/function_def_utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tensorflow/core/common_runtime/function_def_utils.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
 | 
			
		||||
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class AttrSlice;
 | 
			
		||||
struct FunctionBody;
 | 
			
		||||
class FunctionDef;
 | 
			
		||||
class FunctionLibraryDefinition;
 | 
			
		||||
class OpDef;
 | 
			
		||||
 | 
			
		||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
 | 
			
		||||
// FunctionBody that holds the instantiated FunctionDef.
 | 
			
		||||
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
                               const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
                               std::unique_ptr<FunctionBody>* fbody);
 | 
			
		||||
 | 
			
		||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
 | 
			
		||||
// FunctionBody that holds the instantiated FunctionDef. Use custom function
 | 
			
		||||
// signature lookup, in case instantiated function is not in the 'lib_def'.
 | 
			
		||||
Status FunctionDefToBodyHelper(
 | 
			
		||||
    const FunctionDef& fdef, const AttrSlice& attrs,
 | 
			
		||||
    const FunctionLibraryDefinition* lib_def,
 | 
			
		||||
    const std::function<Status(const string&, const OpDef**)>& get_func_sig,
 | 
			
		||||
    std::unique_ptr<FunctionBody>* fbody);
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_
 | 
			
		||||
@ -43,6 +43,11 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
/*static*/ constexpr const char* const
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
 | 
			
		||||
/*static*/ constexpr const char* const
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
// A few string constant used throughout this module.
 | 
			
		||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
 | 
			
		||||
 | 
			
		||||
@ -231,6 +231,13 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
 | 
			
		||||
  return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct LowerFunctionalOpsConstants {
 | 
			
		||||
  static constexpr const char* const kLowerUsingSwitchMergeAttr =
 | 
			
		||||
      "_lower_using_switch_merge";
 | 
			
		||||
  static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
      "_lower_as_multi_device_function";
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,8 @@ limitations under the License.
 | 
			
		||||
#include "absl/strings/str_join.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/device.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_body.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
 | 
			
		||||
#include "tensorflow/core/framework/function.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_util.h"
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_builder.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
#include "tensorflow/core/graph/node_builder.h"
 | 
			
		||||
@ -29,7 +28,7 @@ namespace {
 | 
			
		||||
using NodeOut = NodeBuilder::NodeOut;
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
// Convenience builder to make it easy to construct a case with a single
 | 
			
		||||
// function call in each branch. This first converts the Case node
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,13 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class Graph;
 | 
			
		||||
class Node;
 | 
			
		||||
 | 
			
		||||
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
 | 
			
		||||
Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,9 +16,8 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
 | 
			
		||||
 | 
			
		||||
#include "absl/algorithm/container.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_util.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph_node_util.h"
 | 
			
		||||
@ -30,7 +29,7 @@ using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
 | 
			
		||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
bool LowerAsMultiDeviceFunction(const Node* n) {
 | 
			
		||||
  if (n->IsPartitionedCall()) return true;
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,14 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class FunctionLibraryDefinition;
 | 
			
		||||
class Graph;
 | 
			
		||||
class Node;
 | 
			
		||||
 | 
			
		||||
// Replaces function call node `n` with its function body. Uses
 | 
			
		||||
// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function
 | 
			
		||||
// inlining is not possible or safe (see ValidateInlining), leaves the graph in
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,8 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
 | 
			
		||||
@ -27,17 +28,12 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
/*static*/ constexpr const char* const
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
 | 
			
		||||
/*static*/ constexpr const char* const
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kLowerUsingSwitchMergeAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
 | 
			
		||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
 | 
			
		||||
constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
 | 
			
		||||
@ -173,7 +169,7 @@ Status LowerFunctionalOpsPass::Run(
 | 
			
		||||
      DCHECK(!lower_control_flow(n))
 | 
			
		||||
          << "Node " << FormatNodeForError(*n) << " of type "
 | 
			
		||||
          << n->type_string() << " has '"
 | 
			
		||||
          << LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
 | 
			
		||||
          << LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr
 | 
			
		||||
          << "' attr set but it does not support lowering.\n";
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "absl/types/optional.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
@ -38,9 +39,9 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass {
 | 
			
		||||
  Status Run(const GraphOptimizationPassOptions& options) override;
 | 
			
		||||
 | 
			
		||||
  static constexpr const char* const kLowerUsingSwitchMergeAttr =
 | 
			
		||||
      "_lower_using_switch_merge";
 | 
			
		||||
      LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
 | 
			
		||||
  static constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
      "_lower_as_multi_device_function";
 | 
			
		||||
      LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // If defined use the value to control if functional ops must be fetchable
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_builder.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
#include "tensorflow/core/graph/node_builder.h"
 | 
			
		||||
@ -27,7 +26,7 @@ namespace {
 | 
			
		||||
using NodeOut = NodeBuilder::NodeOut;
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
// Convenience builder to make it easy to construct a conditional with a single
 | 
			
		||||
// function call in the then and else branch. This first converts the if node
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,13 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class Graph;
 | 
			
		||||
class Node;
 | 
			
		||||
 | 
			
		||||
// Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
 | 
			
		||||
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_builder.h"
 | 
			
		||||
#include "tensorflow/core/framework/types.pb.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
@ -30,7 +28,7 @@ namespace {
 | 
			
		||||
using NodeOut = NodeBuilder::NodeOut;
 | 
			
		||||
 | 
			
		||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
 | 
			
		||||
    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
    LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
 | 
			
		||||
 | 
			
		||||
// Helper to convert a functional While op to its lowered form.
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,13 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
 | 
			
		||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class Graph;
 | 
			
		||||
class Node;
 | 
			
		||||
 | 
			
		||||
// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
 | 
			
		||||
// Merge, NextIteration and LoopCond nodes.
 | 
			
		||||
Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "absl/types/optional.h"
 | 
			
		||||
#include "tensorflow/core/common_runtime/function.h"
 | 
			
		||||
#include "tensorflow/core/framework/function.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def_builder.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user