From d6027bd76ab7b6e9d28a301722740c707d620cbc Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 27 Apr 2020 10:40:33 -0700 Subject: [PATCH] [Build cleanup] Split "core_cpu_impl" into fine-grained targets (3/n). This change splits many (but not all) of the function-related targets into separate cc_library targets. The main changes are: * Move "graph/graph_constructor.{h,cc}" to "common_runtime/graph_constructor.{h,cc}" and leave a forwarding header. This code depends on common_runtime and is built as part of it, so it makes sense to move it across. The "graph_constructor" library includes "shape_refiner.{h,cc}", "graph_runner.{h,cc}", and "eval_const_tensor.{h,cc}" because of a circular dependency between these modules. * Split "function.{h,cc}" into "function_body.{h,cc}", "function_utils.{h,cc}", and "inline_function_utils.{h,cc}" (plus the original, slimmed-down module). This enables other targets in common_runtime to depend on just the function utilities they need, without the whole runtime, which breaks some cycles. * New fine-grained targets for "constant_folding", "function_optimization_registry", and "graph_optimizer". PiperOrigin-RevId: 308651243 Change-Id: Iac59c1db4ebdd16609f89d6caee6b7e6ba7ff0a1 --- tensorflow/c/BUILD | 1 + tensorflow/c/c_api_function_test.cc | 1 + tensorflow/compiler/tf2tensorrt/BUILD | 1 + tensorflow/core/common_runtime/BUILD | 184 ++- .../core/common_runtime/constant_folding.cc | 2 +- .../core/common_runtime/direct_session.cc | 2 +- .../core/common_runtime/executor_test.cc | 2 +- tensorflow/core/common_runtime/function.cc | 1124 +---------------- tensorflow/core/common_runtime/function.h | 302 +---- .../core/common_runtime/function_body.cc | 64 + .../core/common_runtime/function_body.h | 52 + .../core/common_runtime/function_test.cc | 2 +- .../function_threadpool_test.cc | 5 +- .../core/common_runtime/function_utils.cc | 368 ++++++ .../core/common_runtime/function_utils.h | 105 ++ .../graph_constructor.cc | 2 +- .../core/common_runtime/graph_constructor.h | 204 +++ .../common_runtime/graph_execution_state.cc | 2 +- .../core/common_runtime/graph_optimizer.cc | 20 +- .../core/common_runtime/graph_optimizer.h | 11 + .../core/common_runtime/graph_runner.cc | 6 +- tensorflow/core/common_runtime/graph_runner.h | 7 +- .../common_runtime/inline_function_utils.cc | 865 +++++++++++++ .../common_runtime/inline_function_utils.h | 236 ++++ ...lacer_inspection_required_ops_pass_test.cc | 2 +- .../core/common_runtime/lower_case_op_test.cc | 5 +- .../common_runtime/lower_function_call_op.cc | 1 + .../lower_function_call_op_test.cc | 5 +- .../lower_functional_ops_test.cc | 2 +- .../core/common_runtime/lower_if_op_test.cc | 5 +- .../common_runtime/lower_while_op_test.cc | 5 +- .../core/common_runtime/partitioning_utils.cc | 2 +- ...acer_inspection_required_ops_utils_test.cc | 2 +- tensorflow/core/common_runtime/placer_test.cc | 2 +- .../process_function_library_runtime.cc | 2 +- .../core/common_runtime/shape_refiner.cc | 5 +- tensorflow/core/graph/BUILD | 4 +- tensorflow/core/graph/graph_constructor.h | 185 +-- tensorflow/core/graph/subgraph.cc | 1 - tensorflow/python/BUILD | 1 + .../tools/def_file_filter/symbols_pybind.txt | 2 +- 41 files changed, 2131 insertions(+), 1668 deletions(-) create mode 100644 tensorflow/core/common_runtime/function_body.cc create mode 100644 tensorflow/core/common_runtime/function_body.h create mode 100644 tensorflow/core/common_runtime/function_utils.cc create mode 100644 tensorflow/core/common_runtime/function_utils.h rename tensorflow/core/{graph => common_runtime}/graph_constructor.cc (99%) create mode 100644 tensorflow/core/common_runtime/graph_constructor.h create mode 100644 tensorflow/core/common_runtime/inline_function_utils.cc create mode 100644 tensorflow/core/common_runtime/inline_function_utils.h diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 867d575cac2..aafa89cc3d8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -604,6 +604,7 @@ tf_cc_test( ":c_api", ":c_api_internal", ":c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index bbf645200c6..3fff9bcd371 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index da32f620e72..582ebbbe1bd 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -491,6 +491,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu", "//tensorflow/core/grappler/costs:graph_properties", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index f1b203d3c47..2edeb1fbff2 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -58,6 +58,7 @@ tf_cuda_library( "device_factory.h", "function.h", "function_optimization_registry.h", + "graph_constructor.h", "optimization_registry.h", "shape_refiner.h", "//tensorflow/core/graph:core_cpu_headers", @@ -153,7 +154,12 @@ filegroup( "device_set.h", "eval_const_tensor.h", "function.h", + "function_body.h", + "function_utils.h", + "graph_constructor.h", + "graph_optimizer.h", "graph_runner.h", + "inline_function_utils.h", "metrics.h", "process_function_library_runtime.h", "scoped_allocator.h", @@ -167,9 +173,6 @@ filegroup( tf_cuda_library( name = "core_cpu_base_no_ops", srcs = [ - "eval_const_tensor.cc", - "graph_optimizer.h", - "shape_refiner.cc", "//tensorflow/core/graph:core_cpu_base_no_ops_srcs", "//tensorflow/core/public:session_options.h", "//tensorflow/core/public:version.h", @@ -190,6 +193,7 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_set", "//third_party/eigen3", ] + if_static([ + ":graph_constructor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ]), @@ -221,7 +225,6 @@ filegroup( "executor.h", "executor_factory.h", "function_optimization_registry.h", - "graph_optimizer.h", "input_colocation_exemption_registry.h", "isolate_placer_inspection_required_ops_pass.h", "local_device.h", @@ -390,6 +393,27 @@ cc_library( ], ) +cc_library( + name = "constant_folding", + srcs = ["constant_folding.cc"], + hdrs = ["constant_folding.h"], + copts = tf_copts(), + deps = [ + ":device", + ":device_factory", + ":executor", + ":function_utils", + ":graph_constructor", + ":memory_types", + ":rendezvous_mgr", + ":session_options", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "costmodel_manager", srcs = ["costmodel_manager.cc"], @@ -524,19 +548,6 @@ cc_library( ], ) -cc_library( - name = "graph_view", - srcs = ["graph_view.cc"], - hdrs = ["graph_view.h"], - copts = tf_copts(), - deps = [ - ":device", - "//tensorflow/core:framework", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - ], -) - cc_library( name = "device_set", srcs = ["device_set.cc"], @@ -557,6 +568,112 @@ cc_library( deps = ["//tensorflow/core:framework"], ) +cc_library( + name = "function_body", + srcs = ["function_body.cc"], + hdrs = ["function_body.h"], + copts = tf_copts(), + deps = [ + ":device", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "function_optimization_registry", + srcs = ["function_optimization_registry.cc"], + hdrs = ["function_optimization_registry.h"], + deps = [ + ":device_set", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "function_utils", + srcs = ["function_utils.cc"], + hdrs = ["function_utils.h"], + deps = [ + ":function_body", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# This library also includes "eval_const_tensor", "graph_runner", and +# "shape_refiner", because there are circular dependencies between these +# modules. +cc_library( + name = "graph_constructor", + srcs = [ + "eval_const_tensor.cc", + "graph_constructor.cc", + "graph_runner.cc", + "shape_refiner.cc", + "//tensorflow/core/framework:versions.h", + ], + hdrs = [ + "eval_const_tensor.h", + "graph_constructor.h", + "graph_runner.h", + "shape_refiner.h", + ], + copts = tf_copts(), + deps = [ + ":device", + ":device_factory", + ":executor", + ":function_utils", + ":memory_types", + ":rendezvous_mgr", + ":session_options", + ":single_threaded_cpu_device", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "graph_optimizer", + srcs = ["graph_optimizer.cc"], + hdrs = ["graph_optimizer.h"], + copts = tf_copts(), + deps = [ + ":constant_folding", + ":function_utils", + ":graph_constructor", + ":inline_function_utils", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "graph_view", + srcs = ["graph_view.cc"], + hdrs = ["graph_view.h"], + copts = tf_copts(), + deps = [ + ":device", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hierarchical_tree_broadcaster", srcs = ["hierarchical_tree_broadcaster.cc"], @@ -592,6 +709,29 @@ cc_library( ], ) +cc_library( + name = "inline_function_utils", + srcs = ["inline_function_utils.cc"], + hdrs = ["inline_function_utils.h"], + copts = tf_copts(), + deps = [ + ":device", + ":function_body", + ":function_utils", + ":graph_constructor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "input_colocation_exemption_registry", srcs = ["input_colocation_exemption_registry.cc"], @@ -685,6 +825,7 @@ cc_library( copts = tf_copts(), deps = [ ":device_set", + ":graph_constructor", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -1062,11 +1203,7 @@ tf_cuda_library( srcs = [ "colocation_graph.cc", "composite_device.cc", - "constant_folding.cc", "function.cc", - "function_optimization_registry.cc", - "graph_optimizer.cc", - "graph_runner.cc", "inspecting_placer.cc", "isolate_placer_inspection_required_ops_pass.cc", "lower_case_op.cc", @@ -1087,9 +1224,14 @@ tf_cuda_library( ":entry", ":executor", ":executor_factory", + ":function_body", + ":function_optimization_registry", + ":graph_constructor", + ":graph_optimizer", ":graph_view", ":local_executor_params", ":immutable_executor_state", + ":inline_function_utils", ":input_colocation_exemption_registry", ":pending_counts", ":propagator_debug_utils", diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 467147921be..f87efb369ed 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index c1426394a17..d104e0a985f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/metrics.h" @@ -49,7 +50,6 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/graph/tensor_id.h" diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index fe62a8459f1..79dbdd3bf44 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 7fb1328a519..a9e12d2385b 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -24,7 +24,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/collective.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/gradients.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/optimizer_cse.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -95,30 +96,6 @@ struct EndpointEq { // The following Add* routines are used to add a few graph nodes while // functions are transformed. -static Node* AddNoOp(StringPiece name, Graph* g) { - NodeDef ndef; - ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); - ndef.set_op("NoOp"); - Status s; - Node* ret = g->AddNode(ndef, &s); - TF_CHECK_OK(s); - return ret; -} - -static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { - DCHECK_LT(0, input.dtype()); - NodeDef ndef; - ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); - ndef.set_op("Identity"); - ndef.add_input(input.name()); - AddNodeAttr("T", BaseType(input.dtype()), &ndef); - Status s; - Node* ret = g->AddNode(ndef, &s); - TF_CHECK_OK(s); - g->AddEdge(input.node, input.index, ret, 0); - return ret; -} - static Node* AddArg(Graph* g, DataType dtype, int index) { DCHECK_LT(0, dtype); DCHECK_LT(dtype, DT_FLOAT_REF); @@ -859,32 +836,6 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { return parent_status; } -void DumpGraph(StringPiece label, const Graph* g) { - // TODO(zhifengc): Change Graph to record #nodes. - VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " - << g->num_edges(); - if (VLOG_IS_ON(5)) { - for (const auto& line : str_util::Split(DebugString(g), '\n')) { - VLOG(5) << "|| " << line; - } - } -} - -void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g, - const GraphOptimizer::Options& graph_optimizer_options) { - OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - optimizer.Optimize(lib, lib->env(), lib->device(), g, - graph_optimizer_options); -} - -void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g) { - OptimizeGraph(lib, g, GraphOptimizer::Options()); -} - namespace { // Removes all stateless nodes that do not contribute to a return // value from the function body. Unlike `RemoveDeadNodes()`, which is @@ -1316,1077 +1267,6 @@ std::unique_ptr NewFunctionLibraryRuntime( optimizer_options, custom_kernel_creator, session_metadata, parent)); } -bool RemoveDeadNodes(Graph* g) { - VLOG(2) << "Removing dead nodes"; - std::unordered_set nodes; - for (auto n : g->nodes()) { - if (n->IsSource() || n->IsSink() || n->IsControlFlow() || - n->op_def().is_stateful()) { - nodes.insert(n); - } - } - return PruneForReverseReachability(g, std::move(nodes)); -} - -namespace { -// If 'edges' contains only 1 non-control edge, returns it. Otherwise, -// returns a nullptr. -const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { - const Edge* ret = nullptr; - for (const Edge* e : edges) { - if (e->IsControlEdge() || ret) { - // Don't touch it if there is a control edge. - return nullptr; - } - if (IsRefType(e->src()->output_type(e->src_output()))) { - // Don't touch it if the identity node is effectively de-reffing - // a ref. - return nullptr; - } - if (IsRecv(e->src()) || IsSwitch(e->src())) { - // Don't touch it if the identity is introduced for control flow. - // Recv disables all its successors if it receives a dead signal. - // When Recv has an outgoing control edge, the current executor - // would not disable the destination. The current solution (see - // graph_partition.cc) is to add an identity after Recv and change - // the control edge to be from this identity node. So the identity - // can't be removed. - return nullptr; - } - ret = e; - } - return ret; -} -} // end namespace - -bool RemoveIdentityNodes(Graph* g) { - VLOG(2) << "Removing identity nodes"; - bool removed_any = false; - gtl::InlinedVector matches; - for (Node* n : g->nodes()) { - if (!n->IsIdentity()) continue; - if (!GetTheOnlyDataEdge(n->in_edges())) continue; - - // Some identity nodes are used as sink nodes to give names to output - // tensors. These nodes are not going to be executed unless they are in the - // fetch set. But if they are in the fetch set we don't want to remove them. - if (n->out_edges().empty()) continue; - - matches.push_back(n); - } - if (!matches.empty()) { - for (Node* n : matches) { - const Edge* in = GetTheOnlyDataEdge(n->in_edges()); - for (const Edge* out : n->out_edges()) { - if (out->IsControlEdge()) { - g->AddControlEdge(in->src(), out->dst()); - } else { - g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); - } - } - VLOG(2) << "Remove Identity: " << n->DebugString(); - g->RemoveNode(n); - removed_any = true; - } - } - return removed_any; -} - -bool RemoveListArrayConverter(Graph* g) { - VLOG(2) << "Removing list array converter"; - gtl::InlinedVector matches; - for (Node* n : g->nodes()) { - if ((n->type_string() == "_ListToArray") || - (n->type_string() == "_ArrayToList")) { - matches.push_back(n); - } - } - bool removed_any = false; - if (!matches.empty()) { - for (Node* n : matches) { - if (n->num_inputs() != n->num_outputs()) { - continue; // Not expected. Skip. - } - gtl::InlinedVector identity_nodes(n->num_inputs(), nullptr); - - const auto no_op = [&](StringPiece name) -> Node* { - return AddNoOp(absl::StrCat(n->name(), "/", name), g); - }; - - const auto identity = [&](StringPiece name, Endpoint input) -> Node* { - Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input); - node->set_requested_device(input.node->def().device()); - return node; - }; - - // Process input edges first. - Node* input_control_node = nullptr; - for (const Edge* e : n->in_edges()) { - if (e->IsControlEdge()) { - if (input_control_node == nullptr) { - // If node "n" has any control dependencies, adds a no-op - // node (input_control_node) which the additional Identity - // nodes depends on and the input_control_node depends on - // the node "n"s control dependencies. - input_control_node = no_op("input_control_node"); - } - g->AddControlEdge(e->src(), input_control_node); - } else { - const int index = e->dst_input(); - Node** id_node = &identity_nodes[index]; - if (*id_node != nullptr) { - LOG(ERROR) - << "RemoveListArrayConverter unexpected duplicated input: " - << e->dst_input(); - return removed_any; - } - *id_node = identity("input", {e->src(), e->src_output()}); - } - } - - // If node "n" has any control dependencies, the added identity - // nodes should have control dependencies on input_control_node. - if (input_control_node != nullptr) { - for (Node* id : identity_nodes) { - g->AddControlEdge(input_control_node, id); - } - } - - Node* output_control_node = nullptr; - for (const Edge* e : n->out_edges()) { - if (e->IsControlEdge()) { - if (output_control_node == nullptr) { - // If node "n" is control-depended upon by other nodes, - // adds a no-op node (output_control_node) which those - // nodes will depend on and output_control_node depends on - // all Identity nodes. - output_control_node = no_op("output_control_node"); - } - g->AddControlEdge(output_control_node, e->dst()); - } else { - Node* id_node = identity_nodes[e->src_output()]; - if (id_node == nullptr) { - LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " - << e->src_output(); - return removed_any; - } - CHECK(id_node); - g->AddEdge(id_node, 0, e->dst(), e->dst_input()); - } - } - - // If any nodes have control dependencies on node "n", those - // nodes should have control dependencies on - // output_control_node. - if (output_control_node != nullptr) { - for (Node* id : identity_nodes) { - g->AddControlEdge(id, output_control_node); - } - } - - g->RemoveNode(n); - removed_any = true; - } - } - return removed_any; -} - -Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, - NameAttrList* function) { - if (call_def.op() == "PartitionedCall" || - call_def.op() == "StatefulPartitionedCall") { - TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function)); - } else { - function->set_name(call_def.op()); - *function->mutable_attr() = call_def.attr(); - } - return Status::OK(); -} - -Status InstantiateFunctionCall(const NodeDef& call_def, - FunctionLibraryRuntime* flr, - FunctionLibraryRuntime::Handle* handle) { - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function)); - return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle); -} - -namespace { - -std::vector InputDevices(const Node& caller) { - std::vector input_devices(caller.in_edges().size()); - std::vector input_tensors(caller.in_edges().size()); - - for (const Edge* edge : caller.in_edges()) { - if (edge->IsControlEdge()) continue; - const string& input_device = edge->src()->has_assigned_device_name() - ? edge->src()->assigned_device_name() - : edge->src()->requested_device(); - input_devices[edge->dst_input()] = input_device; - input_tensors[edge->dst_input()] = - absl::StrCat(edge->src()->name(), ":", edge->src_output()); - } - - if (VLOG_IS_ON(4)) { - VLOG(4) << "Function instantiation input devices:"; - for (int i = 0; i < input_devices.size(); ++i) { - if (input_tensors[i].empty()) continue; // skip control edges - VLOG(4) << " [index " << i << "]" - << " device: " << input_devices[i] - << " (input: " << input_tensors[i] << ")"; - } - } - - return input_devices; -} - -// Place input nodes on the same device as the corresponding caller input -// node. Do not specify any placement for all other nodes. -class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer { - public: - explicit DefaultFunctionBodyPlacer(const Node& caller) - : input_devices_(InputDevices(caller)) {} - - absl::optional InputNodeDevice(int input_index) const override { - return input_devices_[input_index]; - } - absl::optional OutputNodeDevice(int output_index) const override { - return absl::nullopt; - } - bool ColocateInputOutputIdentities() const override { return false; } - absl::optional ControlNodeDevice() const override { - return absl::nullopt; - } - absl::optional BodyNodeDevice(const NodeDef& ndef) const override { - return absl::nullopt; - } - - private: - const std::vector input_devices_; -}; - -// Place all nodes on the same device as caller node. -class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { - public: - explicit SingleDeviceFunctionBodyPlacer(const Node& caller) - : caller_device_(caller.def().device()) {} - - absl::optional InputNodeDevice(int input_index) const override { - return caller_device_; - } - absl::optional OutputNodeDevice(int output_index) const override { - return caller_device_; - } - bool ColocateInputOutputIdentities() const override { return false; } - absl::optional ControlNodeDevice() const override { - return caller_device_; - } - absl::optional BodyNodeDevice(const NodeDef& ndef) const override { - return caller_device_; - } - - private: - const string caller_device_; -}; - -// Place input nodes on the same device as the corresponding caller input -// node. Do not place output node. Place control nodes on the same device as -// caller node. For all function body nodes overrides job, replica and task -// parts of the device assignment to match function caller node. -class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { - public: - explicit MultiDeviceFunctionBodyPlacer(const Node& caller) - : caller_device_(caller.def().device()), - input_devices_(InputDevices(caller)) { - has_parsed_caller_device_ = - DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_); - } - - absl::optional InputNodeDevice(int input_index) const override { - return input_devices_[input_index]; - } - absl::optional OutputNodeDevice(int output_index) const override { - return absl::nullopt; - } - bool ColocateInputOutputIdentities() const override { return true; } - absl::optional ControlNodeDevice() const override { - return caller_device_; - } - absl::optional BodyNodeDevice(const NodeDef& ndef) const override { - // TODO(ezhulenev): If function would have been instantiated as a - // multi-device function and executed via FunctionLibraryRuntime, it could - // be potentially placed on any available device. However there are multiple - // tests relying on this assumption. Fix them, and remove this line. - if (ndef.device().empty()) return caller_device_; - - if (!has_parsed_caller_device_) return ndef.device(); - - DeviceNameUtils::ParsedName ndef_parsed_device; - if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device)) - return ndef.device(); - - if (caller_parsed_device_.has_job) { - ndef_parsed_device.has_job = caller_parsed_device_.has_job; - ndef_parsed_device.job = caller_parsed_device_.job; - } - - if (caller_parsed_device_.has_replica) { - ndef_parsed_device.has_replica = caller_parsed_device_.has_replica; - ndef_parsed_device.replica = caller_parsed_device_.replica; - } - - if (caller_parsed_device_.has_task) { - ndef_parsed_device.has_task = caller_parsed_device_.has_task; - ndef_parsed_device.task = caller_parsed_device_.task; - } - return DeviceNameUtils::ParsedNameToString(ndef_parsed_device); - } - - private: - string caller_device_; - bool has_parsed_caller_device_; - DeviceNameUtils::ParsedName caller_parsed_device_; - std::vector input_devices_; -}; - -} // namespace - -std::unique_ptr -InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph, - const Node& caller) { - VLOG(3) << "Create default placer for inlined function body."; - return absl::make_unique(caller); -} - -std::unique_ptr -InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph, - const Node& caller) { - VLOG(3) << "Create single device placer for inlined function body."; - return absl::make_unique(caller); -} - -std::unique_ptr -InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph, - const Node& caller) { - VLOG(3) << "Create multi device placer for inlined function body."; - return absl::make_unique(caller); -} - -namespace { - -Status ValidateNoInline(const FunctionBody* fbody) { - const auto attr = AttrSlice(&fbody->fdef.attr()); - bool noinline = false; - if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) { - return errors::InvalidArgument( - "Can't inline function marked with '_noinline'"); - } - return Status::OK(); -} - -using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; - -// Propagate the debug info of `nodes` in function `func` to the `target` node. -// If the debug info of any node is missing, its node name and function name -// is used. -void PropagateDebugInfoToNode(const string& func, - const std::vector& nodes, - NodeDef* target) { - if (nodes.empty() || target->has_experimental_debug_info()) { - return; - } - for (const Node* node : nodes) { - const auto& node_def = node->def(); - if (node_def.has_experimental_debug_info()) { - target->mutable_experimental_debug_info()->MergeFrom( - node_def.experimental_debug_info()); - } else { - target->mutable_experimental_debug_info()->add_original_node_names( - node_def.name()); - target->mutable_experimental_debug_info()->add_original_func_names(func); - } - } -} -} // namespace - -string InlineFunctionBodyOptions::DebugString() const { - const auto true_false = [](bool b) { return b ? "true" : "false"; }; - - const auto keep_caller_node_str = [this]() -> string { - switch (keep_caller_node) { - case KeepCallerNode::kDoNotKeep: - return "DoNotKeep"; - case KeepCallerNode::kFetchable: - return "Fetchable"; - case KeepCallerNode::kTargetable: - return "Targetable"; - } - }; - - return absl::StrCat( - "disable_inlining=", true_false(disable_inlining), - ", ignore_noinline=", true_false(ignore_noinline), - ", inline_impl_selection_group_functions=", - true_false(inline_impl_selection_group_functions), - ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=", - output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs" - : "ControlOutputs", - ", inlined_function_body_placer=", inlined_function_body_placer.name, - ", uniquify_frame_names=", true_false(uniquify_frame_names)); -} - -Status ValidateInlining(const Node* node, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options) { - // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee - // that all side-effectful ops will be executed after inlining. See Grappler - // function_optimizer for details. Unify all function inlining mechanism. - // Do not inline if `!fbody->control_ret_nodes.empty()`. - - const auto num_node_inputs = static_cast(node->num_inputs()); - const auto num_node_outputs = static_cast(node->num_outputs()); - - if (num_node_inputs != fbody->arg_types.size() || - num_node_inputs != fbody->arg_nodes.size()) { - return errors::InvalidArgument( - "Node inputs do not match function arguments: inputs=", num_node_inputs, - " arg_types=", fbody->arg_types.size(), - " arg_nodes=", fbody->arg_nodes.size()); - } - - if (num_node_outputs != fbody->ret_types.size() || - num_node_outputs != fbody->ret_nodes.size()) { - return errors::InvalidArgument( - "Node outputs do not match function returns: outputs=", - num_node_outputs, " ret_types=", fbody->ret_types.size(), - " ret_nodes=", fbody->ret_nodes.size()); - } - - for (int i = 0; i < node->num_inputs(); ++i) { - if (node->input_type(i) != fbody->arg_types[i]) { - return errors::InvalidArgument( - "Node input type doesn't match function argument type: ", - node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i); - } - } - for (int i = 0; i < node->num_outputs(); ++i) { - if (node->output_type(i) != fbody->ret_types[i]) { - return errors::InvalidArgument( - "Node output type doesn't match function return type: ", - node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i); - } - } - - if (options.disable_inlining) { - return errors::InvalidArgument( - "Function inlining explicitly disabled by 'options.disable_inlining'"); - } - - if (!options.inline_impl_selection_group_functions) { - bool is_impl_selection_group_function = - fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end(); - if (is_impl_selection_group_function) { - return errors::InvalidArgument( - "Inlining of implementation selection group function ", - fbody->fdef.signature().name(), - " is disabled by options.inline_impl_selection_group_functions"); - } - } - - if (!options.ignore_noinline) { - TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); - } - - return Status::OK(); -} - -// Function inlining must preserve function execution semantics with regards to -// side-effects visibility. Tensorflow in Eager mode has an automatic control -// dependencies tracking mechanism, which enforces well-defined execution order -// of all side-effects. Any other frontend (e.g. Swift) must produce graphs -// following the same rules, to ensure that function inlining works correctly. -// -// IMPORTANT: Currently we do not have a true notion of "side-effectful" node, -// we assume that all stateful nodes might have side-effects, though it's not -// true in practice, e.g. `ReadVariableOp` doesn't have an observable -// side-effect. -// -// Automatic control dependency rules in Tensorflow 2.0 (python in eager mode): -// -// 1) When a function has a resource (DT_RESOURCE data type) input argument it -// "captures" the mutable resource. This is implemented by automatically -// adding a incoming control edge from the previous side-effectful op -// touching that resource, and an outgoing control edge to the next -// side-effectful op using the same resource. This serializes the mutations -// of the resource to make graph execution deterministic. -// -// 2) All stateful ops inside a function body are guaranteed to execute in -// program order, this is achieved by adding control edges between stateful -// ops at graph construction time. Stateful ops (or ops that must execute) -// should be in the function control return set. Having a data edge to the -// regular function output might be not enough, because after function -// inlining it might happen that data output is unused. -// -// 3) Furthermore, all ops accepting the same resource as an input are -// guaranteed to run in program order. This is also done by adding control -// edges at graph construction time. The last op touching the resource -// must be in a control return set, which will guarantee that all side -// effects to the resource will happen before function completion. -// -// Function inlining must preserve side-effect visibility: -// -// 1) All side-effects to the captured resources, that happened before function -// call must be visible to the function body nodes using that resources. -// -// 2) All side-effects to the captured resources, that happened inside function -// body, must be visible to every op/function using that resource after the -// function call completed. -// -// To guarantee that these properties are preserved after inlining we: -// -// 1) Create "input_control_node" NoOp. Function call node incoming control -// edges will be forwarded *to* this node. Function inputs (Identity nodes) -// will have a control edge *from* this node. If function body has nodes -// without inputs, they will have a control edge *from* this node. -// -// 2) Create "output_control_node" NoOp. All nodes that have incoming control -// edge *from* the function call node, will be forwarded to this node. -// -// We have two options for choosing which nodes will have a control edge *to* -// the "output control node": -// a) control returns (`control_ret` field in FunctionDef) -// b) data returns (`ret` field in FunctionDef) -// -// We do a) for multi-device function calls in Tensorflow v2 and b) -// for the rest for compatibility with Tensorflow v1. -// -// Following the automatic control dependencies tracking rules, a node that -// has an incoming control edge from the function call node is dependent on -// the side-effects happening inside the function body. The output control -// node will guarantee side-effects execution order. -// -// If function call node doesn't have an outgoing control edge, it means that -// no one is interested in observing side-effects that might have happened. -// -// Function inlining might leave the graph in partially-placed state. Function -// inlining caller must call Placer to guarantee that all nodes are placed. -// -// Function inlining with `options.override_device=true` will leave graph in -// fully placed state, by overriding all inlined nodes devices with the caller -// node device, but it will make functions always single-device. These functions -// after inlining will not be able to handle resources on multiple devices. This -// is currently acceptable for XLA use cases (XLA cluster is always executed on -// a single device). -// -// TODO(ezhulenev): Documentation above is ahead of implementation below. -Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options) { - VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" - << options.DebugString() << "]"; - - Status validation = ValidateInlining(caller, fbody, options); - if (!validation.ok()) { - return errors::Internal("Inlining mismatch: ", validation.error_message()); - } - - // Placer is responsible for assigning devices for all nodes that we will add - // to the graph. - const std::unique_ptr placer = - options.inlined_function_body_placer.get(*g, *caller); - - // We can't possibly introduce a duplicate control edge during function - // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'. - static constexpr bool kDoNotCheckDuplicates = true; - - // ------------------------------------------------------------------------ // - // Helper functions to create `NoOp` and `Identity` nodes for auxiliary - // control nodes and inlined function inputs and outputs. - - // Add a NoOp node for function control inputs/outputs. - const auto no_op = [&](StringPiece name) -> Node* { - Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g); - const absl::optional device = placer->ControlNodeDevice(); - if (device.has_value()) node->set_requested_device(*device); - return node; - }; - - // Add an Identity node for function input. - const auto input_identity = [&](StringPiece name, Endpoint input, - int index) -> Node* { - Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); - const absl::optional device = placer->InputNodeDevice(index); - if (device.has_value()) node->set_requested_device(*device); - bool colocate_identity = placer->ColocateInputOutputIdentities(); - if (colocate_identity) { - node->AddAttr(kColocationAttrName, - std::vector{absl::StrCat(kColocationGroupPrefix, - input.node->name())}); - } - return node; - }; - - // Add an Identity node for function output. - const auto output_identity = [&](StringPiece name, Endpoint input, - int index) -> Node* { - Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); - const absl::optional device = placer->OutputNodeDevice(index); - if (device.has_value()) node->set_requested_device(*device); - bool colocate_identity = placer->ColocateInputOutputIdentities(); - if (colocate_identity) { - node->AddAttr(kColocationAttrName, - std::vector{absl::StrCat(kColocationGroupPrefix, - input.node->name())}); - } - return node; - }; - - // ------------------------------------------------------------------------ // - // Input edges. For data edges coming into "caller", we first compute the - // : for the i-th input in "inputs". - // If "caller" has any input control dependencies, we add a NoOp - // node "input_control_node", which depends on "caller"'s control inputs. - std::vector inputs(caller->num_inputs()); - Node* input_control_node = nullptr; - for (const Edge* e : caller->in_edges()) { - if (e->IsControlEdge()) { - if (input_control_node == nullptr) { - input_control_node = no_op("input_control_node"); - } - g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates); - } else { - inputs[e->dst_input()] = {e->src(), e->src_output()}; - } - } - if (input_control_node != nullptr) { - VLOG(3) << "Created input control node: " << input_control_node->name(); - } - - // ------------------------------------------------------------------------ // - // Duplicate fbody->graph into 'g'. First, we copy the nodes of - // fbody->graph into 'g' except the source and sink nodes. We copy - // edges among nodes in 'fbody->graph'. - // - // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we - // remember 'y' in node_map[x->id()]. - std::vector node_map(fbody->graph->num_node_ids()); - for (Node* n : fbody->graph->op_nodes()) { - NodeDef ndef = n->def(); - - // Maybe override requested node device assignment. - const absl::optional device = placer->BodyNodeDevice(ndef); - if (device.has_value()) ndef.set_device(*device); - - // Add inlined function name to inlined node debug information. - PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef); - - // Add the function node name as a prefix: - // 1) to node name to avoid collisions - // 2) to frame name to avoid multiple LoopCond nodes in one frame - // 3) to colocation attribute - const string prefix = strings::StrCat(caller->name(), "/"); - TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef, - options.uniquify_frame_names)); - - Status added_node; - Node* clone = g->AddNode(ndef, &added_node); - TF_CHECK_OK(added_node); - node_map[n->id()] = clone; - - // If there is an input control node, and one of: - // a) the node has no data or control inputs, or - // b) the node is a function call (including SymbolicGradient), - // then add a control edge from the input control node to the clone (only - // if it does not already have a control input). - // - // We must not execute any nodes if the original function call would not - // have executed. This is especially critical when the function call is - // inside a control-flow construct like tf.cond(). Case (a) ensures that - // such nodes do not run. - // - // The purpose of case (b) is to ensure that instances of case (a) created - // by further inlining steps also receive the control dependency. - // - // This edge is required to transfer execution frame down to all function - // body nodes of inlined nested function calls. - if (input_control_node) { - const auto is_input_edge = [](const Edge* e) -> bool { - return !e->src()->IsSource(); - }; - const auto is_control_edge = [](const Edge* e) -> bool { - return !e->src()->IsSource() && e->IsControlEdge(); - }; - - // Forward execution frame if: - // - // a) The node has no data or control inputs. - // b) OR the node is a function call without control inputs (control edge - // will be used in nested function inlining to forward execution frame - // to constants inside the function body). - // - // c) Do not forward control frame to function argument nodes, they will - // be connected to the corresponding function input later. - const bool forward_execution_frame = - (absl::c_none_of(n->in_edges(), is_input_edge) || // (a) - (n->IsFunctionCall() && // (b) - absl::c_none_of(n->in_edges(), is_control_edge))) && // - !n->IsArg(); // (c) - - if (forward_execution_frame) { - VLOG(4) << "Add control edge from input control node to: " - << clone->name(); - g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates); - } - } - } - for (const Edge* e : fbody->graph->edges()) { - if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || - e->dst()->IsSink()) { - continue; - } - Node* src_copy = node_map[e->src()->id()]; - Node* dst_copy = node_map[e->dst()->id()]; - g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); - } - - // ------------------------------------------------------------------------ // - // Connect input edges. - // - // We create one Identity node for each input. Then, we connect inputs[i] to - // the i-th identity node added. The nodes that previously connected - // to the j-th output of i-th arg node are reconnected to the i-th - // identity node. - // - // The added identity nodes depend on "input_control_node". - VLOG(4) << "Add input Identity nodes for each function argument:"; - for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { - Node* arg = node_map[fbody->arg_nodes[i]->id()]; - Node* n = input_identity("input", inputs[i], i); - VLOG(4) << " [index " << i << "] " - << fbody->fdef.signature().input_arg(i).name() << " as " - << n->name() << " (input: " << inputs[i].name() - << ", requested_device: " << n->requested_device() << ")"; - - if (input_control_node) { - g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates); - } - for (const Edge* e : arg->out_edges()) { - if (e->IsControlEdge()) { - g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates); - } else { - g->AddEdge(n, 0, e->dst(), e->dst_input()); - } - } - node_map[fbody->arg_nodes[i]->id()] = n; - g->RemoveNode(arg); // 'arg' is disconnected. - } - - // ------------------------------------------------------------------------ // - // Connect output edges. - // - // For i-th return node in fbody->graph, we add in "g" an identity node - // (outputs[i-th]). We then reconnect every incoming edge into the i-th return - // node to the added identity node. - // - // For every data edge coming out of "callee"s i-th output, we reconnect it to - // the i-th identity added above. - // - // If "callee" is control-depended upon by any other nodes, we add a NoOp node - // "output_control_node". "output_control_node" depends on all identity nodes - // added above or on all control return nodes (controlled by - // `options.output_control_src` value). And nodes previously depend on - // "callee" is changed to depend on "output_control_node". - // - // If `keep_node_fetchable` is `true` we always add an output control node, to - // guarantee that executing a fetchable node will execute all side-effects. - VLOG(4) << "Add output Identity nodes for each function output argument:"; - std::vector outputs(caller->num_outputs()); - for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { - Node* ret = node_map[fbody->ret_nodes[i]->id()]; - Endpoint data; // Data input for the ret node. - for (const Edge* e : ret->in_edges()) { - if (!e->IsControlEdge()) { - data = {e->src(), e->src_output()}; - break; - } - } - CHECK(data.node != nullptr); - Node* n = output_identity("output", data, i); - outputs[i] = n; - VLOG(4) << " [index " << i << "] " - << fbody->fdef.signature().output_arg(i).name() << " as " - << n->name() << " (ret: " << data.node->name() << ":" << data.index - << ", requested_device: " << n->requested_device() << ")"; - for (const Edge* e : ret->in_edges()) { - if (e->IsControlEdge()) { - g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates); - } - } - g->RemoveNode(ret); // 'ret' is disconnected. - } - - Node* output_control_node = nullptr; - const bool has_control_outputs = absl::c_any_of( - caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); }); - - using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; - const bool keep_caller_node = - options.keep_caller_node == KeepCallerNode::kFetchable || - options.keep_caller_node == KeepCallerNode::kTargetable; - - if (has_control_outputs || keep_caller_node) { - output_control_node = no_op("output_control_node"); - VLOG(4) << "Add output control node: " << output_control_node->name(); - if (options.output_control_src == OutputControlSrc::kDataOutputs) { - for (Node* n : outputs) { - VLOG(4) << " [data output] add control edge from: " << n->name(); - g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); - } - } else { - for (Node* fbody_node : fbody->control_ret_nodes) { - Node* n = node_map[fbody_node->id()]; - VLOG(4) << " [control output] add control edge from: " << n->name(); - g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); - } - } - } - - // We can't leave output control node without incoming control edges, because - // in this case outgoing control edge will loose execution frame information. - // We connect input_control_node and output_control_node with a control edge - // to forward execution frame to the controlled nodes. Above we add a control - // edge to all function calls inside function body, to guarantee that we will - // always have input_control_node when we need it. - if (output_control_node && output_control_node->in_edges().empty()) { - if (input_control_node) { - VLOG(4) - << "Add add a control edge between input and output control nodes: " - << input_control_node->name() << " to " - << output_control_node->name(); - g->AddControlEdge(input_control_node, output_control_node, - kDoNotCheckDuplicates); - } else { - VLOG(4) << "Function inlining potentially dropped execution frame " - "information from outgoing control edges."; - } - } - - for (const Edge* e : caller->out_edges()) { - if (e->IsControlEdge()) { - g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates); - } else { - g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); - } - } - - // ------------------------------------------------------------------------ // - // Add an IdentityN or NoOp node in-place of caller node to keep `caller` - // fetchable or targetable. - - if (keep_caller_node) { - std::vector output_tensors; - absl::c_transform(outputs, std::back_inserter(output_tensors), - [](Node* n) { return NodeBuilder::NodeOut(n, 0); }); - - Node* caller_substitute_node; - if (options.keep_caller_node == KeepCallerNode::kTargetable || - output_tensors.empty()) { - // IdentityN node must have at least one data input. If function has no - // data outputs, we can't keep it fetchable. - TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp") - .Device(caller->requested_device()) - .ControlInput(output_control_node) - .Finalize(g, &caller_substitute_node)); - - } else if (options.keep_caller_node == KeepCallerNode::kFetchable) { - TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN") - .Device(caller->requested_device()) - .Input(output_tensors) - .ControlInput(output_control_node) - .Finalize(g, &caller_substitute_node)); - } - } - - // ------------------------------------------------------------------------ // - // 'caller' is replaced with inlined function body nodes and maybe IdentityN - // to keep it fetchable. - VLOG(3) << "Successfully inlined function call node: " << caller->name(); - g->RemoveNode(caller); - - return Status::OK(); -} - -bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, - const Node& node) { - return node.IsFunctionCall(); -} - -bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, - const ExpandInlineFunctionsOptions& options) { - std::vector> candidates; - - const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition(); - - for (Node* node : graph->nodes()) { - // Skip nodes that are not function calls or SymbolicGradient calls. - if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) { - continue; - } - // Skip function calls that marked noinline. - bool noinline; - if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { - VLOG(3) << "noinline: " << SummarizeNode(*node); - continue; - } - FunctionLibraryRuntime::Handle handle; - Status s = InstantiateFunctionCall(node->def(), lib, &handle); - if (!s.ok()) { - LOG(ERROR) << "Failed to instantiate a function: " << s.error_message(); - continue; - } - const FunctionBody* fbody = lib->GetFunctionBody(handle); - CHECK_NOTNULL(fbody); - candidates.emplace_back(node, fbody); - } - - bool inlined_any = false; - for (const auto& p : candidates) { - Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, - p.first->IsPartitionedCall() - ? options.multi_device_options - : options.native_options); - if (inlined.ok()) { - inlined_any = true; - } else { - VLOG(1) << "Failed to inline function call: node=" << p.first->name() - << " error=" << inlined.error_message(); - } - } - - // TODO(ezhulenev): Release handles for inlined function calls. - - return inlined_any; -} - -string NewName(const Node* n, bool pretty) { - if (pretty) { - return strings::StrCat(n->type_string(), n->id()); - } else { - return strings::StrCat("n", n->id()); - } -} - -// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. -// and stash the original NodeDef name as an attr for documentation -// purpose. -void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { - // We visit nodes in forward topological sort order, which is a - // possible execution order of the graph. - gtl::InlinedVector inputs; - gdef->Clear(); - *gdef->mutable_versions() = g->versions(); - - std::vector start_nodes; - for (Node* n : g->nodes()) { - if (n->out_edges().empty()) { - start_nodes.push_back(n); - } - } - - ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) { - if (!n->IsOp()) return; - NodeDef* ndef = gdef->add_node(); - ndef->set_name(NewName(n, pretty)); - ndef->set_op(n->type_string()); - for (const auto& attr : n->attrs()) { - (*ndef->mutable_attr())[attr.first] = attr.second; - } - - if (!n->assigned_device_name().empty()) { - ndef->set_device(n->assigned_device_name()); - } else { - ndef->set_device(n->requested_device()); - } - - inputs.clear(); - inputs.resize(n->num_inputs()); - for (const Edge* e : n->in_edges()) { - if (e->IsControlEdge()) { - inputs.push_back(e); - } else { - if (inputs[e->dst_input()] == nullptr) { - inputs[e->dst_input()] = e; - } else { - LOG(WARNING) << "Malformed graph node. multiple input edges: " - << n->DebugString(); - } - } - } - // node->name() is merely NodeDef::name, which are not guaranteed - // to be unique and stable after optimization rewrites. Therefore, - // we use "n" instead. - for (const Edge* e : inputs) { - if (e == nullptr) { - ndef->add_input("unknown"); - continue; - } - const string srcname = NewName(e->src(), pretty); - if (!e->src()->IsOp()) { - } else if (e->IsControlEdge()) { - ndef->add_input(strings::StrCat("^", srcname)); - } else if (e->src_output() == 0) { - ndef->add_input(srcname); - } else { - ndef->add_input(strings::StrCat(srcname, ":", e->src_output())); - } - } - }); -} - -string DebugString(const Graph* g) { - GraphDef gdef; - ToGraphDef(g, &gdef); - return DebugString(gdef); -} - -FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, - DataTypeSlice ret_t, Graph* g) - : fdef(f), - graph(g), - arg_types(arg_t.begin(), arg_t.end()), - ret_types(ret_t.begin(), ret_t.end()) { - // 1. Find regular Arg/Ret nodes. - this->arg_nodes.resize(arg_types.size()); - this->ret_nodes.resize(ret_types.size()); - for (Node* n : this->graph->op_nodes()) { - gtl::InlinedVector* node_vec; - if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) { - node_vec = &this->ret_nodes; - } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) { - node_vec = &this->arg_nodes; - } else { - continue; - } - int index; - TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index)); - CHECK_LE(0, index); - CHECK_LT(index, node_vec->size()); - (*node_vec)[index] = n; - } - // 2. Find ControlRet nodes that must be always executed. - std::unordered_set control_ret_node_names; - for (const auto& control_ret : fdef.control_ret()) { - control_ret_node_names.insert(control_ret.second); - } - this->control_ret_nodes.reserve(control_ret_node_names.size()); - for (Node* n : this->graph->op_nodes()) { - if (control_ret_node_names.count(n->name()) > 0) { - this->control_ret_nodes.push_back(n); - } - } -} - -FunctionBody::~FunctionBody() { delete this->graph; } - class SymbolicGradientHelper { public: explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {} diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 9071a5cfa50..af4b333aca0 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -22,7 +22,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -30,8 +33,6 @@ limitations under the License. namespace tensorflow { -static constexpr const char* const kNoInlineAttr = "_noinline"; - // Get default customizable kernel creator if set const CustomKernelCreator* GetDefaultCustomKernelCreator(); @@ -67,88 +68,6 @@ std::unique_ptr NewFunctionLibraryRuntime( const SessionMetadata* session_metadata, ProcessFunctionLibraryRuntime* parent); -// FunctionLibraryRuntime::GetFunctionBody returns a description of an -// instantiated function that is represented as a Graph with arg/ret -// nodes annotated. -struct FunctionBody { - FunctionDef fdef; - Graph* graph = nullptr; // owned. - DataTypeVector arg_types; - DataTypeVector ret_types; - // arg_nodes[i] contains the i'th function input. In other words, - // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector arg_nodes; - // ret_nodes[i] contains the i'th function output. In other words, - // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector ret_nodes; - gtl::InlinedVector control_ret_nodes; - - FunctionBody() {} - FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, - DataTypeSlice ret_types, Graph* g); - ~FunctionBody(); -}; - -// Debugging facility. Returns a debug string for a graph -// representing an instantiated function. -string DebugString(const Graph* g); - -// A few hand-crafted optimization on the instantiated function body -// (a Graph*). - -// Removes nodes that are -// 1. not stateful; and -// 2. not _Arg; and -// 3. not reachable from _Retval. -// -// This function is triggered by function inlining, unlike 'PruneFunctionBody' -// it doesn't preserve nodes that are reachable from control returns. Function -// inlining is responsible for connecting control return nodes with the nodes -// that have input control edges from the inlined function call node. -// -// Assuming that automatic control dependency tracking is correct, absence of -// outgoing control edge from the function call node means that no one needs to -// observe side-effect that might have been generated by the function (see -// documentation in common_runtime/function.cc for details). -// -// Returns true iff any node is removed from "g". -bool RemoveDeadNodes(Graph* g); - -// Find a pattern: -// src -(in)-> node -(out)-> dst, where -// 1) node is an identity node; -// 2) in is the only incoming data edge; -// 3) out is the only outgoing data edge; -// -// Rewrites the above pattern with src->dst and relevant data -// dependencies updated. Repeat the process until no such pattern -// left. -bool RemoveIdentityNodes(Graph* g); - -// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. -bool RemoveListArrayConverter(Graph* g); - -// Dump the contents of the "graph" to log files if the logging level is -// sufficiently high. -void DumpGraph(StringPiece label, const Graph* g); - -// Applies graph rewrite optimization such as inlining, dead code -// removal, etc. -// -// **g is a graph constructed based on the runtime library 'lib'. -// OptimizeGraph mutates **g extensively and replaces '*g' with a -// complete copy. Therefore, the caller should not keep any references -// to nodes *g. -void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g, - const GraphOptimizer::Options& graph_optimizer_options); -void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g); - -// Convert the Graph of a function to a GraphDef. -// -// Handles renaming of nodes to avoid duplicate names which may -// be present after various rewriting operations. -void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); - // Given a numerical function "f", returns another numerical function // "g", such that if "f" takes N inputs and produces M outputs, "g" // takes N + M inputs and produces N outputs. I.e., if @@ -161,221 +80,6 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); // TODO(zhifengc): Asks math expert to say the comment again. std::unique_ptr SymbolicGradient(const FunctionBody& f); -// Optionally override device assignment for nodes added to the graph for -// inlined functions: -// (1) Identity nodes added in place of function input arguments. -// (2) Identity nodes added in place of function return values. -// (3) Special NoOp nodes that enforce side-effects execution order. -// (4) All nodes inside function body specified in FunctionDef. -class InlinedFunctionBodyPlacer { - public: - virtual ~InlinedFunctionBodyPlacer() = default; - - virtual absl::optional InputNodeDevice(int input_index) const = 0; - virtual absl::optional OutputNodeDevice(int output_index) const = 0; - // Returns true if the added input/output identity nodes should be colocated - // with the corresponding input/output from the function body. - virtual bool ColocateInputOutputIdentities() const = 0; - virtual absl::optional ControlNodeDevice() const = 0; - virtual absl::optional BodyNodeDevice(const NodeDef& ndef) const = 0; - - // Place input nodes on the same device as the corresponding caller input - // node. Do not specify any placement for all other nodes. - static std::unique_ptr DefaultPlacer( - const Graph& graph, const Node& caller); - - // Place all nodes on the same device as caller node. - static std::unique_ptr SingleDevicePlacer( - const Graph& graph, const Node& caller); - - // Place input nodes on the same device as the corresponding caller input - // node. Do not place output node. Place control nodes on the same device as - // caller node. For all function body nodes overrides job, replica and task - // parts of the device assignment to match function caller node. - static std::unique_ptr MultiDevicePlacer( - const Graph& graph, const Node& caller); - - using Factory = std::function( - const Graph&, const Node&)>; - - struct Config { - string name; - Factory get; - }; - - static Config Default() { return {"default", DefaultPlacer}; } - static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; } - static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; } -}; - -struct InlineFunctionBodyOptions { - // All nodes that have incoming control edge *from* the function call node, - // will be forwarded to the "output control node". There are two options for - // choosing which nodes will have a control edge *to* the "output control - // node": - // a) control returns (`control_ret` field in FunctionDef) - // b) data returns (`ret` field in FunctionDef) - enum class OutputControlSource { kDataOutputs, kControlOutputs }; - - // Keep a node in a graph with the same name as the function call node: - // - // a) DoNotKeep: Function call node is fully inlined, and there is no node in - // a graph with the same name. - // - // b) Fetchable: Add an IdentityN node to the graph in place of the inlined - // function call node. It will have a control edge from inlined - // 'output_control_node' and data edges from function output nodes. - // The IdentityN node will be placed on the same device as the caller node. - // - // This is mostly for compatibility with Tensorflow v1 and sessions. - // When we prepare a graph for execution in - // GraphExecutionState::MakeForBaseGraph we don't know what nodes will be - // fetched, so we can't safely remove any of them. When graph executed as a - // function it has 'Retval' nodes for all fetched tensors, and we can - // safely inline function calls. - // - // c) Targetable: Add a NoOp node to the graph in place of the inlined - // function call node. It will have a control edge from inline - // 'output_control_node' and no data edges. NoOp node will be placed on the - // same device as the caller node. This will keep the inlined function call - // node a valid 'session.run' target, and also will keep it a valid control - // output node. - enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable }; - - // If 'true' function inlining is completely disabled. This allows to control - // function inlining for different types of function calls (see - // 'ExpandInlineFunctionsOptions' below). - bool disable_inlining = false; - // Ignore '_noinline' function attribute. - bool ignore_noinline = false; - // If 'true' function inlining will inline functions in implementation - // selection group. Normally those functions should not be inlined; they will - // be handled by Grappler. - bool inline_impl_selection_group_functions = false; - // Controls if we want to keep a node with the name as the function call node - // in a graph after function inlining. - KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep; - // For compatibility with Tensorflow v1 by default we will use data outputs. - // Control returns were added to Tensorflow v2 with automatic control - // dependencies tracking in Eager mode. - OutputControlSource output_control_src = OutputControlSource::kDataOutputs; - // Inlined function body placer decides what requested device assignments - // should be added to the nodes added to the graph. See documentation above - // for available strategies. - InlinedFunctionBodyPlacer::Config inlined_function_body_placer = - InlinedFunctionBodyPlacer::Default(); - // If true, frame names in the function body will be - // made unique in the resulting graph (e.g. by prepending a unique prefix). - // NOTE(mrry): Only set this option to false when there is a single function - // call in the graph (e.g. when making a remote function call via - // ClusterFunctionLibraryRuntime). This option is provided because the graph - // partitioner generates frame names that must remain unmodified across all - // partitions of a multi-device function. - bool uniquify_frame_names = true; - - // A human-readable debug string for this options. - string DebugString() const; -}; - -// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node' -// based on the type signature of 'node' and 'fbody': -// -// (1) Caller node has the same number of inputs and outputs as the function. -// (2) Caller node inputs and outputs have the same data types as function -// inputs and returns. -// (3) Validation rules defined in InlineFunctionBodyOptions. -// -// If function can't be safely inlined, returns error message with details why -// inlining is not possible or safe. -Status ValidateInlining(const Node* node, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options); - -// Given a "caller" in graph "g", which is a function call of a function -// to "fbody". Replaces the "caller" with fbody->graph and connects -// edges properly. "override_device" specifies whether inlining should replace -// explicitly specified devices inside fbody with the callee's device. -// -// Returns 'Status::OK()' if function was successfully inlined into the graph. -// If function inlining is not possible returns an error with a reason, and -// leaves the graph in unmodified state. -Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options); - -// There are three types of function calls that could be invoked during -// *Tensorflow graph execution*: -// -// 1) Native function call (node.type_string() is the function name). These -// functions are always executed on a single-device, which is the device of -// the function call node. -// -// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall -// ops) can execute on multiple devices and accept DT_RESOURCE inputs that -// belong to different devices. This type of functions was added in -// Tensorflow 2.0 Eager mode, and it has control outputs to represent -// side-effects that must always execute (see `control_ret` in FunctionDef). -// -// 3) SymbolicGradient has been deprecated for a while, but we still keep it and -// use `native` options for inlining for compatibility. -// -// We need to have distinct inlining rules for compatibility with Tensorflow v1. -// -// There are few other places in Tensorflow that could execute functions: -// -// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level" -// functions directly via function library runtime, without going through -// the graph. -// 2) tf.data pipelines - also execute functions directly via function library -// runtime with custom executors. -struct ExpandInlineFunctionsOptions { - ExpandInlineFunctionsOptions() : native_options(), multi_device_options() { - using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; - multi_device_options.output_control_src = OutputControlSrc::kControlOutputs; - } - - InlineFunctionBodyOptions native_options; - InlineFunctionBodyOptions multi_device_options; -}; - -// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary -// workaround that will be enabled only during the function inlining unification -// (b/126811947). Contact ezhulenev@ if you think you need it. -// TODO(ezhulenev): Delete this function. -bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, - const ExpandInlineFunctionsOptions& options); - -// For each node in "graph", if "lib" indicates that the node is a -// function call, inline the function body. Returns true if at least -// one node is inlined. -// -// This routine goes through "graph" nodes once and applies the -// inlining. The caller may decide to apply the inlining on "graph" -// multiple times by calling ExpandInlineFunctions a few times. -// -// Function calls that can't be safely inlined into the graph (ValidateInlining -// returns error), are ignored. -// -// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the -// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see -// lower_function_call.cc). -inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { - return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions()); -} - -// Extracts function name and attributes from `call_def` -// `call_def` can be a native function call (where the op type is the function -// name) or a call through PartitionedCall/StatefulPartitionedCall. -Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, - NameAttrList* function); - -// Extracts function name and attributes from `call_def` and invokes -// flr->Instantiate(name, attrs, handle). -// `call_def` can be a native function call (where the op type is the function -// name) or a call through PartitionedCall/StatefulPartitionedCall. -Status InstantiateFunctionCall(const NodeDef& call_def, - FunctionLibraryRuntime* flr, - FunctionLibraryRuntime::Handle* handle); - // Returns true iff `n` represents a function call. `n` can be a native // function call (n.type_string() is the function name), // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which diff --git a/tensorflow/core/common_runtime/function_body.cc b/tensorflow/core/common_runtime/function_body.cc new file mode 100644 index 00000000000..3b3442bf7f5 --- /dev/null +++ b/tensorflow/core/common_runtime/function_body.cc @@ -0,0 +1,64 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function_body.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, + DataTypeSlice ret_t, Graph* g) + : fdef(f), + graph(g), + arg_types(arg_t.begin(), arg_t.end()), + ret_types(ret_t.begin(), ret_t.end()) { + // 1. Find regular Arg/Ret nodes. + this->arg_nodes.resize(arg_types.size()); + this->ret_nodes.resize(ret_types.size()); + for (Node* n : this->graph->op_nodes()) { + gtl::InlinedVector* node_vec; + if (n->type_string() == FunctionLibraryDefinition::kRetOp || + n->type_string() == FunctionLibraryDefinition::kDeviceRetOp) { + node_vec = &this->ret_nodes; + } else if (n->type_string() == FunctionLibraryDefinition::kArgOp || + n->type_string() == FunctionLibraryDefinition::kDeviceArgOp) { + node_vec = &this->arg_nodes; + } else { + continue; + } + int index; + TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index)); + CHECK_LE(0, index); + CHECK_LT(index, node_vec->size()); + (*node_vec)[index] = n; + } + // 2. Find ControlRet nodes that must be always executed. + std::unordered_set control_ret_node_names; + for (const auto& control_ret : fdef.control_ret()) { + control_ret_node_names.insert(control_ret.second); + } + this->control_ret_nodes.reserve(control_ret_node_names.size()); + for (Node* n : this->graph->op_nodes()) { + if (control_ret_node_names.count(n->name()) > 0) { + this->control_ret_nodes.push_back(n); + } + } +} + +FunctionBody::~FunctionBody() { delete this->graph; } + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_body.h b/tensorflow/core/common_runtime/function_body.h new file mode 100644 index 00000000000..cbd602612a2 --- /dev/null +++ b/tensorflow/core/common_runtime/function_body.h @@ -0,0 +1,52 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +class Graph; +class Node; + +// FunctionLibraryRuntime::GetFunctionBody returns a description of an +// instantiated function that is represented as a Graph with arg/ret +// nodes annotated. +struct FunctionBody { + FunctionDef fdef; + Graph* graph = nullptr; // owned. + DataTypeVector arg_types; + DataTypeVector ret_types; + // arg_nodes[i] contains the i'th function input. In other words, + // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i. + gtl::InlinedVector arg_nodes; + // ret_nodes[i] contains the i'th function output. In other words, + // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i. + gtl::InlinedVector ret_nodes; + gtl::InlinedVector control_ret_nodes; + + FunctionBody() {} + FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, + DataTypeSlice ret_types, Graph* g); + ~FunctionBody(); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 9dbb2e77c94..581b7adbef7 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index 3b1f90e7198..0786d9032a8 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/function.h" - #include #include @@ -25,7 +23,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc new file mode 100644 index 00000000000..c332927cb95 --- /dev/null +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -0,0 +1,368 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function_utils.h" + +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +static constexpr const char* const kNodeLabel = "Func"; + +// Represents the index-th output of a node. +struct Endpoint { + Node* node; + int index; + + // Returns the string name represents this endpoint. + string name() const { + if (index == 0) { + return node->name(); + } else { + return strings::StrCat(node->name(), ":", index); + } + } + + DataType dtype() const { return node->output_type(index); } +}; + +// The following Add* routines are used to add a few graph nodes while +// functions are transformed. +static Node* AddNoOp(StringPiece name, Graph* g) { + NodeDef ndef; + ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); + ndef.set_op("NoOp"); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + return ret; +} + +static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { + DCHECK_LT(0, input.dtype()); + NodeDef ndef; + ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); + ndef.set_op("Identity"); + ndef.add_input(input.name()); + AddNodeAttr("T", BaseType(input.dtype()), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; +} + +void DumpGraph(StringPiece label, const Graph* g) { + // TODO(zhifengc): Change Graph to record #nodes. + VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " + << g->num_edges(); + if (VLOG_IS_ON(5)) { + for (const auto& line : str_util::Split(DebugString(g), '\n')) { + VLOG(5) << "|| " << line; + } + } +} + +bool RemoveDeadNodes(Graph* g) { + VLOG(2) << "Removing dead nodes"; + std::unordered_set nodes; + for (auto n : g->nodes()) { + if (n->IsSource() || n->IsSink() || n->IsControlFlow() || + n->op_def().is_stateful()) { + nodes.insert(n); + } + } + return PruneForReverseReachability(g, std::move(nodes)); +} + +namespace { +// If 'edges' contains only 1 non-control edge, returns it. Otherwise, +// returns a nullptr. +const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { + const Edge* ret = nullptr; + for (const Edge* e : edges) { + if (e->IsControlEdge() || ret) { + // Don't touch it if there is a control edge. + return nullptr; + } + if (IsRefType(e->src()->output_type(e->src_output()))) { + // Don't touch it if the identity node is effectively de-reffing + // a ref. + return nullptr; + } + if (IsRecv(e->src()) || IsSwitch(e->src())) { + // Don't touch it if the identity is introduced for control flow. + // Recv disables all its successors if it receives a dead signal. + // When Recv has an outgoing control edge, the current executor + // would not disable the destination. The current solution (see + // graph_partition.cc) is to add an identity after Recv and change + // the control edge to be from this identity node. So the identity + // can't be removed. + return nullptr; + } + ret = e; + } + return ret; +} +} // end namespace + +bool RemoveIdentityNodes(Graph* g) { + VLOG(2) << "Removing identity nodes"; + bool removed_any = false; + gtl::InlinedVector matches; + for (Node* n : g->nodes()) { + if (!n->IsIdentity()) continue; + if (!GetTheOnlyDataEdge(n->in_edges())) continue; + + // Some identity nodes are used as sink nodes to give names to output + // tensors. These nodes are not going to be executed unless they are in the + // fetch set. But if they are in the fetch set we don't want to remove them. + if (n->out_edges().empty()) continue; + + matches.push_back(n); + } + if (!matches.empty()) { + for (Node* n : matches) { + const Edge* in = GetTheOnlyDataEdge(n->in_edges()); + for (const Edge* out : n->out_edges()) { + if (out->IsControlEdge()) { + g->AddControlEdge(in->src(), out->dst()); + } else { + g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); + } + } + VLOG(2) << "Remove Identity: " << n->DebugString(); + g->RemoveNode(n); + removed_any = true; + } + } + return removed_any; +} + +bool RemoveListArrayConverter(Graph* g) { + VLOG(2) << "Removing list array converter"; + gtl::InlinedVector matches; + for (Node* n : g->nodes()) { + if ((n->type_string() == "_ListToArray") || + (n->type_string() == "_ArrayToList")) { + matches.push_back(n); + } + } + bool removed_any = false; + if (!matches.empty()) { + for (Node* n : matches) { + if (n->num_inputs() != n->num_outputs()) { + continue; // Not expected. Skip. + } + gtl::InlinedVector identity_nodes(n->num_inputs(), nullptr); + + const auto no_op = [&](StringPiece name) -> Node* { + return AddNoOp(absl::StrCat(n->name(), "/", name), g); + }; + + const auto identity = [&](StringPiece name, Endpoint input) -> Node* { + Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input); + node->set_requested_device(input.node->def().device()); + return node; + }; + + // Process input edges first. + Node* input_control_node = nullptr; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + // If node "n" has any control dependencies, adds a no-op + // node (input_control_node) which the additional Identity + // nodes depends on and the input_control_node depends on + // the node "n"s control dependencies. + input_control_node = no_op("input_control_node"); + } + g->AddControlEdge(e->src(), input_control_node); + } else { + const int index = e->dst_input(); + Node** id_node = &identity_nodes[index]; + if (*id_node != nullptr) { + LOG(ERROR) + << "RemoveListArrayConverter unexpected duplicated input: " + << e->dst_input(); + return removed_any; + } + *id_node = identity("input", {e->src(), e->src_output()}); + } + } + + // If node "n" has any control dependencies, the added identity + // nodes should have control dependencies on input_control_node. + if (input_control_node != nullptr) { + for (Node* id : identity_nodes) { + g->AddControlEdge(input_control_node, id); + } + } + + Node* output_control_node = nullptr; + for (const Edge* e : n->out_edges()) { + if (e->IsControlEdge()) { + if (output_control_node == nullptr) { + // If node "n" is control-depended upon by other nodes, + // adds a no-op node (output_control_node) which those + // nodes will depend on and output_control_node depends on + // all Identity nodes. + output_control_node = no_op("output_control_node"); + } + g->AddControlEdge(output_control_node, e->dst()); + } else { + Node* id_node = identity_nodes[e->src_output()]; + if (id_node == nullptr) { + LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " + << e->src_output(); + return removed_any; + } + CHECK(id_node); + g->AddEdge(id_node, 0, e->dst(), e->dst_input()); + } + } + + // If any nodes have control dependencies on node "n", those + // nodes should have control dependencies on + // output_control_node. + if (output_control_node != nullptr) { + for (Node* id : identity_nodes) { + g->AddControlEdge(id, output_control_node); + } + } + + g->RemoveNode(n); + removed_any = true; + } + } + return removed_any; +} + +Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, + NameAttrList* function) { + if (call_def.op() == "PartitionedCall" || + call_def.op() == "StatefulPartitionedCall") { + TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function)); + } else { + function->set_name(call_def.op()); + *function->mutable_attr() = call_def.attr(); + } + return Status::OK(); +} + +Status InstantiateFunctionCall(const NodeDef& call_def, + FunctionLibraryRuntime* flr, + FunctionLibraryRuntime::Handle* handle) { + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function)); + return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle); +} + +bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, + const Node& node) { + return node.IsFunctionCall(); +} + +string NewName(const Node* n, bool pretty) { + if (pretty) { + return strings::StrCat(n->type_string(), n->id()); + } else { + return strings::StrCat("n", n->id()); + } +} + +// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. +// and stash the original NodeDef name as an attr for documentation +// purpose. +void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { + // We visit nodes in forward topological sort order, which is a + // possible execution order of the graph. + gtl::InlinedVector inputs; + gdef->Clear(); + *gdef->mutable_versions() = g->versions(); + + std::vector start_nodes; + for (Node* n : g->nodes()) { + if (n->out_edges().empty()) { + start_nodes.push_back(n); + } + } + + ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) { + if (!n->IsOp()) return; + NodeDef* ndef = gdef->add_node(); + ndef->set_name(NewName(n, pretty)); + ndef->set_op(n->type_string()); + for (const auto& attr : n->attrs()) { + (*ndef->mutable_attr())[attr.first] = attr.second; + } + + if (!n->assigned_device_name().empty()) { + ndef->set_device(n->assigned_device_name()); + } else { + ndef->set_device(n->requested_device()); + } + + inputs.clear(); + inputs.resize(n->num_inputs()); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + inputs.push_back(e); + } else { + if (inputs[e->dst_input()] == nullptr) { + inputs[e->dst_input()] = e; + } else { + LOG(WARNING) << "Malformed graph node. multiple input edges: " + << n->DebugString(); + } + } + } + // node->name() is merely NodeDef::name, which are not guaranteed + // to be unique and stable after optimization rewrites. Therefore, + // we use "n" instead. + for (const Edge* e : inputs) { + if (e == nullptr) { + ndef->add_input("unknown"); + continue; + } + const string srcname = NewName(e->src(), pretty); + if (!e->src()->IsOp()) { + } else if (e->IsControlEdge()) { + ndef->add_input(strings::StrCat("^", srcname)); + } else if (e->src_output() == 0) { + ndef->add_input(srcname); + } else { + ndef->add_input(strings::StrCat(srcname, ":", e->src_output())); + } + } + }); +} + +string DebugString(const Graph* g) { + GraphDef gdef; + ToGraphDef(g, &gdef); + return DebugString(gdef); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_utils.h b/tensorflow/core/common_runtime/function_utils.h new file mode 100644 index 00000000000..8a3de2a8402 --- /dev/null +++ b/tensorflow/core/common_runtime/function_utils.h @@ -0,0 +1,105 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ + +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class AttrSlice; +class Graph; +class GraphDef; +class NameAttrList; +class Node; +class NodeDef; +class OpDef; + +// Debugging facility. Returns a debug string for a graph +// representing an instantiated function. +string DebugString(const Graph* g); + +// Dump the contents of the "graph" to log files if the logging level is +// sufficiently high. +void DumpGraph(StringPiece label, const Graph* g); + +// Convert the Graph of a function to a GraphDef. +// +// Handles renaming of nodes to avoid duplicate names which may +// be present after various rewriting operations. +void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); + +// Extracts function name and attributes from `call_def` +// `call_def` can be a native function call (where the op type is the function +// name) or a call through PartitionedCall/StatefulPartitionedCall. +Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, + NameAttrList* function); + +// A few hand-crafted optimization on the instantiated function body +// (a Graph*). + +// Removes nodes that are +// 1. not stateful; and +// 2. not _Arg; and +// 3. not reachable from _Retval. +// +// This function is triggered by function inlining, unlike 'PruneFunctionBody' +// it doesn't preserve nodes that are reachable from control returns. Function +// inlining is responsible for connecting control return nodes with the nodes +// that have input control edges from the inlined function call node. +// +// Assuming that automatic control dependency tracking is correct, absence of +// outgoing control edge from the function call node means that no one needs to +// observe side-effect that might have been generated by the function (see +// documentation in common_runtime/function.cc for details). +// +// Returns true iff any node is removed from "g". +bool RemoveDeadNodes(Graph* g); + +// Find a pattern: +// src -(in)-> node -(out)-> dst, where +// 1) node is an identity node; +// 2) in is the only incoming data edge; +// 3) out is the only outgoing data edge; +// +// Rewrites the above pattern with src->dst and relevant data +// dependencies updated. Repeat the process until no such pattern +// left. +bool RemoveIdentityNodes(Graph* g); + +// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. +bool RemoveListArrayConverter(Graph* g); + +// Extracts function name and attributes from `call_def` and invokes +// flr->Instantiate(name, attrs, handle). +// `call_def` can be a native function call (where the op type is the function +// name) or a call through PartitionedCall/StatefulPartitionedCall. +Status InstantiateFunctionCall(const NodeDef& call_def, + FunctionLibraryRuntime* flr, + FunctionLibraryRuntime::Handle* handle); + +// Returns true iff `n` represents a function call. `n` can be a native +// function call (n.type_string() is the function name), +// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which +// has been deprecated for a while). +bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc similarity index 99% rename from tensorflow/core/graph/graph_constructor.cc rename to tensorflow/core/common_runtime/graph_constructor.cc index feaf9f6e70e..ab5b086b25c 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include #include diff --git a/tensorflow/core/common_runtime/graph_constructor.h b/tensorflow/core/common_runtime/graph_constructor.h new file mode 100644 index 00000000000..c58a4aafd40 --- /dev/null +++ b/tensorflow/core/common_runtime/graph_constructor.h @@ -0,0 +1,204 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class ShapeRefiner; + +// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on +// error, in which case *g is left in an incomplete state. +// +// *g is expected to be an empty graph (with no more than a source and sink +// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph, +// see ImportGraphDef. +struct GraphConstructorOptions { + GraphConstructorOptions() {} + + // If true, allows internal ops in the GraphDef. + bool allow_internal_ops = false; + + // If true, the graph def is expected to have fully specified + // devices for all nodes. A node in the resulting graph "g" has the + // device name set accordingly. + // + // TODO(zhifengc): if possible, consider removing this option. + bool expect_device_spec = false; + + // If true, validates that nodes being converted have all expected attrs + // set and no unknown attrs set by calling ValidateNodeDef(). + // Setting validate_nodes without add_default_attributes, will fail if + // the GraphDef does not have all required attributes set. + bool validate_nodes = false; + + // If true, GraphConstructor will add attributes with their default + // value to the Node when they are missing from the NodeDef. + bool add_default_attributes = true; +}; +extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g); +extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + GraphDef&& gdef, Graph* g); + +// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function +// instantiation. +// TODO(irving): This will turn into std::vector soon. +extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + gtl::ArraySlice nodes, Graph* g); + +// Options for calling ImportGraphDef(). +struct ImportGraphDefOptions { + ImportGraphDefOptions() + : uniquify_names(false), + uniquify_prefix(false), + skip_mapped_nodes(false), + validate_shape(true) {} + + // Name prefix to use for nodes imported from the GraphDef. For example, if + // prefix="animals" and GraphDef contains a node "bunny" then the node will be + // named "animals/bunny" in *g. Must not be already used as a node name or + // prefix in the graph. + string prefix; + + // If true, imported node names will be modified if their name already exists + // in the graph. If false, conflicting names will be treated as an error. Note + // that this option has no effect if `prefix` is specified, since `prefix` + // will guarantee all node names are unique. + bool uniquify_names; + + // If true, `prefix` will be modified if it already exists as a node name or + // prefix in the graph. If false, a conflicting prefix will be treated as an + // error. This option has no effect if `prefix` isn't specified. + bool uniquify_prefix; + + // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef` + // corresponding to `input_map` keys will be remapped to the nodes in `g` + // corresponding to the values. + // + // Keys should not include `prefix`, i.e., a key ID's name should be the name + // as it originally appears in `gdef`. + // + // If this is non-empty, ImportGraphDef must be called with the shape refiner + // used to create the existing nodes referenced in `input_map`. + // TODO(skyewm): can we remove this requirement? How do we access the original + // shape refiner? + std::map input_map; + + // If true, nodes that will have all output edges removed because of + // overrides in `input_map` will not be imported. + bool skip_mapped_nodes; + + // The names of existing nodes in `g` that the imported graph should have + // control dependencies on. + // + // Note that to avoid creating many redundant control edges, ImportGraphDef() + // won't add control edges to nodes that will inherit the dependencies from + // other nodes in `gdef`. + std::vector control_dependencies; + + // Tensors in `gdef` that will be returned via the ImportGraphDefResults + // output parameter of `ImportGraphDef()`. If this list is non-empty, the + // caller must pass a results object to `ImportGraphDef()`. The + // `return_tensors` field will be populated with the imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each ID's name should be the + // name as it originally appears in `gdef`. + // + // If this contains a tensor that's also being remapped via `input_map`, the + // corresponding existing tensor in `g` will be returned. + std::vector return_tensors; + + // The names of nodes in `gdef` that will be returned via the + // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list + // is non-empty, the caller must pass a results object to + // `ImportGraphDef()`. The `return_nodes` field will be populated with the + // imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each node's name should be the + // name as it originally appears in `gdef`. + // + // Unlike `return_tensors`, `input_map` has no effect on the nodes + // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. + // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. + std::vector return_nodes; + + // If true, checks that all colocation constraints are nodes in the GraphDef. + bool validate_colocation_constraints = true; + + // If false skips shape validation. + bool validate_shape; + + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries + // with ops that are not defined in the binary calling ImportGraphDef. + // Similar to the producer_op_list argument to import_graph_def in the + // python API. + + // Try to set default execution device for this grapth. + string default_device; +}; + +// Optional results that may be returned by ImportGraphDef. +struct ImportGraphDefResults { + // The requested tensors associated with + // ImportGraphDefOptions::return_tensors. Note that the index may be different + // than the requested index if the returned tensor has been remapped according + // to `input_map`. + typedef int Index; + std::vector> return_tensors; + + // The requested nodes associated with ImportGraphDefOptions::return_nodes. + std::vector return_nodes; + + // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and + // weren't used as an input to any node in `gdef`. These keys are likely due + // to typos, and callers may wish to treat their existence as an error. + std::vector missing_unused_input_map_keys; +}; + +// Adds the graph in GraphDef `gdef` into an existing Graph `*g`. +// +// On error, returns non-OK and leaves `*g` unmodified. +// +// `refiner` can be null. It should be non-null if the caller +// intends to add additional nodes to the graph after the import. This +// allows the caller to validate shapes of those nodes (since +// ShapeRefiner::AddNode must be called in topological order). +// +// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is +// non-empty. It can also be set to fetch the unused input map keys. If it's +// non-null, all the vector fields must be empty. +// +// TODO(ashankar): Push this mechanism and get rid of Session::Extend() +// as a means of enhancing an existing Graph. +extern Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results = nullptr); + +// Make a copy of "src" into "*dest". +// +// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges +// other than the implicit Source/Sink nodes. +extern void CopyGraph(const Graph& src, Graph* dest); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 42247c664ec..944c676d0a9 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/collective_order.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/graph/validate.h" diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 410c4c2ad09..746930750ad 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/constant_folding.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/optimizer_cse.h" @@ -144,4 +145,19 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, options.inline_with_single_device_body_placer); } +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g, + const GraphOptimizer::Options& graph_optimizer_options) { + OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(lib, lib->env(), lib->device(), g, + graph_optimizer_options); +} + +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g) { + OptimizeGraph(lib, g, GraphOptimizer::Options()); +} + } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 77c9d62c27f..099ea8efa12 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -91,6 +91,17 @@ class GraphOptimizer { TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizer); }; +// Applies graph rewrite optimization such as inlining, dead code +// removal, etc. +// +// **g is a graph constructed based on the runtime library 'lib'. +// OptimizeGraph mutates **g extensively and replaces '*g' with a +// complete copy. Therefore, the caller should not keep any references +// to nodes *g. +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g, + const GraphOptimizer::Options& graph_optimizer_options); +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g); + } // end namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0b3970a468f..2c17bf54a17 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -20,9 +20,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/single_threaded_cpu_device.h" @@ -31,11 +32,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/graph_runner.h b/tensorflow/core/common_runtime/graph_runner.h index 1c4b2b719cd..563586e0534 100644 --- a/tensorflow/core/common_runtime/graph_runner.h +++ b/tensorflow/core/common_runtime/graph_runner.h @@ -20,15 +20,16 @@ limitations under the License. #include #include -#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/env.h" namespace tensorflow { +class Device; +class Env; +class Graph; + // GraphRunner takes a Graph, some inputs to feed, and some outputs // to fetch and executes the graph required to feed and fetch the // inputs and outputs. diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc new file mode 100644 index 00000000000..0dba49a0510 --- /dev/null +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -0,0 +1,865 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/inline_function_utils.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/optimizer_cse.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +namespace { +// A few string constant used throughout this module. +static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static constexpr const char* const kDeviceArgOp = + FunctionLibraryDefinition::kDeviceArgOp; +static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; +static constexpr const char* const kDeviceRetOp = + FunctionLibraryDefinition::kDeviceRetOp; +static constexpr const char* const kGradientOp = + FunctionLibraryDefinition::kGradientOp; +static constexpr const char* const kNodeLabel = "Func"; +static constexpr const char* const kFuncAttr = + FunctionLibraryDefinition::kFuncAttr; + +// Represents the index-th output of a node. +struct Endpoint { + Node* node; + int index; + + // Returns the string name represents this endpoint. + string name() const { + if (index == 0) { + return node->name(); + } else { + return strings::StrCat(node->name(), ":", index); + } + } + + DataType dtype() const { return node->output_type(index); } +}; + +struct EndpointHash { + uint64 operator()(const Endpoint& x) const { + return Hash64(reinterpret_cast(&x.node), sizeof(Node*), + x.index); + } +}; + +struct EndpointEq { + bool operator()(const Endpoint& x, const Endpoint& y) const { + return (x.node == y.node) && (x.index == y.index); + } +}; + +// The following Add* routines are used to add a few graph nodes while +// functions are transformed. +static Node* AddNoOp(StringPiece name, Graph* g) { + NodeDef ndef; + ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); + ndef.set_op("NoOp"); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + return ret; +} + +static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { + DCHECK_LT(0, input.dtype()); + NodeDef ndef; + ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); + ndef.set_op("Identity"); + ndef.add_input(input.name()); + AddNodeAttr("T", BaseType(input.dtype()), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; +} + +std::vector InputDevices(const Node& caller) { + std::vector input_devices(caller.in_edges().size()); + std::vector input_tensors(caller.in_edges().size()); + + for (const Edge* edge : caller.in_edges()) { + if (edge->IsControlEdge()) continue; + const string& input_device = edge->src()->has_assigned_device_name() + ? edge->src()->assigned_device_name() + : edge->src()->requested_device(); + input_devices[edge->dst_input()] = input_device; + input_tensors[edge->dst_input()] = + absl::StrCat(edge->src()->name(), ":", edge->src_output()); + } + + if (VLOG_IS_ON(4)) { + VLOG(4) << "Function instantiation input devices:"; + for (int i = 0; i < input_devices.size(); ++i) { + if (input_tensors[i].empty()) continue; // skip control edges + VLOG(4) << " [index " << i << "]" + << " device: " << input_devices[i] + << " (input: " << input_tensors[i] << ")"; + } + } + + return input_devices; +} + +// Place input nodes on the same device as the corresponding caller input +// node. Do not specify any placement for all other nodes. +class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer { + public: + explicit DefaultFunctionBodyPlacer(const Node& caller) + : input_devices_(InputDevices(caller)) {} + + absl::optional InputNodeDevice(int input_index) const override { + return input_devices_[input_index]; + } + absl::optional OutputNodeDevice(int output_index) const override { + return absl::nullopt; + } + bool ColocateInputOutputIdentities() const override { return false; } + absl::optional ControlNodeDevice() const override { + return absl::nullopt; + } + absl::optional BodyNodeDevice(const NodeDef& ndef) const override { + return absl::nullopt; + } + + private: + const std::vector input_devices_; +}; + +// Place all nodes on the same device as caller node. +class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { + public: + explicit SingleDeviceFunctionBodyPlacer(const Node& caller) + : caller_device_(caller.def().device()) {} + + absl::optional InputNodeDevice(int input_index) const override { + return caller_device_; + } + absl::optional OutputNodeDevice(int output_index) const override { + return caller_device_; + } + bool ColocateInputOutputIdentities() const override { return false; } + absl::optional ControlNodeDevice() const override { + return caller_device_; + } + absl::optional BodyNodeDevice(const NodeDef& ndef) const override { + return caller_device_; + } + + private: + const string caller_device_; +}; + +// Place input nodes on the same device as the corresponding caller input +// node. Do not place output node. Place control nodes on the same device as +// caller node. For all function body nodes overrides job, replica and task +// parts of the device assignment to match function caller node. +class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { + public: + explicit MultiDeviceFunctionBodyPlacer(const Node& caller) + : caller_device_(caller.def().device()), + input_devices_(InputDevices(caller)) { + has_parsed_caller_device_ = + DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_); + } + + absl::optional InputNodeDevice(int input_index) const override { + return input_devices_[input_index]; + } + absl::optional OutputNodeDevice(int output_index) const override { + return absl::nullopt; + } + bool ColocateInputOutputIdentities() const override { return true; } + absl::optional ControlNodeDevice() const override { + return caller_device_; + } + absl::optional BodyNodeDevice(const NodeDef& ndef) const override { + // TODO(ezhulenev): If function would have been instantiated as a + // multi-device function and executed via FunctionLibraryRuntime, it could + // be potentially placed on any available device. However there are multiple + // tests relying on this assumption. Fix them, and remove this line. + if (ndef.device().empty()) return caller_device_; + + if (!has_parsed_caller_device_) return ndef.device(); + + DeviceNameUtils::ParsedName ndef_parsed_device; + if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device)) + return ndef.device(); + + if (caller_parsed_device_.has_job) { + ndef_parsed_device.has_job = caller_parsed_device_.has_job; + ndef_parsed_device.job = caller_parsed_device_.job; + } + + if (caller_parsed_device_.has_replica) { + ndef_parsed_device.has_replica = caller_parsed_device_.has_replica; + ndef_parsed_device.replica = caller_parsed_device_.replica; + } + + if (caller_parsed_device_.has_task) { + ndef_parsed_device.has_task = caller_parsed_device_.has_task; + ndef_parsed_device.task = caller_parsed_device_.task; + } + return DeviceNameUtils::ParsedNameToString(ndef_parsed_device); + } + + private: + string caller_device_; + bool has_parsed_caller_device_; + DeviceNameUtils::ParsedName caller_parsed_device_; + std::vector input_devices_; +}; + +} // namespace + +std::unique_ptr +InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph, + const Node& caller) { + VLOG(3) << "Create default placer for inlined function body."; + return absl::make_unique(caller); +} + +std::unique_ptr +InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph, + const Node& caller) { + VLOG(3) << "Create single device placer for inlined function body."; + return absl::make_unique(caller); +} + +std::unique_ptr +InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph, + const Node& caller) { + VLOG(3) << "Create multi device placer for inlined function body."; + return absl::make_unique(caller); +} + +namespace { + +Status ValidateNoInline(const FunctionBody* fbody) { + const auto attr = AttrSlice(&fbody->fdef.attr()); + bool noinline = false; + if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) { + return errors::InvalidArgument( + "Can't inline function marked with '_noinline'"); + } + return Status::OK(); +} + +using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; + +// Propagate the debug info of `nodes` in function `func` to the `target` node. +// If the debug info of any node is missing, its node name and function name +// is used. +void PropagateDebugInfoToNode(const string& func, + const std::vector& nodes, + NodeDef* target) { + if (nodes.empty() || target->has_experimental_debug_info()) { + return; + } + for (const Node* node : nodes) { + const auto& node_def = node->def(); + if (node_def.has_experimental_debug_info()) { + target->mutable_experimental_debug_info()->MergeFrom( + node_def.experimental_debug_info()); + } else { + target->mutable_experimental_debug_info()->add_original_node_names( + node_def.name()); + target->mutable_experimental_debug_info()->add_original_func_names(func); + } + } +} +} // namespace + +string InlineFunctionBodyOptions::DebugString() const { + const auto true_false = [](bool b) { return b ? "true" : "false"; }; + + const auto keep_caller_node_str = [this]() -> string { + switch (keep_caller_node) { + case KeepCallerNode::kDoNotKeep: + return "DoNotKeep"; + case KeepCallerNode::kFetchable: + return "Fetchable"; + case KeepCallerNode::kTargetable: + return "Targetable"; + } + }; + + return absl::StrCat( + "disable_inlining=", true_false(disable_inlining), + ", ignore_noinline=", true_false(ignore_noinline), + ", inline_impl_selection_group_functions=", + true_false(inline_impl_selection_group_functions), + ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=", + output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs" + : "ControlOutputs", + ", inlined_function_body_placer=", inlined_function_body_placer.name, + ", uniquify_frame_names=", true_false(uniquify_frame_names)); +} + +Status ValidateInlining(const Node* node, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options) { + // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee + // that all side-effectful ops will be executed after inlining. See Grappler + // function_optimizer for details. Unify all function inlining mechanism. + // Do not inline if `!fbody->control_ret_nodes.empty()`. + + const auto num_node_inputs = static_cast(node->num_inputs()); + const auto num_node_outputs = static_cast(node->num_outputs()); + + if (num_node_inputs != fbody->arg_types.size() || + num_node_inputs != fbody->arg_nodes.size()) { + return errors::InvalidArgument( + "Node inputs do not match function arguments: inputs=", num_node_inputs, + " arg_types=", fbody->arg_types.size(), + " arg_nodes=", fbody->arg_nodes.size()); + } + + if (num_node_outputs != fbody->ret_types.size() || + num_node_outputs != fbody->ret_nodes.size()) { + return errors::InvalidArgument( + "Node outputs do not match function returns: outputs=", + num_node_outputs, " ret_types=", fbody->ret_types.size(), + " ret_nodes=", fbody->ret_nodes.size()); + } + + for (int i = 0; i < node->num_inputs(); ++i) { + if (node->input_type(i) != fbody->arg_types[i]) { + return errors::InvalidArgument( + "Node input type doesn't match function argument type: ", + node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i); + } + } + for (int i = 0; i < node->num_outputs(); ++i) { + if (node->output_type(i) != fbody->ret_types[i]) { + return errors::InvalidArgument( + "Node output type doesn't match function return type: ", + node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i); + } + } + + if (options.disable_inlining) { + return errors::InvalidArgument( + "Function inlining explicitly disabled by 'options.disable_inlining'"); + } + + if (!options.inline_impl_selection_group_functions) { + bool is_impl_selection_group_function = + fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end(); + if (is_impl_selection_group_function) { + return errors::InvalidArgument( + "Inlining of implementation selection group function ", + fbody->fdef.signature().name(), + " is disabled by options.inline_impl_selection_group_functions"); + } + } + + if (!options.ignore_noinline) { + TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); + } + + return Status::OK(); +} + +// Function inlining must preserve function execution semantics with regards to +// side-effects visibility. Tensorflow in Eager mode has an automatic control +// dependencies tracking mechanism, which enforces well-defined execution order +// of all side-effects. Any other frontend (e.g. Swift) must produce graphs +// following the same rules, to ensure that function inlining works correctly. +// +// IMPORTANT: Currently we do not have a true notion of "side-effectful" node, +// we assume that all stateful nodes might have side-effects, though it's not +// true in practice, e.g. `ReadVariableOp` doesn't have an observable +// side-effect. +// +// Automatic control dependency rules in Tensorflow 2.0 (python in eager mode): +// +// 1) When a function has a resource (DT_RESOURCE data type) input argument it +// "captures" the mutable resource. This is implemented by automatically +// adding a incoming control edge from the previous side-effectful op +// touching that resource, and an outgoing control edge to the next +// side-effectful op using the same resource. This serializes the mutations +// of the resource to make graph execution deterministic. +// +// 2) All stateful ops inside a function body are guaranteed to execute in +// program order, this is achieved by adding control edges between stateful +// ops at graph construction time. Stateful ops (or ops that must execute) +// should be in the function control return set. Having a data edge to the +// regular function output might be not enough, because after function +// inlining it might happen that data output is unused. +// +// 3) Furthermore, all ops accepting the same resource as an input are +// guaranteed to run in program order. This is also done by adding control +// edges at graph construction time. The last op touching the resource +// must be in a control return set, which will guarantee that all side +// effects to the resource will happen before function completion. +// +// Function inlining must preserve side-effect visibility: +// +// 1) All side-effects to the captured resources, that happened before function +// call must be visible to the function body nodes using that resources. +// +// 2) All side-effects to the captured resources, that happened inside function +// body, must be visible to every op/function using that resource after the +// function call completed. +// +// To guarantee that these properties are preserved after inlining we: +// +// 1) Create "input_control_node" NoOp. Function call node incoming control +// edges will be forwarded *to* this node. Function inputs (Identity nodes) +// will have a control edge *from* this node. If function body has nodes +// without inputs, they will have a control edge *from* this node. +// +// 2) Create "output_control_node" NoOp. All nodes that have incoming control +// edge *from* the function call node, will be forwarded to this node. +// +// We have two options for choosing which nodes will have a control edge *to* +// the "output control node": +// a) control returns (`control_ret` field in FunctionDef) +// b) data returns (`ret` field in FunctionDef) +// +// We do a) for multi-device function calls in Tensorflow v2 and b) +// for the rest for compatibility with Tensorflow v1. +// +// Following the automatic control dependencies tracking rules, a node that +// has an incoming control edge from the function call node is dependent on +// the side-effects happening inside the function body. The output control +// node will guarantee side-effects execution order. +// +// If function call node doesn't have an outgoing control edge, it means that +// no one is interested in observing side-effects that might have happened. +// +// Function inlining might leave the graph in partially-placed state. Function +// inlining caller must call Placer to guarantee that all nodes are placed. +// +// Function inlining with `options.override_device=true` will leave graph in +// fully placed state, by overriding all inlined nodes devices with the caller +// node device, but it will make functions always single-device. These functions +// after inlining will not be able to handle resources on multiple devices. This +// is currently acceptable for XLA use cases (XLA cluster is always executed on +// a single device). +// +// TODO(ezhulenev): Documentation above is ahead of implementation below. +Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, + Node* caller, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options) { + VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" + << options.DebugString() << "]"; + + Status validation = ValidateInlining(caller, fbody, options); + if (!validation.ok()) { + return errors::Internal("Inlining mismatch: ", validation.error_message()); + } + + // Placer is responsible for assigning devices for all nodes that we will add + // to the graph. + const std::unique_ptr placer = + options.inlined_function_body_placer.get(*g, *caller); + + // We can't possibly introduce a duplicate control edge during function + // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'. + static constexpr bool kDoNotCheckDuplicates = true; + + // ------------------------------------------------------------------------ // + // Helper functions to create `NoOp` and `Identity` nodes for auxiliary + // control nodes and inlined function inputs and outputs. + + // Add a NoOp node for function control inputs/outputs. + const auto no_op = [&](StringPiece name) -> Node* { + Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g); + const absl::optional device = placer->ControlNodeDevice(); + if (device.has_value()) node->set_requested_device(*device); + return node; + }; + + // Add an Identity node for function input. + const auto input_identity = [&](StringPiece name, Endpoint input, + int index) -> Node* { + Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); + const absl::optional device = placer->InputNodeDevice(index); + if (device.has_value()) node->set_requested_device(*device); + bool colocate_identity = placer->ColocateInputOutputIdentities(); + if (colocate_identity) { + node->AddAttr(kColocationAttrName, + std::vector{absl::StrCat(kColocationGroupPrefix, + input.node->name())}); + } + return node; + }; + + // Add an Identity node for function output. + const auto output_identity = [&](StringPiece name, Endpoint input, + int index) -> Node* { + Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); + const absl::optional device = placer->OutputNodeDevice(index); + if (device.has_value()) node->set_requested_device(*device); + bool colocate_identity = placer->ColocateInputOutputIdentities(); + if (colocate_identity) { + node->AddAttr(kColocationAttrName, + std::vector{absl::StrCat(kColocationGroupPrefix, + input.node->name())}); + } + return node; + }; + + // ------------------------------------------------------------------------ // + // Input edges. For data edges coming into "caller", we first compute the + // : for the i-th input in "inputs". + // If "caller" has any input control dependencies, we add a NoOp + // node "input_control_node", which depends on "caller"'s control inputs. + std::vector inputs(caller->num_inputs()); + Node* input_control_node = nullptr; + for (const Edge* e : caller->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + input_control_node = no_op("input_control_node"); + } + g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates); + } else { + inputs[e->dst_input()] = {e->src(), e->src_output()}; + } + } + if (input_control_node != nullptr) { + VLOG(3) << "Created input control node: " << input_control_node->name(); + } + + // ------------------------------------------------------------------------ // + // Duplicate fbody->graph into 'g'. First, we copy the nodes of + // fbody->graph into 'g' except the source and sink nodes. We copy + // edges among nodes in 'fbody->graph'. + // + // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we + // remember 'y' in node_map[x->id()]. + std::vector node_map(fbody->graph->num_node_ids()); + for (Node* n : fbody->graph->op_nodes()) { + NodeDef ndef = n->def(); + + // Maybe override requested node device assignment. + const absl::optional device = placer->BodyNodeDevice(ndef); + if (device.has_value()) ndef.set_device(*device); + + // Add inlined function name to inlined node debug information. + PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef); + + // Add the function node name as a prefix: + // 1) to node name to avoid collisions + // 2) to frame name to avoid multiple LoopCond nodes in one frame + // 3) to colocation attribute + const string prefix = strings::StrCat(caller->name(), "/"); + TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef, + options.uniquify_frame_names)); + + Status added_node; + Node* clone = g->AddNode(ndef, &added_node); + TF_CHECK_OK(added_node); + node_map[n->id()] = clone; + + // If there is an input control node, and one of: + // a) the node has no data or control inputs, or + // b) the node is a function call (including SymbolicGradient), + // then add a control edge from the input control node to the clone (only + // if it does not already have a control input). + // + // We must not execute any nodes if the original function call would not + // have executed. This is especially critical when the function call is + // inside a control-flow construct like tf.cond(). Case (a) ensures that + // such nodes do not run. + // + // The purpose of case (b) is to ensure that instances of case (a) created + // by further inlining steps also receive the control dependency. + // + // This edge is required to transfer execution frame down to all function + // body nodes of inlined nested function calls. + if (input_control_node) { + const auto is_input_edge = [](const Edge* e) -> bool { + return !e->src()->IsSource(); + }; + const auto is_control_edge = [](const Edge* e) -> bool { + return !e->src()->IsSource() && e->IsControlEdge(); + }; + + // Forward execution frame if: + // + // a) The node has no data or control inputs. + // b) OR the node is a function call without control inputs (control edge + // will be used in nested function inlining to forward execution frame + // to constants inside the function body). + // + // c) Do not forward control frame to function argument nodes, they will + // be connected to the corresponding function input later. + const bool forward_execution_frame = + (absl::c_none_of(n->in_edges(), is_input_edge) || // (a) + (n->IsFunctionCall() && // (b) + absl::c_none_of(n->in_edges(), is_control_edge))) && // + !n->IsArg(); // (c) + + if (forward_execution_frame) { + VLOG(4) << "Add control edge from input control node to: " + << clone->name(); + g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates); + } + } + } + for (const Edge* e : fbody->graph->edges()) { + if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || + e->dst()->IsSink()) { + continue; + } + Node* src_copy = node_map[e->src()->id()]; + Node* dst_copy = node_map[e->dst()->id()]; + g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } + + // ------------------------------------------------------------------------ // + // Connect input edges. + // + // We create one Identity node for each input. Then, we connect inputs[i] to + // the i-th identity node added. The nodes that previously connected + // to the j-th output of i-th arg node are reconnected to the i-th + // identity node. + // + // The added identity nodes depend on "input_control_node". + VLOG(4) << "Add input Identity nodes for each function argument:"; + for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { + Node* arg = node_map[fbody->arg_nodes[i]->id()]; + Node* n = input_identity("input", inputs[i], i); + VLOG(4) << " [index " << i << "] " + << fbody->fdef.signature().input_arg(i).name() << " as " + << n->name() << " (input: " << inputs[i].name() + << ", requested_device: " << n->requested_device() << ")"; + + if (input_control_node) { + g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates); + } + for (const Edge* e : arg->out_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates); + } else { + g->AddEdge(n, 0, e->dst(), e->dst_input()); + } + } + node_map[fbody->arg_nodes[i]->id()] = n; + g->RemoveNode(arg); // 'arg' is disconnected. + } + + // ------------------------------------------------------------------------ // + // Connect output edges. + // + // For i-th return node in fbody->graph, we add in "g" an identity node + // (outputs[i-th]). We then reconnect every incoming edge into the i-th return + // node to the added identity node. + // + // For every data edge coming out of "callee"s i-th output, we reconnect it to + // the i-th identity added above. + // + // If "callee" is control-depended upon by any other nodes, we add a NoOp node + // "output_control_node". "output_control_node" depends on all identity nodes + // added above or on all control return nodes (controlled by + // `options.output_control_src` value). And nodes previously depend on + // "callee" is changed to depend on "output_control_node". + // + // If `keep_node_fetchable` is `true` we always add an output control node, to + // guarantee that executing a fetchable node will execute all side-effects. + VLOG(4) << "Add output Identity nodes for each function output argument:"; + std::vector outputs(caller->num_outputs()); + for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { + Node* ret = node_map[fbody->ret_nodes[i]->id()]; + Endpoint data; // Data input for the ret node. + for (const Edge* e : ret->in_edges()) { + if (!e->IsControlEdge()) { + data = {e->src(), e->src_output()}; + break; + } + } + CHECK(data.node != nullptr); + Node* n = output_identity("output", data, i); + outputs[i] = n; + VLOG(4) << " [index " << i << "] " + << fbody->fdef.signature().output_arg(i).name() << " as " + << n->name() << " (ret: " << data.node->name() << ":" << data.index + << ", requested_device: " << n->requested_device() << ")"; + for (const Edge* e : ret->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates); + } + } + g->RemoveNode(ret); // 'ret' is disconnected. + } + + Node* output_control_node = nullptr; + const bool has_control_outputs = absl::c_any_of( + caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); }); + + using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; + const bool keep_caller_node = + options.keep_caller_node == KeepCallerNode::kFetchable || + options.keep_caller_node == KeepCallerNode::kTargetable; + + if (has_control_outputs || keep_caller_node) { + output_control_node = no_op("output_control_node"); + VLOG(4) << "Add output control node: " << output_control_node->name(); + if (options.output_control_src == OutputControlSrc::kDataOutputs) { + for (Node* n : outputs) { + VLOG(4) << " [data output] add control edge from: " << n->name(); + g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); + } + } else { + for (Node* fbody_node : fbody->control_ret_nodes) { + Node* n = node_map[fbody_node->id()]; + VLOG(4) << " [control output] add control edge from: " << n->name(); + g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); + } + } + } + + // We can't leave output control node without incoming control edges, because + // in this case outgoing control edge will loose execution frame information. + // We connect input_control_node and output_control_node with a control edge + // to forward execution frame to the controlled nodes. Above we add a control + // edge to all function calls inside function body, to guarantee that we will + // always have input_control_node when we need it. + if (output_control_node && output_control_node->in_edges().empty()) { + if (input_control_node) { + VLOG(4) + << "Add add a control edge between input and output control nodes: " + << input_control_node->name() << " to " + << output_control_node->name(); + g->AddControlEdge(input_control_node, output_control_node, + kDoNotCheckDuplicates); + } else { + VLOG(4) << "Function inlining potentially dropped execution frame " + "information from outgoing control edges."; + } + } + + for (const Edge* e : caller->out_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates); + } else { + g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); + } + } + + // ------------------------------------------------------------------------ // + // Add an IdentityN or NoOp node in-place of caller node to keep `caller` + // fetchable or targetable. + + if (keep_caller_node) { + std::vector output_tensors; + absl::c_transform(outputs, std::back_inserter(output_tensors), + [](Node* n) { return NodeBuilder::NodeOut(n, 0); }); + + Node* caller_substitute_node; + if (options.keep_caller_node == KeepCallerNode::kTargetable || + output_tensors.empty()) { + // IdentityN node must have at least one data input. If function has no + // data outputs, we can't keep it fetchable. + TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp") + .Device(caller->requested_device()) + .ControlInput(output_control_node) + .Finalize(g, &caller_substitute_node)); + + } else if (options.keep_caller_node == KeepCallerNode::kFetchable) { + TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN") + .Device(caller->requested_device()) + .Input(output_tensors) + .ControlInput(output_control_node) + .Finalize(g, &caller_substitute_node)); + } + } + + // ------------------------------------------------------------------------ // + // 'caller' is replaced with inlined function body nodes and maybe IdentityN + // to keep it fetchable. + VLOG(3) << "Successfully inlined function call node: " << caller->name(); + g->RemoveNode(caller); + + return Status::OK(); +} + +bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, + const ExpandInlineFunctionsOptions& options) { + std::vector> candidates; + + const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition(); + + for (Node* node : graph->nodes()) { + // Skip nodes that are not function calls or SymbolicGradient calls. + if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) { + continue; + } + // Skip function calls that marked noinline. + bool noinline; + if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { + VLOG(3) << "noinline: " << SummarizeNode(*node); + continue; + } + FunctionLibraryRuntime::Handle handle; + Status s = InstantiateFunctionCall(node->def(), lib, &handle); + if (!s.ok()) { + LOG(ERROR) << "Failed to instantiate a function: " << s.error_message(); + continue; + } + const FunctionBody* fbody = lib->GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + candidates.emplace_back(node, fbody); + } + + bool inlined_any = false; + for (const auto& p : candidates) { + Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, + p.first->IsPartitionedCall() + ? options.multi_device_options + : options.native_options); + if (inlined.ok()) { + inlined_any = true; + } else { + VLOG(1) << "Failed to inline function call: node=" << p.first->name() + << " error=" << inlined.error_message(); + } + } + + // TODO(ezhulenev): Release handles for inlined function calls. + + return inlined_any; +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/inline_function_utils.h b/tensorflow/core/common_runtime/inline_function_utils.h new file mode 100644 index 00000000000..bc873a3fc60 --- /dev/null +++ b/tensorflow/core/common_runtime/inline_function_utils.h @@ -0,0 +1,236 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +static constexpr const char* const kNoInlineAttr = "_noinline"; + +// Optionally override device assignment for nodes added to the graph for +// inlined functions: +// (1) Identity nodes added in place of function input arguments. +// (2) Identity nodes added in place of function return values. +// (3) Special NoOp nodes that enforce side-effects execution order. +// (4) All nodes inside function body specified in FunctionDef. +class InlinedFunctionBodyPlacer { + public: + virtual ~InlinedFunctionBodyPlacer() = default; + + virtual absl::optional InputNodeDevice(int input_index) const = 0; + virtual absl::optional OutputNodeDevice(int output_index) const = 0; + // Returns true if the added input/output identity nodes should be colocated + // with the corresponding input/output from the function body. + virtual bool ColocateInputOutputIdentities() const = 0; + virtual absl::optional ControlNodeDevice() const = 0; + virtual absl::optional BodyNodeDevice(const NodeDef& ndef) const = 0; + + // Place input nodes on the same device as the corresponding caller input + // node. Do not specify any placement for all other nodes. + static std::unique_ptr DefaultPlacer( + const Graph& graph, const Node& caller); + + // Place all nodes on the same device as caller node. + static std::unique_ptr SingleDevicePlacer( + const Graph& graph, const Node& caller); + + // Place input nodes on the same device as the corresponding caller input + // node. Do not place output node. Place control nodes on the same device as + // caller node. For all function body nodes overrides job, replica and task + // parts of the device assignment to match function caller node. + static std::unique_ptr MultiDevicePlacer( + const Graph& graph, const Node& caller); + + using Factory = std::function( + const Graph&, const Node&)>; + + struct Config { + string name; + Factory get; + }; + + static Config Default() { return {"default", DefaultPlacer}; } + static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; } + static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; } +}; + +struct InlineFunctionBodyOptions { + // All nodes that have incoming control edge *from* the function call node, + // will be forwarded to the "output control node". There are two options for + // choosing which nodes will have a control edge *to* the "output control + // node": + // a) control returns (`control_ret` field in FunctionDef) + // b) data returns (`ret` field in FunctionDef) + enum class OutputControlSource { kDataOutputs, kControlOutputs }; + + // Keep a node in a graph with the same name as the function call node: + // + // a) DoNotKeep: Function call node is fully inlined, and there is no node in + // a graph with the same name. + // + // b) Fetchable: Add an IdentityN node to the graph in place of the inlined + // function call node. It will have a control edge from inlined + // 'output_control_node' and data edges from function output nodes. + // The IdentityN node will be placed on the same device as the caller node. + // + // This is mostly for compatibility with Tensorflow v1 and sessions. + // When we prepare a graph for execution in + // GraphExecutionState::MakeForBaseGraph we don't know what nodes will be + // fetched, so we can't safely remove any of them. When graph executed as a + // function it has 'Retval' nodes for all fetched tensors, and we can + // safely inline function calls. + // + // c) Targetable: Add a NoOp node to the graph in place of the inlined + // function call node. It will have a control edge from inline + // 'output_control_node' and no data edges. NoOp node will be placed on the + // same device as the caller node. This will keep the inlined function call + // node a valid 'session.run' target, and also will keep it a valid control + // output node. + enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable }; + + // If 'true' function inlining is completely disabled. This allows to control + // function inlining for different types of function calls (see + // 'ExpandInlineFunctionsOptions' below). + bool disable_inlining = false; + // Ignore '_noinline' function attribute. + bool ignore_noinline = false; + // If 'true' function inlining will inline functions in implementation + // selection group. Normally those functions should not be inlined; they will + // be handled by Grappler. + bool inline_impl_selection_group_functions = false; + // Controls if we want to keep a node with the name as the function call node + // in a graph after function inlining. + KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep; + // For compatibility with Tensorflow v1 by default we will use data outputs. + // Control returns were added to Tensorflow v2 with automatic control + // dependencies tracking in Eager mode. + OutputControlSource output_control_src = OutputControlSource::kDataOutputs; + // Inlined function body placer decides what requested device assignments + // should be added to the nodes added to the graph. See documentation above + // for available strategies. + InlinedFunctionBodyPlacer::Config inlined_function_body_placer = + InlinedFunctionBodyPlacer::Default(); + // If true, frame names in the function body will be + // made unique in the resulting graph (e.g. by prepending a unique prefix). + // NOTE(mrry): Only set this option to false when there is a single function + // call in the graph (e.g. when making a remote function call via + // ClusterFunctionLibraryRuntime). This option is provided because the graph + // partitioner generates frame names that must remain unmodified across all + // partitions of a multi-device function. + bool uniquify_frame_names = true; + + // A human-readable debug string for this options. + string DebugString() const; +}; + +// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node' +// based on the type signature of 'node' and 'fbody': +// +// (1) Caller node has the same number of inputs and outputs as the function. +// (2) Caller node inputs and outputs have the same data types as function +// inputs and returns. +// (3) Validation rules defined in InlineFunctionBodyOptions. +// +// If function can't be safely inlined, returns error message with details why +// inlining is not possible or safe. +Status ValidateInlining(const Node* node, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); + +// Given a "caller" in graph "g", which is a function call of a function +// to "fbody". Replaces the "caller" with fbody->graph and connects +// edges properly. "override_device" specifies whether inlining should replace +// explicitly specified devices inside fbody with the callee's device. +// +// Returns 'Status::OK()' if function was successfully inlined into the graph. +// If function inlining is not possible returns an error with a reason, and +// leaves the graph in unmodified state. +Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, + Node* caller, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); + +// There are three types of function calls that could be invoked during +// *Tensorflow graph execution*: +// +// 1) Native function call (node.type_string() is the function name). These +// functions are always executed on a single-device, which is the device of +// the function call node. +// +// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall +// ops) can execute on multiple devices and accept DT_RESOURCE inputs that +// belong to different devices. This type of functions was added in +// Tensorflow 2.0 Eager mode, and it has control outputs to represent +// side-effects that must always execute (see `control_ret` in FunctionDef). +// +// 3) SymbolicGradient has been deprecated for a while, but we still keep it and +// use `native` options for inlining for compatibility. +// +// We need to have distinct inlining rules for compatibility with Tensorflow v1. +// +// There are few other places in Tensorflow that could execute functions: +// +// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level" +// functions directly via function library runtime, without going through +// the graph. +// 2) tf.data pipelines - also execute functions directly via function library +// runtime with custom executors. +struct ExpandInlineFunctionsOptions { + ExpandInlineFunctionsOptions() : native_options(), multi_device_options() { + using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; + multi_device_options.output_control_src = OutputControlSrc::kControlOutputs; + } + + InlineFunctionBodyOptions native_options; + InlineFunctionBodyOptions multi_device_options; +}; + +// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary +// workaround that will be enabled only during the function inlining unification +// (b/126811947). Contact ezhulenev@ if you think you need it. +// TODO(ezhulenev): Delete this function. +bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, + const ExpandInlineFunctionsOptions& options); + +// For each node in "graph", if "lib" indicates that the node is a +// function call, inline the function body. Returns true if at least +// one node is inlined. +// +// This routine goes through "graph" nodes once and applies the +// inlining. The caller may decide to apply the inlining on "graph" +// multiple times by calling ExpandInlineFunctions a few times. +// +// Function calls that can't be safely inlined into the graph (ValidateInlining +// returns error), are ignored. +// +// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the +// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see +// lower_function_call.cc). +inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { + return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions()); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc index ef1a74c5f29..c5e61dbefbb 100644 --- a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/lower_case_op_test.cc b/tensorflow/core/common_runtime/lower_case_op_test.cc index ce34a21f0ca..0ed46fb25ed 100644 --- a/tensorflow/core/common_runtime/lower_case_op_test.cc +++ b/tensorflow/core/common_runtime/lower_case_op_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/lower_functional_ops.h" - #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -22,12 +20,13 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc index f65c157e485..b1b657c9c22 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" #include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/core/common_runtime/lower_function_call_op_test.cc b/tensorflow/core/common_runtime/lower_function_call_op_test.cc index c7e6b16dca6..1a88f93737f 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op_test.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/lower_functional_ops.h" - #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -22,12 +20,13 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/lower_functional_ops_test.cc b/tensorflow/core/common_runtime/lower_functional_ops_test.cc index 21f2a5e82d8..a2c530e7429 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops_test.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index 71e1a011aa3..1990480f079 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/lower_functional_ops.h" - #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -22,12 +20,13 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc index 5aec79f35c3..a1ae5924718 100644 --- a/tensorflow/core/common_runtime/lower_while_op_test.cc +++ b/tensorflow/core/common_runtime/lower_while_op_test.cc @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/lower_functional_ops.h" - #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 7ddddc8f7ba..6cb56080a27 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils_test.cc b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils_test.cc index 0f8a439752a..aa619cbd105 100644 --- a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils_test.cc +++ b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 2b3152a57c0..0a12714ba7c 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/function.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index e8a486c60c7..42bde655735 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/partitioning_utils.h" #include "tensorflow/core/common_runtime/placer.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_node_util.h" #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/lib/core/blocking_counter.h" diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index d6ab1e30a55..a968aaf09b6 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -20,7 +20,8 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/eval_const_tensor.h" -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -28,9 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { diff --git a/tensorflow/core/graph/BUILD b/tensorflow/core/graph/BUILD index c8ea3ee1437..36fae822623 100644 --- a/tensorflow/core/graph/BUILD +++ b/tensorflow/core/graph/BUILD @@ -163,11 +163,10 @@ filegroup( ], ) -# Both of these files depend on common_runtime. +# This file depends on common_runtime. filegroup( name = "core_cpu_base_no_ops_srcs", srcs = [ - "graph_constructor.cc", "graph_def_builder_util.cc", ], ) @@ -250,7 +249,6 @@ filegroup( "gradients.h", "graph.cc", "graph.h", - "graph_constructor.cc", "graph_constructor.h", "graph_def_builder.cc", "graph_def_builder.h", diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 6930a57bf2e..bbbcf6ce423 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -16,189 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_ #define TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_ -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -class ShapeRefiner; - -// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on -// error, in which case *g is left in an incomplete state. -// -// *g is expected to be an empty graph (with no more than a source and sink -// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph, -// see ImportGraphDef. -struct GraphConstructorOptions { - GraphConstructorOptions() {} - - // If true, allows internal ops in the GraphDef. - bool allow_internal_ops = false; - - // If true, the graph def is expected to have fully specified - // devices for all nodes. A node in the resulting graph "g" has the - // device name set accordingly. - // - // TODO(zhifengc): if possible, consider removing this option. - bool expect_device_spec = false; - - // If true, validates that nodes being converted have all expected attrs - // set and no unknown attrs set by calling ValidateNodeDef(). - // Setting validate_nodes without add_default_attributes, will fail if - // the GraphDef does not have all required attributes set. - bool validate_nodes = false; - - // If true, GraphConstructor will add attributes with their default - // value to the Node when they are missing from the NodeDef. - bool add_default_attributes = true; -}; -extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - const GraphDef& gdef, Graph* g); -extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - GraphDef&& gdef, Graph* g); - -// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function -// instantiation. -// TODO(irving): This will turn into std::vector soon. -extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, - gtl::ArraySlice nodes, Graph* g); - -// Options for calling ImportGraphDef(). -struct ImportGraphDefOptions { - ImportGraphDefOptions() - : uniquify_names(false), - uniquify_prefix(false), - skip_mapped_nodes(false), - validate_shape(true) {} - - // Name prefix to use for nodes imported from the GraphDef. For example, if - // prefix="animals" and GraphDef contains a node "bunny" then the node will be - // named "animals/bunny" in *g. Must not be already used as a node name or - // prefix in the graph. - string prefix; - - // If true, imported node names will be modified if their name already exists - // in the graph. If false, conflicting names will be treated as an error. Note - // that this option has no effect if `prefix` is specified, since `prefix` - // will guarantee all node names are unique. - bool uniquify_names; - - // If true, `prefix` will be modified if it already exists as a node name or - // prefix in the graph. If false, a conflicting prefix will be treated as an - // error. This option has no effect if `prefix` isn't specified. - bool uniquify_prefix; - - // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef` - // corresponding to `input_map` keys will be remapped to the nodes in `g` - // corresponding to the values. - // - // Keys should not include `prefix`, i.e., a key ID's name should be the name - // as it originally appears in `gdef`. - // - // If this is non-empty, ImportGraphDef must be called with the shape refiner - // used to create the existing nodes referenced in `input_map`. - // TODO(skyewm): can we remove this requirement? How do we access the original - // shape refiner? - std::map input_map; - - // If true, nodes that will have all output edges removed because of - // overrides in `input_map` will not be imported. - bool skip_mapped_nodes; - - // The names of existing nodes in `g` that the imported graph should have - // control dependencies on. - // - // Note that to avoid creating many redundant control edges, ImportGraphDef() - // won't add control edges to nodes that will inherit the dependencies from - // other nodes in `gdef`. - std::vector control_dependencies; - - // Tensors in `gdef` that will be returned via the ImportGraphDefResults - // output parameter of `ImportGraphDef()`. If this list is non-empty, the - // caller must pass a results object to `ImportGraphDef()`. The - // `return_tensors` field will be populated with the imported nodes in `g`. - // - // Entries should not include `prefix`, i.e., each ID's name should be the - // name as it originally appears in `gdef`. - // - // If this contains a tensor that's also being remapped via `input_map`, the - // corresponding existing tensor in `g` will be returned. - std::vector return_tensors; - - // The names of nodes in `gdef` that will be returned via the - // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list - // is non-empty, the caller must pass a results object to - // `ImportGraphDef()`. The `return_nodes` field will be populated with the - // imported nodes in `g`. - // - // Entries should not include `prefix`, i.e., each node's name should be the - // name as it originally appears in `gdef`. - // - // Unlike `return_tensors`, `input_map` has no effect on the nodes - // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. - // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. - std::vector return_nodes; - - // If true, checks that all colocation constraints are nodes in the GraphDef. - bool validate_colocation_constraints = true; - - // If false skips shape validation. - bool validate_shape; - - // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries - // with ops that are not defined in the binary calling ImportGraphDef. - // Similar to the producer_op_list argument to import_graph_def in the - // python API. - - // Try to set default execution device for this grapth. - string default_device; -}; - -// Optional results that may be returned by ImportGraphDef. -struct ImportGraphDefResults { - // The requested tensors associated with - // ImportGraphDefOptions::return_tensors. Note that the index may be different - // than the requested index if the returned tensor has been remapped according - // to `input_map`. - typedef int Index; - std::vector> return_tensors; - - // The requested nodes associated with ImportGraphDefOptions::return_nodes. - std::vector return_nodes; - - // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and - // weren't used as an input to any node in `gdef`. These keys are likely due - // to typos, and callers may wish to treat their existence as an error. - std::vector missing_unused_input_map_keys; -}; - -// Adds the graph in GraphDef `gdef` into an existing Graph `*g`. -// -// On error, returns non-OK and leaves `*g` unmodified. -// -// `refiner` can be null. It should be non-null if the caller -// intends to add additional nodes to the graph after the import. This -// allows the caller to validate shapes of those nodes (since -// ShapeRefiner::AddNode must be called in topological order). -// -// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is -// non-empty. It can also be set to fetch the unused input map keys. If it's -// non-null, all the vector fields must be empty. -// -// TODO(ashankar): Push this mechanism and get rid of Session::Extend() -// as a means of enhancing an existing Graph. -extern Status ImportGraphDef(const ImportGraphDefOptions& opts, - const GraphDef& gdef, Graph* g, - ShapeRefiner* refiner, - ImportGraphDefResults* results = nullptr); - -// Make a copy of "src" into "*dest". -// -// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges -// other than the implicit Source/Sink nodes. -extern void CopyGraph(const Graph& src, Graph* dest); - -} // namespace tensorflow +#include "tensorflow/core/common_runtime/graph_constructor.h" #endif // TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_ diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index e70427f9ef8..32c1aeac0de 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 420f5002739..2e7e87cd656 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5951,6 +5951,7 @@ filegroup( "//tensorflow/core/common_runtime:core_cpu_rump_impl", # quantize_training "//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session "//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session + "//tensorflow/core/common_runtime:graph_constructor", # tf_session "//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session "//tensorflow/core/common_runtime:session_state", # tf_session "//tensorflow/core/data/service:server_lib", # server_lib diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 80429b4e448..f61f6c04992 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -220,7 +220,7 @@ tensorflow::ImportGraphDef [op_gen_lib] # tf_session tensorflow::ApiDefMap::~ApiDefMap -[core_cpu_base_no_ops] # tf_session +[graph_constructor] # tf_session tensorflow::ShapeRefiner::~ShapeRefiner [python_api] # tf_session