From 9e467f4df3d982b55a9c697e0552514e60febec7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 28 Mar 2019 21:17:34 -0700 Subject: [PATCH] Automated rollback of commit 92f736f429e398df261cd2f3c8c949840dd06a76 PiperOrigin-RevId: 240915460 --- tensorflow/core/framework/function.cc | 6 +- tensorflow/core/grappler/costs/BUILD | 1 - .../core/grappler/costs/graph_properties.cc | 72 +- tensorflow/core/grappler/op_types.cc | 8 - tensorflow/core/grappler/op_types.h | 2 - .../core/grappler/optimizers/data/rebatch.cc | 8 - .../optimizers/dependency_optimizer.cc | 2 +- .../grappler/optimizers/function_optimizer.cc | 195 +++-- .../optimizers/function_optimizer_test.cc | 2 +- .../optimizers/meta_optimizer_test.cc | 24 +- tensorflow/core/grappler/utils/BUILD | 1 - tensorflow/core/grappler/utils/functions.cc | 804 ++++++++++++------ tensorflow/core/grappler/utils/functions.h | 150 +++- .../core/grappler/utils/functions_test.cc | 289 +++++-- 14 files changed, 1076 insertions(+), 488 deletions(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 756b9849585..b46705a88e2 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -177,8 +177,7 @@ class FunctionInstantiationHelper { } else { gnode->set_op(FunctionLibraryDefinition::kArgOp); } - DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i]; - AddAttr("T", dtype, gnode); + AddAttr("T", dtypes[i], gnode); AddAttr("index", arg_index, gnode); result_.arg_types.push_back(dtypes[i]); ++arg_index; @@ -344,8 +343,7 @@ class FunctionInstantiationHelper { gnode->set_op(FunctionLibraryDefinition::kRetOp); } AddInput(nodes_.size() - 1, item->nid, item->idx + i); - DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i]; - AddAttr("T", dtype, gnode); + AddAttr("T", dtypes[i], gnode); AddAttr("index", (*ret_index)++, gnode); result_.ret_types.push_back(dtypes[i]); } diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index c87b2506a1b..84d813fe771 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -41,7 +41,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":utils", - "@com_google_absl//absl/types:optional", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler:mutable_graph_view", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 28f4108c6ef..e4136273402 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" -#include "absl/types/optional.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -604,31 +603,22 @@ class SymbolicShapeRefiner { " was not previously added to SymbolicShapeRefiner."); } - const absl::optional& maybe_grappler_function_item = - it->second; - if (!maybe_grappler_function_item.has_value()) { - VLOG(3) << "Skip failed to instantiate function call: function_name=" - << function_node->op(); - - auto* ctx = GetNodeContext(function_node); - auto* ic = ctx->inference_context.get(); - for (int i = 0; i < ic->num_outputs(); ++i) { - TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i)); - } - - return Status::OK(); - } - // Copy (not reference) so that changes we make here (e.g., replacing - // _Arg with Const and _Retval with Identity) don't affect one in + // Placeholder with Const) don't affect one in // fun_to_grappler_function_item_. - GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item; + GrapplerFunctionItem grappler_function_item = it->second; MutableGraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. for (int i = 0; i < grappler_function_item.inputs().size(); ++i) { auto& fun_input = grappler_function_item.input(i); - NodeDef* fun_node = gv.GetNode(fun_input.node_name); + if (fun_input.placeholders.size() > 1) { + // TODO(jmdecker): Handle case with multiple input placeholders + return errors::Unimplemented( + "Input arguments with multiple placeholders are not yet " + "supported."); + } + NodeDef* fun_node = gv.GetNode(fun_input.input_name); const TensorId input_tensor = ParseTensorName(function_node->input(i)); if (IsControlInput(input_tensor)) { @@ -659,18 +649,11 @@ class SymbolicShapeRefiner { proto.mutable_dim(i)->set_size(-1); } } - - // Turn _Arg node into a Placeholder. _Arg node is a system op without a - // valid shape function. *attr_output_shape.mutable_shape() = proto; - fun_node->set_op("Placeholder"); - (*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"]; - (*fun_node->mutable_attr()).erase("index"); - (*fun_node->mutable_attr()).erase("T"); (*fun_node->mutable_attr())["shape"] = attr_output_shape; } - // Replace input nodes with Consts, if values are known. Note that + // Replace input Placeholders with Consts, if values are known. Note that // we don't check exceptions here as it's done in the above loop. auto* ctx = GetNodeContext(function_node); auto* ic = ctx->inference_context.get(); @@ -701,15 +684,6 @@ class SymbolicShapeRefiner { } } - // Replace output _Retval nodes with Identity nodes. _Retval is a system op - // without outputs and registered shape function. - for (const auto& output_arg : grappler_function_item.outputs()) { - NodeDef* output_node = gv.GetNode(output_arg.node_name); - DCHECK_EQ(output_node->op(), "_Retval"); - output_node->set_op("Identity"); - output_node->mutable_attr()->erase("index"); - } - // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_)); @@ -720,9 +694,16 @@ class SymbolicShapeRefiner { ctx->output_tensor_protos.resize(grappler_function_item.output_size(), nullptr); for (auto const& out_arg : grappler_function_item.outputs()) { + if (out_arg.output_nodes.size() > 1) { + // TODO(jmdecker): Handle case of multiple output tensors + return errors::Unimplemented( + "Output arguments with multiple output tensors are not yet " + "supported."); + } + // It is guaranteed that output_tensors does not contain any control // inputs, so port_id >= 0. - TensorId out_tensor = ParseTensorName(out_arg.node_name); + TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]); const NodeDef* retnode = gv.GetNode(out_tensor.node()); if (retnode == nullptr) { @@ -1061,18 +1042,9 @@ class SymbolicShapeRefiner { CHECK_NOTNULL(function_library_.Find(function_node->op())); GrapplerFunctionItem grappler_function_item; - Status function_instantiated = + TF_RETURN_IF_ERROR( MakeGrapplerFunctionItem(*function_def, function_library_, - graph_def_version_, &grappler_function_item); - - // If function instantiation failed we will skip it during shape inference. - if (!function_instantiated.ok()) { - VLOG(3) << "Failed to instantiate a function. Error: " - << function_instantiated.error_message(); - fun_to_grappler_function_item_[function_def->signature().name()] = - absl::nullopt; - return Status::OK(); - } + graph_def_version_, &grappler_function_item)); if (grappler_function_item.inputs().size() > function_node->input_size()) { return errors::FailedPrecondition( @@ -1719,9 +1691,7 @@ class SymbolicShapeRefiner { std::unordered_map node_to_context_; std::unordered_map unknown_shapes_; std::unordered_map unknown_dims_; - // Store function instantiations only for valid function. If function - // instantiation failed it will have an `absl::nullopt`. - std::unordered_map> + std::unordered_map fun_to_grappler_function_item_; FunctionLibraryDefinition function_library_; const std::unordered_map>& fed_ports_; diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 02fec134915..d417e812caf 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -67,10 +67,6 @@ bool IsApproximateEqual(const NodeDef& node) { return node.op() == "ApproximateEqual"; } -bool IsArg(const NodeDef& node) { - return node.op() == "_Arg" || node.op() == "_DeviceArg"; -} - bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; } bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; } @@ -423,10 +419,6 @@ bool IsRestore(const NodeDef& node) { node.op() == "RestoreSlice"); } -bool IsRetval(const NodeDef& node) { - return node.op() == "_Retval" || node.op() == "_DeviceRetval"; -} - bool IsReverse(const NodeDef& node) { return node.op() == "Reverse" || node.op() == "ReverseV2"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index fec635bccf2..12223453a74 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -33,7 +33,6 @@ bool IsAnyMaxPool(const NodeDef& node); bool IsAnyMin(const NodeDef& node); bool IsAnyMul(const NodeDef& node); bool IsApproximateEqual(const NodeDef& node); -bool IsArg(const NodeDef& node); bool IsArgMax(const NodeDef& node); bool IsArgMin(const NodeDef& node); bool IsAssert(const NodeDef& node); @@ -138,7 +137,6 @@ bool IsRelu6Grad(const NodeDef& node); bool IsReluGrad(const NodeDef& node); bool IsReshape(const NodeDef& node); bool IsRestore(const NodeDef& node); -bool IsRetval(const NodeDef& node); bool IsReverse(const NodeDef& node); bool IsReverseV2(const NodeDef& node); bool IsRsqrt(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc index 3bb3d648d99..b10e30ed1c9 100644 --- a/tensorflow/core/grappler/optimizers/data/rebatch.cc +++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc @@ -226,19 +226,11 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, // Replace optimized function with a new FunctionDef. TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func)); - } else { - VLOG(2) << "Failed to optimize dataset function. Error: " - << s.error_message(); } } else if (IsDatasetNodeOfType(node, kSourceDatasetOps)) { return errors::InvalidArgument( "Reached a source dataset: ", node.op(), " without encountering a batch transformation."); - } else if (IsRetval(node)) { - // _Retvals added to the function body graph in place of function outputs. - NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); - TF_RETURN_IF_ERROR( - RecursivelyHandleOp(*input_node, num_workers, flib, graph)); } else { return errors::InvalidArgument("Encountered an unsupported op: ", node.op()); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 1788fb97913..81d5aef4b36 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -76,7 +76,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const { return false; } for (const auto& consumer : node_map_->GetOutputs(node.name())) { - if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) { + if (node.input_size() > 1 && IsMerge(*consumer)) { return false; } if (IsSwitch(*input)) { diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 642c75819e9..06d30a50c94 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -109,10 +109,6 @@ bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { // Check if func_node has function attribute with a function name matching // FunctionDef signature. bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { - if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) { - return false; - } - auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName); return func_attr != nullptr && func_attr->has_func() && func_attr->func().name() == func.signature().name(); @@ -824,7 +820,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // update outputs for the fetch nodes, so we just skip them. std::vector> output_mapping; if (!signature.is_in_fetch_set) { - int num_func_outputs = item.output_size(); + int num_func_outputs = 0; + for (const auto& out_arg : item.outputs()) { + num_func_outputs += out_arg.output_nodes.size(); + } absl::flat_hash_set remove; for (int i = 0; i < num_func_outputs; ++i) { @@ -975,8 +974,10 @@ NodeDef InlinedFunctionInputsNode(const NodeDef& func_node, AttrValue::ListValue* type_list = (*inputs.mutable_attr())["T"].mutable_list(); - for (const InputArgInstantiation& input_arg : item.inputs()) { - type_list->add_type(input_arg.data_type); + for (const InputArgExpansion& input_arg : item.inputs()) { + for (int i = 0; i < input_arg.placeholders.size(); ++i) { + type_list->add_type(input_arg.data_type); + } } return inputs; @@ -995,11 +996,12 @@ NodeDef InlinedFunctionOutputsNode( AttrValue::ListValue* type_list = (*outputs.mutable_attr())["T"].mutable_list(); - for (const OutputArgInstantiation& output_arg : item.outputs()) { - const absl::string_view output_tensor = - output_tensors.at(output_arg.node_name); - type_list->add_type(output_arg.data_type); - outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor)); + for (const OutputArgExpansion& output_arg : item.outputs()) { + for (const string& output_node : output_arg.output_nodes) { + const absl::string_view output_tensor = output_tensors.at(output_node); + type_list->add_type(output_arg.data_type); + outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor)); + } } return outputs; @@ -1026,24 +1028,29 @@ Status InlineDirectFunctionCall(const NodeDef& func_node, ". Error: ", item_status.error_message()); } - // Mapping from input arg node name to function input position. - absl::flat_hash_map input_args_idx; - for (const InputArgInstantiation& input_arg : item.inputs()) { - const int idx = input_args_idx.size(); - input_args_idx[input_arg.node_name] = idx; + // Mapping from input placeholder name to function input position. + absl::flat_hash_map input_placeholders_idx; + for (const InputArgExpansion& input_arg : item.inputs()) { + for (const string& placeholder : input_arg.placeholders) { + const int idx = input_placeholders_idx.size(); + input_placeholders_idx[placeholder] = idx; + } } - // Mapping from the '_Retval' node name to the output tensor. - absl::flat_hash_map output_tensors; - for (const NodeDef& func_body_node : item.function_body().node()) { - if (!IsRetval(func_body_node)) continue; - if (func_body_node.input_size() != 1) { - return errors::Internal("_Retval node must have single input: ", - SummarizeNodeDef(func_body_node)); + // Bypass identity nodes added to the graph in place of function outputs. + absl::flat_hash_set output_nodes; + for (const OutputArgExpansion& output_arg : item.outputs()) { + for (const string& output_node : output_arg.output_nodes) { + output_nodes.insert(output_node); } - output_tensors.emplace(func_body_node.name(), func_body_node.input(0)); } + // For each function output value we added an identity node that reads the + // tensor from one of the function body nodes. When we inline function into + // the main graph we want to bypass these nodes, so we keep a mapping from + // 'output node name' -> 'output tensor name'. + absl::flat_hash_map output_tensors; + // Hook inlined function inputs to IdentityN node. NodeDef* func_inputs = optimized_graph->add_node(); *func_inputs = InlinedFunctionInputsNode(func_node, item); @@ -1051,18 +1058,22 @@ Status InlineDirectFunctionCall(const NodeDef& func_node, for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { const string& node_name = func_body_node.name(); - // Skip function output nodes. - if (IsRetval(func_body_node)) continue; + // Skip output identity node, and update a mapping to the output tensor. + if (IsIdentity(func_body_node) && output_nodes.count(node_name)) { + output_tensors.emplace(node_name, func_body_node.input(0)); + continue; + } - // Turn _Arg nodes added in place of input arguments into identity nodes. - const auto input_arg_idx = input_args_idx.find(node_name); - if (input_arg_idx != input_args_idx.end()) { + // Turn placeholders added in place of input arguments into identity nodes. + const auto input_placeholder_idx = input_placeholders_idx.find(node_name); + if (input_placeholder_idx != input_placeholders_idx.end()) { CHECK_EQ(0, func_body_node.input_size()); func_body_node.set_op("Identity"); - func_body_node.mutable_attr()->erase("index"); + (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype"); + func_body_node.mutable_attr()->erase("dtype"); func_body_node.mutable_attr()->erase("shape"); - func_body_node.add_input( - strings::StrCat(func_inputs->name(), ":", input_arg_idx->second)); + func_body_node.add_input(strings::StrCat(func_inputs->name(), ":", + input_placeholder_idx->second)); } else { // Update the input names if any. for (string& input : *func_body_node.mutable_input()) { @@ -1335,8 +1346,10 @@ Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx, // Names of the function body nodes that return function output values. absl::flat_hash_set output_nodes; - for (const auto& output_arg : item.outputs()) { - output_nodes.insert(output_arg.node_name); + for (const auto& output_expansion : item.outputs()) { + for (const auto& output_node : output_expansion.output_nodes) { + output_nodes.insert(output_node); + } } GraphTopologyView topology_view; @@ -1417,10 +1430,7 @@ Status CheckThatSideEffectsWillExecute( // can't produce any visible side-effects. const bool read_only = IsReadVariableOp(func_body_node); - // _Retval marked as stateful, but we will remove it before inlining. - const bool retval = IsRetval(func_body_node); - - if (read_only || retval || !node_must_execute) continue; + if (read_only || !node_must_execute) continue; VLOG(3) << "Check that node " << func_body_node.name() << " will execute after inlining."; @@ -1460,7 +1470,7 @@ Status CheckThatSideEffectsWillExecute( Status PlaceInlinedFunctionBody( const NodeDef& func_node, const GrapplerFunctionItem& item, - const absl::flat_hash_map& input_args_idx, + const absl::flat_hash_map& input_placeholders_idx, FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) { // Control flow lowering and Placer works with a Graph object. std::unique_ptr func_body_graph = @@ -1488,14 +1498,15 @@ Status PlaceInlinedFunctionBody( TF_RETURN_IF_ERROR(pass.Run(opt_options)); // ------------------------------------------------------------------------ // - // Before placing the function body nodes we pin input arguments to the + // Before placing the function body nodes we pin input placeholders to the // same device as their corresponding input nodes. for (Node* func_body_node : func_body_graph->nodes()) { - const auto input_arg_idx = input_args_idx.find(func_body_node->name()); + const auto input_placeholder_idx = + input_placeholders_idx.find(func_body_node->name()); - if (input_arg_idx != input_args_idx.end()) { - const int input_idx = input_arg_idx->second; + if (input_placeholder_idx != input_placeholders_idx.end()) { + const int input_idx = input_placeholder_idx->second; const GraphView::OutputPort output_port = ctx->graph_view().GetRegularFanin({&func_node, input_idx}); @@ -1620,26 +1631,45 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, inputs.push_back(tensor_id); } - // Mapping from input argument node to function input position. - absl::flat_hash_map input_args_idx; - for (const InputArgInstantiation& input_arg : item.inputs()) { - const int idx = input_args_idx.size(); - input_args_idx[input_arg.node_name] = idx; + // Mapping from input placeholder name to function input position. + absl::flat_hash_map input_placeholders_idx; + for (const InputArgExpansion& input_arg : item.inputs()) { + for (const string& placeholder : input_arg.placeholders) { + const int idx = input_placeholders_idx.size(); + input_placeholders_idx[placeholder] = idx; + } } const string prefix = strings::StrCat(func_node.name(), "/"); // ------------------------------------------------------------------------ // - // Mapping from the '_Retval' node name to the output tensor. - absl::flat_hash_map output_tensors; + // For each function output value we added an identity node that reads the + // tensor from one of the function body nodes. When we inline function into + // the main graph we want to bypass these nodes, so we keep a mapping from + // 'output node name' -> 'output tensor name'. + absl::flat_hash_map output_tensors; - for (const NodeDef& func_body_node : item.function_body().node()) { - if (!IsRetval(func_body_node)) continue; - if (func_body_node.input_size() != 1) { - return errors::Internal("_Retval node must have single input: ", - SummarizeNodeDef(func_body_node)); + // Unique names of nodes producing tensors in `output_tensors`. + absl::flat_hash_set output_tensors_nodes; + + // Identity nodes added to the function body in place of function outputs. + absl::flat_hash_set output_nodes; + for (const OutputArgExpansion& output_arg : item.outputs()) { + for (const string& output_node : output_arg.output_nodes) { + output_nodes.insert(output_node); + } + } + + for (const NodeDef& func_body_node : item.graph.node()) { + const string& node_name = func_body_node.name(); + + if (IsIdentity(func_body_node) && output_nodes.count(node_name)) { + const string& output_tensor = func_body_node.input(0); + output_tensors.emplace(node_name, output_tensor); + + SafeTensorId tensor_id = ParseTensorName(output_tensor); + output_tensors_nodes.insert(tensor_id.node()); } - output_tensors.emplace(func_body_node.name(), func_body_node.input(0)); } // ------------------------------------------------------------------------ // @@ -1713,8 +1743,8 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, // make sure that after inlining all nodes will have valid device assignment. GraphDef placed_graph_def; - TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(func_node, item, input_args_idx, - ctx, &placed_graph_def)); + TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody( + func_node, item, input_placeholders_idx, ctx, &placed_graph_def)); // ------------------------------------------------------------------------ // // After all nodes placed we need to prepare them for inlining into the @@ -1728,14 +1758,15 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) { const string& node_name = func_body_node.name(); - // Turn _Arg nodes added in place of input arguments into identity nodes. - const auto input_arg_idx = input_args_idx.find(node_name); - if (input_arg_idx != input_args_idx.end()) { + // Turn placeholders added in place of input arguments into identity nodes. + const auto input_placeholder_idx = input_placeholders_idx.find(node_name); + if (input_placeholder_idx != input_placeholders_idx.end()) { DCHECK_EQ(0, func_body_node.input_size()); func_body_node.set_op("Identity"); - func_body_node.mutable_attr()->erase("index"); + (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype"); + func_body_node.mutable_attr()->erase("dtype"); func_body_node.mutable_attr()->erase("shape"); - const int input_idx = input_arg_idx->second; + const int input_idx = input_placeholder_idx->second; func_body_node.add_input(inputs[input_idx].ToString()); // Add a control dependency on 'inputs_ready' node, to guarantee that all @@ -1788,7 +1819,17 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, // ------------------------------------------------------------------------ // // Check that after inlining all side-effects will be executed in well defined // order. We do it by checking if there is a path from stateful/dataset ops to - // one of the control output nodes. + // one of the output nodes. + + // Because we rename all the nodes before inlining, we need a copy of + // output_nodes with a new names. + absl::flat_hash_set inlined_output_nodes; + for (const string& output_node : output_nodes) { + inlined_output_nodes.insert(inlined_node_name(output_node)); + } + const auto is_inlined_output_node = [&](const NodeDef& node) -> bool { + return inlined_output_nodes.find(node.name()) != inlined_output_nodes.end(); + }; // Names of the inlined control output nodes. absl::flat_hash_set inlined_control_output_nodes; @@ -1844,8 +1885,10 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, } for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) { - // We bypass _Retval nodes and fetch tensors from `retval.input(0)`. - if (IsRetval(func_body_node)) continue; + // Skip output identity nodes. + if (IsIdentity(func_body_node) && is_inlined_output_node(func_body_node)) + continue; + optimized_graph->add_node()->Swap(&func_body_node); } @@ -1853,17 +1896,19 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, // not copy the original function call node, so we have to setup tensor // mapping from old output tensors, to the outputs of inlined nodes. int output_idx = 0; - for (const OutputArgInstantiation& output : item.outputs()) { - const string& output_tensor = output_tensors.at(output.node_name); + for (const OutputArgExpansion& output : item.outputs()) { + for (const string& output_node : output.output_nodes) { + const string& output_tensor = output_tensors.at(output_node); - const SafeTensorId from_tensor(func_node.name(), output_idx++); - const SafeTensorId to_tensor = ParseTensorName(output_tensor); + const SafeTensorId from_tensor(func_node.name(), output_idx++); + const SafeTensorId to_tensor = ParseTensorName(output_tensor); - const SafeTensorId inlined_to_tensor = - SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()), - to_tensor.index()); + const SafeTensorId inlined_to_tensor = + SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()), + to_tensor.index()); - ctx->AddTensorMapping(from_tensor, inlined_to_tensor); + ctx->AddTensorMapping(from_tensor, inlined_to_tensor); + } } // If function call node was in keep_ops set, it means that we need to keep a diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 04e9137bef3..83f9468e3f3 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -123,7 +123,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SkipErrorsIfGraphNotModified) { // Standard XTimesTwo() function. FunctionDef x_times_two = test::function::XTimesTwo(); - // Function signature has non-type attribute (currently not supported). + // Function with sequence of tensors as an input (currently not supported). FunctionDef my_identity_n = FunctionDefHelper::Create( // Name "MyIdentityN", diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index dca7b603590..0970134ed2b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -367,8 +367,8 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { if (node.name() == "my_mul/inlined_inputs" && ++count) { EXPECT_EQ("IdentityN", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("x", node.input(1)); + EXPECT_EQ("x:0", node.input(0)); + EXPECT_EQ("x:0", node.input(1)); } else if (node.name() == "my_mul/x" && ++count) { EXPECT_EQ("Identity", node.op()); EXPECT_EQ(1, node.input_size()); @@ -623,17 +623,17 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) { MetaOptimizer optimizer(nullptr, config_proto); // Define simple function library with two identical mul functions. - FunctionDef mul_func_1 = FunctionDefHelper::Create( - "MyMul1", {"x:float", "y:float"}, {"z:float"}, {}, - {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}}, - /*ret_def=*/ - {{"z", "mul:z:0"}}); + FunctionDef mul_func_1 = + FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"}, + {}, {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /*ret_def=*/ + {{"z", "mul:z:0"}}); - FunctionDef mul_func_2 = FunctionDefHelper::Create( - "MyMul2", {"x:float", "y:float"}, {"z:float"}, {}, - {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}}, - /*ret_def=*/ - {{"z", "mul:z:0"}}); + FunctionDef mul_func_2 = + FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"}, + {}, {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /*ret_def=*/ + {{"z", "mul:z:0"}}); // Tensorflow graph: // diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index b24f27dde77..d565fe6db3f 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -172,7 +172,6 @@ cc_library( hdrs = ["functions.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 3061a9780c7..2ec9794b68a 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -17,9 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" #include "absl/strings/substitute.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -36,29 +34,306 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { + +Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration, + const NodeDef& node, + GrapplerFunctionConnectivity* connectivity) { + tensorflow::NameRangeMap outputs_range_map; + TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( + node, registration.op_def, nullptr, &outputs_range_map)); + connectivity->RegisterFunctionBodyOutputs(node.name(), + std::move(outputs_range_map)); + return Status::OK(); +} + +Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib, + const NodeDef& node, + GrapplerFunctionConnectivity* connectivity) { + const OpRegistrationData* registration; + TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration)); + return RegisterFunctionBodyOutputs(*registration, node, connectivity); +} + +// Replace the placeholder attribute values with the values specified in +// instantiation attributes. +Status ResolveFunctionBodyNodeAttrPlaceholders( + const AttrSlice& func_instantiation_attr, NodeDef* node) { + for (auto& attr : *node->mutable_attr()) { + const string& placeholder = attr.second.placeholder(); + if (placeholder.empty()) continue; + + const AttrValue* attr_value = func_instantiation_attr.Find(placeholder); + if (attr_value) { + attr.second = *attr_value; + } else { + return errors::InvalidArgument("Can't resolve placeholder: ", + placeholder); + } + } + return Status::OK(); +} + +} // namespace + +void GrapplerFunctionConnectivity::RegisterInputArgExpansion( + InputArgExpansion input_arg_expansion) { + string input_name = input_arg_expansion.input_name; + const auto& placeholders = input_arg_expansion.placeholders; + + for (int i = 0; i < placeholders.size(); ++i) { + const string& placeholder = input_arg_expansion.placeholders[i]; + input_arg_placeholders_.insert( + {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}}); + } + input_arg_expansions_.insert( + {std::move(input_name), std::move(input_arg_expansion)}); +} + +void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( + const string& node_name, tensorflow::NameRangeMap&& outputs) { + function_body_outputs_[node_name] = std::move(outputs); +} + +Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( + const string& func_def_input, std::vector* graph_def_inputs) const { + using ::tensorflow::strings::Scanner; + + if (IsControlInput(func_def_input)) { + graph_def_inputs->push_back(func_def_input); + return Status::OK(); + } + + // Parse input format: "node_name[:node_output][:position]" + string node_name; + string node_output; + int position = -1; + + StringPiece capture; + StringPiece remaining; + + // Parse "node_name" + if (Scanner(func_def_input) + .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) + .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_name = string(capture.data(), capture.size()); + } + + // Parse "node_output" if it exists + if (Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .One(strings::Scanner::LETTER) + .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_output = string(capture.data(), capture.size()); + } + + // Parse "position" if it exists + if (Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .Many(strings::Scanner::DIGIT) + .GetResult(nullptr, &capture)) { + CHECK(strings::safe_strto32(capture, &position)); + } + + // If "node_output" is not empty, it must be an output of a function body node + bool is_function_body_output = !node_output.empty(); + + // Function input argument: "node_name[:position]" + if (!is_function_body_output) { + auto input_arg = input_arg_expansions_.find(node_name); + if (input_arg != input_arg_expansions_.end()) { + const InputArgExpansion& input_arg_expansion = input_arg->second; + const auto& placeholders = input_arg_expansion.placeholders; + + if (position == -1) { + // If position is not defined use all placeholders + graph_def_inputs->reserve(placeholders.size()); + for (const string& placeholder : placeholders) { + graph_def_inputs->push_back(placeholder); + } + } else { + if (position > input_arg_expansion.placeholders.size() - 1) { + return errors::InvalidArgument("Invalid input ", node_name, + "position: ", position, + " (out of range)"); + } + graph_def_inputs->push_back(input_arg_expansion.placeholders[position]); + } + + return Status::OK(); + } + } + + // Function body output: "node_name:node_output[:position]" + if (is_function_body_output) { + auto function_body_outputs = function_body_outputs_.find(node_name); + if (function_body_outputs != function_body_outputs_.end()) { + const tensorflow::NameRangeMap& outputs = function_body_outputs->second; + auto output = outputs.find(node_output); + if (output != outputs.end()) { + const auto& output_range = output->second; + + if (position == -1) { + graph_def_inputs->reserve(graph_def_inputs->size() + + output_range.second - output_range.first); + // If position is not defined expand node output range + for (int i = output_range.first; i < output_range.second; ++i) { + graph_def_inputs->push_back( + i == 0 ? node_name : absl::StrCat(node_name, ":", i)); + } + } else { + if (position > (output_range.second - output_range.first)) { + return errors::InvalidArgument( + "Invalid node ", node_name, " output ", node_output, + " position: ", position, " (out of range)"); + } + int pos = output_range.first + position; + graph_def_inputs->push_back( + pos == 0 ? node_name : absl::StrCat(node_name, ":", pos)); + } + + return Status::OK(); + } + } + } + + return errors::InvalidArgument("Failed to expand a function def input: ", + func_def_input); +} + +Status GrapplerFunctionConnectivity::ExpandNodeInputs( + NodeDef* function_body_node) const { + std::vector expanded_inputs; + + for (const string& function_def_input : function_body_node->input()) { + TF_RETURN_IF_ERROR( + ExpandFunctionDefInput(function_def_input, &expanded_inputs)); + } + + function_body_node->clear_input(); + for (string& expanded_input : expanded_inputs) + function_body_node->add_input(std::move(expanded_input)); + return Status::OK(); +} + +Status GrapplerFunctionConnectivity::AsFunctionDefInput( + const string& graph_def_input, string* func_def_input) const { + if (IsControlInput(graph_def_input)) { + *func_def_input = graph_def_input; + return Status::OK(); + } + + const TensorId tensor = ParseTensorName(graph_def_input); + DCHECK_GE(tensor.index(), 0); + + const absl::string_view node_name = tensor.node(); + const int index = tensor.index(); + + // Check if it's an input arg placeholder + if (tensor.index() == 0) { + const auto is_input_placeholder = input_arg_placeholders_.find(node_name); + if (is_input_placeholder != input_arg_placeholders_.end()) { + const InputArgPlaceholder& placeholder = is_input_placeholder->second; + *func_def_input = + absl::StrCat(placeholder.input_name, ":", placeholder.input_index); + return Status::OK(); + } + } + + // It must be output from one of the function body nodes + const auto is_body_output = function_body_outputs_.find(tensor.node()); + if (is_body_output != function_body_outputs_.end()) { + const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second; + + for (const auto& el : outputs_range_map) { + const auto& output_name = el.first; + const auto& output_range = el.second; + if (index >= output_range.first && index < output_range.second) { + int pos = index - output_range.first; + *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos); + return Status::OK(); + } + } + } + + return errors::InvalidArgument("Unknown graph def input: ", graph_def_input); +} + +Status GrapplerFunctionConnectivity::AsFunctionDefNode( + NodeDef* function_body_node) const { + string func_def_input; + + for (int i = 0; i < function_body_node->input_size(); ++i) { + TF_RETURN_IF_ERROR( + AsFunctionDefInput(function_body_node->input(i), &func_def_input)); + function_body_node->set_input(i, func_def_input); + } + + return Status::OK(); +} + +Status GrapplerFunctionItemInstantiation::GetTypeAttr( + const string& type_attr_name, DataType* data_type) const { + const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name); + if (type_attr == nullptr) { + return errors::InvalidArgument("Type attribute ", type_attr_name, + " is not defined"); + } else if (type_attr->type() == DT_INVALID) { + return errors::InvalidArgument("Type attribute ", type_attr_name, + " is not defined with a valid type"); + } else { + *data_type = type_attr->type(); + } + return Status::OK(); +} + +Status GrapplerFunctionItemInstantiation::GetArgType( + const OpDef::ArgDef& arg, DataType* data_type) const { + if (arg.type() != DT_INVALID) { + *data_type = arg.type(); + } else { + if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) { + return errors::InvalidArgument( + "Arguments with sequence of tensors are not supported. Unsupported " + "argument name: ", + arg.name()); + } + TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type)); + } + return Status::OK(); +} + GrapplerFunctionItem::GrapplerFunctionItem( string func_name, string description, AttrSlice func_attr, - std::vector input_args, - std::vector output_args, + std::vector input_arg_expansions, + std::vector output_arg_expansions, std::vector control_outputs, const int graph_def_version, const bool is_stateful, GraphDef&& function_body) : description_(std::move(description)), func_attr_(func_attr), - input_args_(std::move(input_args)), - output_args_(std::move(output_args)), + input_arg_expansions_(std::move(input_arg_expansions)), + output_arg_expansions_(std::move(output_arg_expansions)), control_outputs_(std::move(control_outputs)), is_stateful_(is_stateful) { id = std::move(func_name); graph = std::move(function_body); - graph.mutable_versions()->set_producer(graph_def_version); - // Fill the feed nodes with function input arguments. - for (const InputArgInstantiation& input_arg : input_args_) { - feed.push_back({input_arg.node_name, Tensor()}); + graph.mutable_versions()->set_producer(graph_def_version); + // Fill the feed nodes with input placeholders. + for (const InputArgExpansion& input_arg : input_arg_expansions_) { + for (const string& placeholder : input_arg.placeholders) { + feed.push_back({placeholder, Tensor()}); + } } // Fill the fetch nodes with outputs. - for (const OutputArgInstantiation& output_arg : output_args_) { - fetch.push_back(output_arg.node_name); + for (const OutputArgExpansion& output_arg : output_arg_expansions_) { + for (const string& output_node : output_arg.output_nodes) { + fetch.push_back(output_node); + } } // We must keep all control output nodes. for (const ControlOutput& control_output : control_outputs_) { @@ -72,29 +347,28 @@ GrapplerFunctionItem::GrapplerFunctionItem( const string& GrapplerFunctionItem::description() const { return description_; } -const std::vector& GrapplerFunctionItem::inputs() const { - return input_args_; +const std::vector& GrapplerFunctionItem::inputs() const { + return input_arg_expansions_; } -const InputArgInstantiation& GrapplerFunctionItem::input(int i) const { - return input_args_[i]; +const InputArgExpansion& GrapplerFunctionItem::input(int i) const { + return input_arg_expansions_[i]; } const std::size_t GrapplerFunctionItem::input_size() const { - return input_args_.size(); + return input_arg_expansions_.size(); } -const std::vector& GrapplerFunctionItem::outputs() - const { - return output_args_; +const std::vector& GrapplerFunctionItem::outputs() const { + return output_arg_expansions_; } -const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const { - return output_args_[i]; +const OutputArgExpansion& GrapplerFunctionItem::output(int i) const { + return output_arg_expansions_[i]; } const std::size_t GrapplerFunctionItem::output_size() const { - return output_args_.size(); + return output_arg_expansions_.size(); } const std::vector& GrapplerFunctionItem::control_outputs() @@ -153,23 +427,15 @@ Status InstantiationTypeParameters( return errors::InvalidArgument("Type parameters output map must be empty"); } - const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status { - if (!arg.type_attr().empty()) { - DataType dtype; - TF_RETURN_IF_ERROR( - GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype)); - type_parameters->emplace(arg.type_attr(), dtype); + GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr); - } else if (!arg.type_list_attr().empty()) { - std::vector dtypes; - TF_RETURN_IF_ERROR( - GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes)); - int index = 0; - for (const DataType& dtype : dtypes) { - type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index), - dtype); - ++index; - } + const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) { + // Check if it's unknown and unresolved type. + if (arg.type() == DT_INVALID && + type_parameters->find(arg.type_attr()) == type_parameters->end()) { + DataType data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type)); + type_parameters->insert({arg.type_attr(), data_type}); } return Status::OK(); }; @@ -193,7 +459,8 @@ Status InstantiationBodyParameters( for (auto& attr : func_body_node.attr()) { const string& placeholder = attr.second.placeholder(); - if (placeholder.empty() || body_parameters->contains(placeholder)) { + if (placeholder.empty() || + body_parameters->find(placeholder) != body_parameters->end()) { continue; } @@ -231,13 +498,15 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, } } - // Instantiate function into a statically defined FunctionBody Graph. - std::unique_ptr fbody; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody)); + // Helper methods to lookup function instantiation attributes + GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr); + // Mapping from FunctionDef input format (name[:output][:position]) to + // GraphDef input format (name[:position]) + GrapplerFunctionConnectivity connectivity; + + // Instantiate function body into a statically defined graph def. GraphDef function_body; - fbody->graph->ToGraphDef(&function_body); // Function body shares the library with the graph that instantiated it. We do // not need a full copy of the function library, just the reachable subset. @@ -249,25 +518,122 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, flib.num_functions() - function_body.library().function_size(), signature.name(), function_body.library().function_size()); - const int num_instantiated_inputs = fbody->arg_types.size(); - const int num_instantiated_outputs = fbody->ret_types.size(); + // TODO(ezhulenev): support functions with tensor sequence inputs/outputs - std::vector inputs; - inputs.reserve(num_instantiated_inputs); - - for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) { - const Node* node = fbody->arg_nodes[in_id]; - const DataType& dtype = fbody->arg_types[in_id]; - inputs.emplace_back(node->name(), dtype); + // Make sure that there are no tensor lists in inputs or outputs. + for (const OpDef::ArgDef& input : signature.input_arg()) { + if (!input.type_list_attr().empty() || !input.number_attr().empty()) { + return errors::InvalidArgument( + "Inputs with lists of tensors are not supported. Input: ", + input.name()); + } + } + for (const OpDef::ArgDef& output : signature.output_arg()) { + if (!output.type_list_attr().empty() || !output.number_attr().empty()) { + return errors::InvalidArgument( + "Outputs with lists of tensors are not supported. Output: ", + output.name()); + } } - std::vector outputs; - outputs.reserve(num_instantiated_outputs); + std::vector inputs; + inputs.reserve(signature.input_arg_size()); - for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) { - const Node* node = fbody->ret_nodes[out_id]; - const DataType& dtype = fbody->ret_types[out_id]; - outputs.emplace_back(node->name(), dtype); + // For each input argument create a placeholder in function body. + for (const OpDef::ArgDef& input : signature.input_arg()) { + DataType input_data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type)); + + NodeDef* placeholder = function_body.add_node(); + placeholder->set_name(input.name()); + placeholder->set_op("Placeholder"); + (*placeholder->mutable_attr())["dtype"].set_type(input_data_type); + (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank( + true); + + InputArgExpansion input_expansion{/*input_name=*/input.name(), + /*data_type=*/input_data_type, + /*is_ref=*/input.is_ref(), + /*placeholders=*/{input.name()}}; + connectivity.RegisterInputArgExpansion(input_expansion); + inputs.push_back(std::move(input_expansion)); + } + + // Keep names of all nodes in the function body to guarantee that we do not + // add an identity with a duplicate name. + absl::flat_hash_set func_body_nodes; + + // Generate unique output node name: "${out_arg_name}_output_node_${index}". + const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out, + int index) -> string { + string name = absl::StrCat(out.name(), "_output_node_", index); + int i = 1; + while (func_body_nodes.find(name) != func_body_nodes.end()) { + name = absl::StrCat(out.name(), "_output_node_", index, "_", i++); + } + return name; + }; + + // Add all function nodes to the function body. + for (const NodeDef& func_def_node : func.node_def()) { + func_body_nodes.insert(func_def_node.name()); + + NodeDef* new_node = function_body.add_node(); + *new_node = func_def_node; + + const OpRegistrationData* registration; + TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), ®istration)); + + // Resolve all placeholder values using function instantiation attributes. + TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders( + func_instantiation_attr, new_node)); + + // Register node output range in a function connectivity. + TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node, + &connectivity)); + } + + // Rewrite inputs to use GraphDef format + for (NodeDef& node : *function_body.mutable_node()) { + TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node)); + } + + std::vector outputs; + outputs.reserve(signature.output_arg_size()); + + // For each function output argument we create an Identity node in the + // function body, that reads output tensor from the function body node. + for (const OpDef::ArgDef& out : signature.output_arg()) { + DataType output_data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type)); + + std::vector output_tensors; + auto ret = func.ret().find(out.name()); + TF_RETURN_IF_ERROR( + ret != func.ret().end() + // Expand outputs using provided output mapping + ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors) + // Otherwise output must be one of the function inputs + : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors)); + + absl::InlinedVector output_nodes; + for (int i = 0; i < output_tensors.size(); ++i) { + const string& output_tensor = output_tensors[i]; + + NodeDef* identity = function_body.add_node(); + identity->set_name(output_node_name(out, i)); + identity->set_op("Identity"); + (*identity->mutable_attr())["T"].set_type(output_data_type); + identity->add_input(output_tensor); + + output_nodes.push_back(identity->name()); + } + + OutputArgExpansion output{/*output_name=*/out.name(), + /*data_type=*/output_data_type, + /*is_ref=*/out.is_ref(), + /*output_nodes=*/std::move(output_nodes)}; + outputs.push_back(std::move(output)); } // Control outputs ensure that all side-effectful nodes in the function body @@ -295,42 +661,70 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, item); } +// Register GrapplerFunctionItem input arg expansion and function body outputs +// in the GrapplerFunctionConnectivity. +Status RegisterGrapplerFunctionConnectivity( + const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, + GrapplerFunctionConnectivity* connectivity) { + for (const InputArgExpansion& input : item.inputs()) { + connectivity->RegisterInputArgExpansion(input); + } + for (const NodeDef& func_body_node : item.function_body().node()) { + TF_RETURN_IF_ERROR( + RegisterFunctionBodyOutputs(flib, func_body_node, connectivity)); + } + return Status::OK(); +} + Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, GrapplerFunctionItem* item) { if (!IsConstant(input_const)) { - return errors::InvalidArgument("Input node is not a constant: ", - SummarizeNodeDef(input_const)); - } - if (input_index < 0 || input_index >= item->input_size()) { - return errors::InvalidArgument( - "Function input index is out of bound: index=", input_index, - " input_size=", item->input_size()); + return errors::InvalidArgument("Input node ", input_const.name(), + " is not a constant"); } - const InputArgInstantiation& input_arg = item->input(input_index); + auto& inputs = item->input_arg_expansions_; + // Find input arg expansion and input placeholder position in it for the + // given function input position. + InputArgExpansion* input_arg_expansion = nullptr; + int placeholder_idx = input_index; + + for (InputArgExpansion& input : inputs) { + if (placeholder_idx < input.placeholders.size()) { + input_arg_expansion = &input; + break; + } + placeholder_idx -= input.placeholders.size(); + } + + if (input_arg_expansion == nullptr) { + return errors::InvalidArgument("Input placeholder not found: input_index=", + input_index, " function=", item->id); + } + + // Delete placeholder from input expansion. + string placeholder_name = input_arg_expansion->placeholders[placeholder_idx]; + input_arg_expansion->placeholders.erase( + input_arg_expansion->placeholders.begin() + placeholder_idx); + + // Delete empty input expansions. + inputs.erase(std::remove_if(inputs.begin(), inputs.end(), + [](const InputArgExpansion& input) { + return input.placeholders.empty(); + }), + inputs.end()); + + // Replace placeholder node in the function body with a const node. for (NodeDef& node : *item->graph.mutable_node()) { - // Replace '_Arg' node in the function body with a 'Const' node. - if (node.name() == input_arg.node_name) { + if (node.name() == placeholder_name) { node = input_const; - node.set_name(input_arg.node_name); - node.clear_input(); + node.set_name(placeholder_name); + node.clear_input(); // remove potential control inputs node.clear_device(); // device placement is defined by instantiating node } - - // Update index in all inputs after the removed const input. - if (IsArg(node)) { - auto attrs = AttrSlice(node); - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index)); - if (index >= input_index) { - (*node.mutable_attr())["index"].set_i(index - 1); - } - } } - item->input_args_.erase(item->input_args_.begin() + input_index); - return Status::OK(); } @@ -339,24 +733,31 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set& remove_outputs, std::vector>* output_mapping) { DCHECK(output_mapping->empty()); + // Code below assumes that we do not support tensor list outputs and there is + // a 1-to-1 mapping between output tensor and output argument expansion. + for (const OutputArgExpansion& out_arg : item->outputs()) { + DCHECK(out_arg.output_nodes.size() == 1) + << "Output arg expansion must have single output"; + } + // Do some sanity checking of the removed outputs positions. for (int remove_output : remove_outputs) { if (remove_output < 0 || remove_output >= item->output_size()) { return errors::InvalidArgument( "Function output index is out of bound: index=", remove_output, - " output_size=", item->output_size()); + " max_output_index=", item->output_size()); } } - absl::flat_hash_set remove_output_args; - const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) { + absl::flat_hash_set remove_output_args; + const auto is_remove_output_arg = [&](const OutputArgExpansion& output) { return remove_output_args.find(&output) != remove_output_args.end(); }; for (int i = 0; i < item->output_size(); ++i) { - const OutputArgInstantiation& output = item->output(i); - if (remove_outputs.contains(i)) { - VLOG(3) << "Remove functions output: name=" << output.node_name + const OutputArgExpansion& output = item->output(i); + if (remove_outputs.find(i) != remove_outputs.end()) { + VLOG(3) << "Remove functions output: output_name=" << output.output_name << "(index = " << i << ")"; remove_output_args.insert(&output); } else if (!remove_output_args.empty()) { @@ -365,130 +766,12 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set& remove_outputs, } } - // Update 'index' attribute in all '_Retval' nodes that are in output mapping. - for (NodeDef& node : *item->graph.mutable_node()) { - if (IsRetval(node)) { - auto attrs = AttrSlice(node); - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index)); - - for (const auto& mapping : *output_mapping) { - const int from = mapping.first; - const int to = mapping.second; - if (index == from) { - (*node.mutable_attr())["index"].set_i(to); - } - } - } - } - - auto& o = item->output_args_; + auto& o = item->output_arg_expansions_; o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end()); return Status::OK(); } -namespace { - -// FunctionDef uses different connectivity encoding for the function body nodes, -// than a GraphDef (see function.proto for details). This is a helper class that -// converts inputs in GraphDef format (node[:position]) to the FunctionDef -// format (node:output[:position]). -class MakeFunctionDefHelper { - public: - MakeFunctionDefHelper() = default; - - Status Initialize(const GrapplerFunctionItem& item, - const FunctionLibraryDefinition& flib); - - // Converts input name from GraphDef format (name[:position]) to the - // FunctionDef input format (name[:output][:position]) using registered input - // arg instantiations and function body outputs. - Status AsFunctionDefInput(const string& graph_def_input, - string* func_def_input) const; - - // Updates Node inputs from GraphDef to FunctionDef format. - Status AsFunctionDefNode(NodeDef* function_body_node) const; - - private: - absl::flat_hash_set input_nodes_; - // Mapping from function body node name to output names range map. - absl::flat_hash_map function_body_outputs_; -}; - -Status MakeFunctionDefHelper::Initialize( - const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) { - for (const InputArgInstantiation& input_arg : item.inputs()) { - input_nodes_.insert(input_arg.node_name); - } - - for (const NodeDef& node : item.function_body().node()) { - const OpRegistrationData* registration; - TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration)); - - tensorflow::NameRangeMap outputs_range_map; - TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( - node, registration->op_def, nullptr, &outputs_range_map)); - - function_body_outputs_.emplace(node.name(), std::move(outputs_range_map)); - } - - return Status::OK(); -} - -Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input, - string* func_def_input) const { - if (IsControlInput(graph_def_input)) { - *func_def_input = graph_def_input; - return Status::OK(); - } - - const SafeTensorId tensor = ParseTensorName(graph_def_input); - DCHECK_GE(tensor.index(), 0); - - // Graph def input corresponds to one of the function inputs. - const auto is_input = input_nodes_.find(tensor.node()); - if (is_input != input_nodes_.end()) { - DCHECK_EQ(tensor.index(), 0); - *func_def_input = tensor.node(); - return Status::OK(); - } - - // Or it must be output from one of the function body nodes - const auto is_body_output = function_body_outputs_.find(tensor.node()); - if (is_body_output != function_body_outputs_.end()) { - const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second; - - for (const auto& el : outputs_range_map) { - const auto& output_name = el.first; - const auto& output_range = el.second; - if (tensor.index() >= output_range.first && - tensor.index() < output_range.second) { - *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":", - tensor.index() - output_range.first); - return Status::OK(); - } - } - } - - return errors::InvalidArgument("Unknown graph def input: ", graph_def_input); -} - -Status MakeFunctionDefHelper::AsFunctionDefNode( - NodeDef* function_body_node) const { - string func_def_input; - - for (int i = 0; i < function_body_node->input_size(); ++i) { - TF_RETURN_IF_ERROR( - AsFunctionDefInput(function_body_node->input(i), &func_def_input)); - function_body_node->set_input(i, func_def_input); - } - - return Status::OK(); -} - -} // namespace - Status MakeFunctionDef(const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, FunctionDef* func) { @@ -496,55 +779,86 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, func->mutable_signature()->set_description(item.description()); func->mutable_signature()->set_is_stateful(item.is_stateful()); - MakeFunctionDefHelper helper; - TF_RETURN_IF_ERROR(helper.Initialize(item, flib)); - - // Keep track of '_Arg' nodes that were added to the graph in place of - // instantiated function input arguments. - absl::flat_hash_set input_nodes; - for (const InputArgInstantiation& input_arg : item.inputs()) { - input_nodes.insert(input_arg.node_name); + // Keep track of placeholders that were added to the graph in place of + // expanded function input arguments. + absl::flat_hash_set input_placeholders; + for (const InputArgExpansion& input_arg : item.inputs()) { + for (const string& placeholder : input_arg.placeholders) { + input_placeholders.insert(placeholder); + } } - // Mapping from the '_Retval' node name to the output tensor. + // Keep track of identity nodes that were added to the graph in place of + // expanded function output arguments. + absl::flat_hash_set output_nodes; + for (const OutputArgExpansion& output_arg : item.outputs()) { + for (const string& output_node : output_arg.output_nodes) { + output_nodes.insert(output_node); + } + } + + // If the output identity node was not modified by any optimizer, we can + // bypass it and returns the function value from its input. absl::flat_hash_map output_tensors; for (const NodeDef& func_body_node : item.function_body().node()) { - if (!IsRetval(func_body_node)) continue; - if (func_body_node.input_size() != 1) { - return errors::Internal("_Retval node must have single input: ", - SummarizeNodeDef(func_body_node)); + if (!IsIdentity(func_body_node)) continue; + + const string& node_name = func_body_node.name(); + if (output_nodes.find(node_name) != output_nodes.end()) { + // Grappler optimizers might optimize nodes in the fanin of the output + // node, and forward their control dependencies. We can't express control + // dependencies in a function signature, so we have to keep the node. + if (func_body_node.input_size() == 1) { + VLOG(3) << "Bypass function output node: " << node_name << " -> " + << func_body_node.input(0); + output_tensors.emplace(node_name, func_body_node.input(0)); + } else { + VLOG(3) << "Keep function output node: " << node_name; + } } - output_tensors.emplace(func_body_node.name(), func_body_node.input(0)); } - for (const InputArgInstantiation& input_arg : item.inputs()) { + // Return output tensor name (input of the output node) if it's safe to bypass + // output node, otherwise returns the output node name. + const auto output_tensor = + [&output_tensors](const OutputArgExpansion& output_arg) -> const string& { + const string& output_node = output_arg.output_nodes[0]; + const auto is_output_tensor = output_tensors.find(output_node); + return is_output_tensor == output_tensors.end() ? output_node + : is_output_tensor->second; + }; + + // Build a GrapplerFunctionConnectivity from inputs and new function body. + GrapplerFunctionConnectivity connectivity; + TF_RETURN_IF_ERROR( + RegisterGrapplerFunctionConnectivity(item, flib, &connectivity)); + + // Add function input arguments. + for (const InputArgExpansion& input_arg : item.inputs()) { + DCHECK(input_arg.placeholders.size() == 1) // do some sanity checking + << "Inputs of tensor lists are not supported"; + OpDef::ArgDef arg_def; - arg_def.set_name(input_arg.node_name); + arg_def.set_name(input_arg.input_name); arg_def.set_type(input_arg.data_type); - arg_def.set_is_ref(IsRefType(input_arg.data_type)); + arg_def.set_is_ref(input_arg.is_ref); *func->mutable_signature()->add_input_arg() = arg_def; } // Add function output arguments. - for (const OutputArgInstantiation& output_arg : item.outputs()) { - const string output_name = - absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}}); + for (const OutputArgExpansion& output_arg : item.outputs()) { + DCHECK(output_arg.output_nodes.size() == 1) // do some sanity checking + << "Outputs of tensor lists are not supported"; OpDef::ArgDef arg_def; - arg_def.set_name(output_name); + arg_def.set_name(output_arg.output_name); arg_def.set_type(output_arg.data_type); - arg_def.set_is_ref(IsRefType(output_arg.data_type)); + arg_def.set_is_ref(output_arg.is_ref); *func->mutable_signature()->add_output_arg() = arg_def; - auto it = output_tensors.find(output_arg.node_name); - if (it == output_tensors.end()) { - return errors::Internal( - "Can't find an output tensor for the output node: ", - output_arg.node_name); - } - - TF_RETURN_IF_ERROR(helper.AsFunctionDefInput( - it->second, &(*func->mutable_ret())[output_name])); + TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput( + output_tensor(output_arg), + &(*func->mutable_ret())[output_arg.output_name])); } // Add function control outputs. @@ -563,12 +877,16 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, // Copy function body nodes to the FunctionDef and update input format for (const NodeDef& func_node : item.function_body().node()) { - // Do not copy input/output nodes. - if (IsArg(func_node) || IsRetval(func_node)) continue; + const string& name = func_node.name(); + + // Do not copy input placeholders. + if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue; + // Do not copy output nodes that we bypassed. + if (IsIdentity(func_node) && output_tensors.count(name)) continue; NodeDef* func_def_node = func->add_node_def(); *func_def_node = func_node; - TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node)); + TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node)); } return Status::OK(); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index b03b89af2ab..d450f6a41fc 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -33,22 +33,45 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Function input argument instantiated into an '_Arg' node in the function body -// graph, with an 'index' attribute corresponding to the input position. -struct InputArgInstantiation { - InputArgInstantiation(string node_name, DataType data_type) - : node_name(std::move(node_name)), data_type(data_type) {} - string node_name; +// WARNING(ezhulenev): Currently we do not support functions with inputs or +// outputs instantiated into multiple tensors. This can happen if the +// input/output type is 'T*N' or 'list(type)'. This is enforced by multiple +// checks across this file and also function_optimizer.cc. InputArgExpansion and +// OutputArgExpansion already support lists of tensors, but that's pretty much +// it, all other code is written with assumption that expansions are always of +// size 1. MakeGrapplerFunctionItem will gracefully fail with Status error. +// +// This is a low priority feature, because in practice we don't see a lot (any +// at all?) functions with such arguments. Tensorflow-Eager always produces +// functions with plain input/output arguments. + +// TODO(ezhulenev): Support inputs and outputs of type 'T*N'. +// TODO(ezhulenev): Support inputs and outputs of type 'list(type)'. + +// Depending on the function instantiation attributes, input argument to the +// function might be a single tensor, list of tensors of the same type, or a +// list of tensors of different types. +// +// InputArgExpansion keeps track of the placeholders that were added to the +// function body in place of function inputs and a resolved input data type. +struct InputArgExpansion { + string input_name; DataType data_type; + bool is_ref; + absl::InlinedVector placeholders; }; -// Function output instantiated into a '_Retval' node in the function body -// graph, with an 'index' attribute corresponding to the output position. -struct OutputArgInstantiation { - OutputArgInstantiation(string node_name, DataType data_type) - : node_name(std::move(node_name)), data_type(data_type) {} - string node_name; +// Depending on the function instantiation attributes, output argument is mapped +// to one or more outputs of one of the function body nodes. +// +// OutputArgExpansion keeps track of the Identity nodes that were added to the +// function body to forward output tensors. Adding these output nodes allows +// nested function inlining and specialization (see function optimizer). +struct OutputArgExpansion { + string output_name; DataType data_type; + bool is_ref; + absl::InlinedVector output_nodes; }; // A mapping from control output name to node name in function body graph. @@ -57,6 +80,78 @@ struct ControlOutput { string node_name; }; +// FunctionDef uses different connectivity encoding for the function body nodes, +// then a GraphDef (see function.proto for details). Input name in FunctionDef +// can potentially represent a sequence of tensors (instead just one tensor in +// GraphDef), we need to expand it when converting from FunctionDef to GraphDef, +// and fold it back when doing backward conversion. +class GrapplerFunctionConnectivity { + public: + void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion); + void RegisterFunctionBodyOutputs(const string& node_name, + tensorflow::NameRangeMap&& outputs); + + // Expands input encoded in FunctionDef format (name[:output][:position]) into + // multiple inputs in GraphDef format (name[:position]). + Status ExpandFunctionDefInput(const string& func_def_input, + std::vector* graph_def_inputs) const; + + // Updates Node inputs from FunctionDef to GraphDef format. + Status ExpandNodeInputs(NodeDef* function_body_node) const; + + // When expanding inputs in function def format, single input might be + // expanded into multiple tensors. When converting back to the function def + // format from graph def format, it's always a 1-to-1 relationship. + // FunctionDef built from GrapplerFunctionItem is always specialized to its + // instantiation attributes and length of input args (and node def outputs) is + // known. + + // Converts input name from GraphDef format (name[:position]) to the + // FunctionDef input format (name[:output][:position]) using registered input + // arg expansion and function body outputs. + Status AsFunctionDefInput(const string& graph_def_input, + string* func_def_input) const; + + // Updates Node inputs from GraphDef to FunctionDef format. + Status AsFunctionDefNode(NodeDef* function_body_node) const; + + private: + // Mapping from input name to input arg expansion. + absl::flat_hash_map input_arg_expansions_; + // Mapping from function body node name to output names range map. + absl::flat_hash_map function_body_outputs_; + + // For each placeholder added to the function instantiation graph, we keep a + // mapping back to the function input argument name and index. + struct InputArgPlaceholder { + string input_name; // Name of the function input argument. + int input_index; // Index of a tensor in the function input argument + // expansion, it can be greater than `0` if input + // argument is a list of tensors (aka list(type)). + }; + // Mapping from input arg placeholder to the function input tensor. + absl::flat_hash_map input_arg_placeholders_; +}; + +// Get Function type attributes using attributes of a node that instantiated +// a function. +class GrapplerFunctionItemInstantiation { + public: + explicit GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr) + : func_instantiation_attr_(func_instantiation_attr) {} + + // Get DataType from attributes by name. Return error if attribute is missing, + // or it doesn't define a valid data type. + Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const; + + // Get argument data type. If data type is not explicitly defined, uses + // provided attribute name to look it up in function attributes. + Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const; + + private: + const AttrSlice func_instantiation_attr_; // do not own +}; + // A special case of GrapplerItem, constructed from a TensorFlow Function. class GrapplerFunctionItem : public GrapplerItem { public: @@ -64,12 +159,12 @@ class GrapplerFunctionItem : public GrapplerItem { const string& description() const; - const std::vector& inputs() const; - const InputArgInstantiation& input(int i) const; + const std::vector& inputs() const; + const InputArgExpansion& input(int i) const; const std::size_t input_size() const; - const std::vector& outputs() const; - const OutputArgInstantiation& output(int i) const; + const std::vector& outputs() const; + const OutputArgExpansion& output(int i) const; const std::size_t output_size() const; const std::vector& control_outputs() const; @@ -95,8 +190,8 @@ class GrapplerFunctionItem : public GrapplerItem { GrapplerFunctionItem(string func_name, string description, AttrSlice func_attr, - std::vector input_args, - std::vector output_args, + std::vector input_arg_expansions, + std::vector output_arg_expansions, std::vector control_outputs, int graph_def_version, bool is_stateful, GraphDef&& function_body); @@ -105,8 +200,8 @@ class GrapplerFunctionItem : public GrapplerItem { AttrSlice func_attr_; // Attributes specific to function definition that // produced this item (FuncDef.attr field). - std::vector input_args_; - std::vector output_args_; + std::vector input_arg_expansions_; + std::vector output_arg_expansions_; std::vector control_outputs_; bool is_stateful_ = false; @@ -137,13 +232,22 @@ Status InstantiationBodyParameters( const FunctionDef& func, const AttrSlice& func_instantiation_attr, absl::flat_hash_map* body_parameters); +// Register GrapplerFunctionItem input arg expansion and function body outputs +// in the GrapplerFunctionConnectivity. Use function library definition to +// lookup function body nodes output names and ranges. +Status RegisterGrapplerFunctionConnectivity( + const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, + GrapplerFunctionConnectivity* connectivity); + // Replace one of the function inputs with a constant. Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, GrapplerFunctionItem* item); -// Removes outputs from instantiated grappler function item. For all active -// function outputs that changed its output index, this function adds an output -// mapping (std::pair). +// Removes outputs from instantiated grappler function item. Function node +// outputs use GraphDef output index encoding, and multiple outputs might belong +// to the same output argument expansion (in case of tensor list outputs). For +// all active function outputs that changed its output index, this function adds +// an output mapping (std::pair). Status RemoveFunctionOutputs(const absl::flat_hash_set& remove_outputs, GrapplerFunctionItem* item, std::vector>* output_mapping); diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 8cc938ec845..813e6a318cf 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -63,16 +63,11 @@ TEST_F(FunctionsTest, InstantiationParameters) { FunctionDef func = FunctionDefHelper::Create( "ParametrizedFunc", /* inputs */ - {"input1:A", "input2:B", "input3:float", "input4: C"}, + {"input1:A", "input2:B", "input3:float"}, /* outputs */ - {"output1: A", "output2:D"}, + {"output1: A", "output2:C"}, /* type parameters */ - { - "A: {float, double}", - "B: {float, int32}", - "C: list(type)", - "D: {float, double}", - }, + {"A: {float, double}", "B: {float, int32}", "C: {float, double}"}, /* function body*/ {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}}, /* Mapping between function returns and function node outputs. */ @@ -82,20 +77,16 @@ TEST_F(FunctionsTest, InstantiationParameters) { func_instantiation_attr["key"].set_s("key-value"); func_instantiation_attr["A"].set_type(DT_FLOAT); func_instantiation_attr["B"].set_type(DT_INT32); - func_instantiation_attr["C"].mutable_list()->add_type(DT_FLOAT); - func_instantiation_attr["C"].mutable_list()->add_type(DT_INT32); - func_instantiation_attr["D"].set_type(DT_DOUBLE); + func_instantiation_attr["C"].set_type(DT_DOUBLE); absl::flat_hash_map type_parameters; TF_EXPECT_OK(InstantiationTypeParameters( func, AttrSlice(&func_instantiation_attr), &type_parameters)); - ASSERT_EQ(5, type_parameters.size()); + ASSERT_EQ(3, type_parameters.size()); EXPECT_EQ(DT_FLOAT, type_parameters["A"]); EXPECT_EQ(DT_INT32, type_parameters["B"]); - EXPECT_EQ(DT_FLOAT, type_parameters["C:0"]); - EXPECT_EQ(DT_INT32, type_parameters["C:1"]); - EXPECT_EQ(DT_DOUBLE, type_parameters["D"]); + EXPECT_EQ(DT_DOUBLE, type_parameters["C"]); absl::flat_hash_map body_parameters; TF_EXPECT_OK(InstantiationBodyParameters( @@ -105,6 +96,131 @@ TEST_F(FunctionsTest, InstantiationParameters) { EXPECT_EQ("key-value", body_parameters["key"].s()); } +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion( + {"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}}); + + connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}}); + connectivity.RegisterFunctionBodyOutputs("Func", + {{"o1", {0, 2}}, {"o2", {2, 4}}}); + + std::vector inputs; + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputA", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("inputA", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("inputB_0", inputs[0]); + EXPECT_EQ("inputB_1", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("inputB_1", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Add:z", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Add", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("Func", inputs[0]); + EXPECT_EQ("Func:1", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2", &inputs)); + ASSERT_EQ(2, inputs.size()); + EXPECT_EQ("Func:2", inputs[0]); + EXPECT_EQ("Func:3", inputs[1]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:0", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:1", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:0", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:2", inputs[0]); + + inputs.clear(); + TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:1", &inputs)); + ASSERT_EQ(1, inputs.size()); + EXPECT_EQ("Func:3", inputs[0]); +} + +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion( + {"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}}); + + connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}}); + connectivity.RegisterFunctionBodyOutputs("Func", + {{"o1", {0, 2}}, {"o2", {2, 4}}}); + + string input; + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input)); + EXPECT_EQ("inputA:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input)); + EXPECT_EQ("inputB:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input)); + EXPECT_EQ("inputB:1", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input)); + EXPECT_EQ("Add:z:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input)); + EXPECT_EQ("Func:o1:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input)); + EXPECT_EQ("Func:o1:1", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input)); + EXPECT_EQ("Func:o2:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input)); + EXPECT_EQ("Func:o2:1", input); +} + +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion( + {"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}}); + + NodeDef node; + node.add_input("inputA:0"); + node.add_input("inputB"); + + TF_EXPECT_OK(connectivity.ExpandNodeInputs(&node)); + + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("inputA", node.input(0)); + EXPECT_EQ("inputB_0", node.input(1)); + EXPECT_EQ("inputB_1", node.input(2)); +} + TEST_F(FunctionsTest, FromSimpleFunctionDef) { const Tensor kTwo = test::AsScalar(2); FunctionDef func = FunctionDefHelper::Define( @@ -136,17 +252,19 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) { EXPECT_EQ(5, item.function_body().node_size()); EXPECT_EQ(1, item.input_size()); - EXPECT_EQ("x", item.input(0).node_name); + EXPECT_EQ("x", item.input(0).input_name); + ASSERT_EQ(1, item.input(0).placeholders.size()); + EXPECT_EQ("x", item.input(0).placeholders[0]); EXPECT_EQ(1, item.output_size()); - EXPECT_EQ("y_RetVal", item.output(0).node_name); + EXPECT_EQ("y", item.output(0).output_name); + EXPECT_EQ("y_output_node_0", item.output(0).output_nodes[0]); int count = 0; for (const NodeDef &node : item.function_body().node()) { if (node.name() == "x" && ++count) { - EXPECT_EQ("_Arg", node.op()); - EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); - EXPECT_EQ(0, node.attr().at("index").i()); + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "two" && ++count) { EXPECT_EQ("Const", node.op()); @@ -162,11 +280,10 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("scale", node.input(1)); - } else if (node.name() == "y_RetVal" && ++count) { - EXPECT_EQ("_Retval", node.op()); + } else if (node.name() == "y_output_node_0" && ++count) { + EXPECT_EQ("Identity", node.op()); ASSERT_EQ(1, node.input_size()); EXPECT_EQ("y", node.input(0)); - EXPECT_EQ(0, node.attr().at("index").i()); } } EXPECT_EQ(5, count); @@ -217,22 +334,20 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { EXPECT_EQ(14, item.function_body().node_size()); ASSERT_EQ(3, item.input_size()); - EXPECT_EQ("x", item.input(0).node_name); - EXPECT_EQ("y", item.input(1).node_name); - EXPECT_EQ("dz", item.input(2).node_name); + EXPECT_EQ("x", item.input(0).input_name); + EXPECT_EQ("y", item.input(1).input_name); + EXPECT_EQ("dz", item.input(2).input_name); ASSERT_EQ(2, item.output_size()); - EXPECT_EQ("dx_RetVal", item.output(0).node_name); - EXPECT_EQ("dy_RetVal", item.output(1).node_name); + EXPECT_EQ("dx_output_node_0", item.output(0).output_nodes[0]); + EXPECT_EQ("dy_output_node_0", item.output(1).output_nodes[0]); int count = 0; for (const NodeDef &node : item.function_body().node()) { if (node.name() == "x" || node.name() == "y" || node.name() == "dz") { count++; - EXPECT_EQ("_Arg", node.op()); - EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); - int expected_index = node.name() == "x" ? 0 : node.name() == "y" ? 1 : 2; - EXPECT_EQ(expected_index, node.attr().at("index").i()); + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "rx" && ++count) { EXPECT_EQ("BroadcastGradientArgs", node.op()); @@ -249,14 +364,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("gy", node.input(0)); EXPECT_EQ("rx:1", node.input(1)); - } else if (node.name() == "dx_RetVal" && ++count) { - EXPECT_EQ("_Retval", node.op()); - EXPECT_EQ(0, node.attr().at("index").i()); + } else if (node.name() == "dx_output_node_0" && ++count) { + EXPECT_EQ("Identity", node.op()); ASSERT_EQ(1, node.input_size()); EXPECT_EQ("dx", node.input(0)); - } else if (node.name() == "dy_RetVal" && ++count) { - EXPECT_EQ("_Retval", node.op()); - EXPECT_EQ(1, node.attr().at("index").i()); + } else if (node.name() == "dy_output_node_0" && ++count) { + EXPECT_EQ("Identity", node.op()); ASSERT_EQ(1, node.input_size()); EXPECT_EQ("dy", node.input(0)); } @@ -312,10 +425,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { for (const NodeDef &node : item.function_body().node()) { if (node.name() == "x" || node.name() == "y") { count++; - EXPECT_EQ("_Arg", node.op()); - EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); - int expected_index = node.name() == "x" ? 0 : 1; - EXPECT_EQ(expected_index, node.attr().at("index").i()); + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "a0" && ++count) { EXPECT_EQ("Swap", node.op()); @@ -374,14 +485,13 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { flib, TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ(1, item.output_size()); - EXPECT_EQ("out_RetVal", item.output(0).node_name); + EXPECT_EQ("out_output_node_0", item.output(0).output_nodes[0]); int count = 0; for (const NodeDef &node : item.function_body().node()) { if (node.name() == "in" && ++count) { - EXPECT_EQ("_Arg", node.op()); - EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); - EXPECT_EQ(0, node.attr().at("index").i()); + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "Linear_func" && ++count) { EXPECT_EQ("Identity", node.op()); @@ -391,9 +501,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { EXPECT_EQ("Exp", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("Linear_func", node.input(0)); - } else if (node.name() == "out_RetVal" && ++count) { - EXPECT_EQ("_Retval", node.op()); - EXPECT_EQ(0, node.attr().at("index").i()); + } else if (node.name() == "out_output_node_0" && ++count) { + EXPECT_EQ("Identity", node.op()); ASSERT_EQ(1, node.input_size()); EXPECT_EQ("Exp", node.input(0)); } @@ -401,6 +510,70 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { EXPECT_EQ(4, count); } +TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { + FunctionDef func = FunctionDefHelper::Create( + // Name + "ForwardInputs", + // Args + {"in0: float", "in1: float", "arg2: float", "arg3: int32", "arg4: float"}, + // Return values + {"out0: float", "arg2: float", "arg3: int32"}, + // Attr def + {}, + // Nodes + {}, + // Mapping + {{"out0", "in0"}}); + + protobuf::Map func_instantiation_attr; + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, + AttrSlice(&func_instantiation_attr), + flib, TF_GRAPH_DEF_VERSION, &item)); + + EXPECT_EQ("ForwardInputs", item.id); + EXPECT_EQ(8, item.function_body().node_size()); + + EXPECT_EQ(3, item.output_size()); + EXPECT_EQ("out0_output_node_0", item.output(0).output_nodes[0]); + EXPECT_EQ("arg2_output_node_0", item.output(1).output_nodes[0]); + EXPECT_EQ("arg3_output_node_0", item.output(2).output_nodes[0]); + + int count = 0; + + const auto is_arg_placeholder = [](const string &name) { + return name == "in0" || name == "in1" || name == "arg2" || name == "arg3" || + name == "arg4"; + }; + + for (const NodeDef &node : item.function_body().node()) { + if (is_arg_placeholder(node.name()) && node.op() == "Placeholder") { + count++; + if (node.name() == "arg3") { + EXPECT_EQ(DT_INT32, node.attr().at("dtype").type()); + } else { + EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); + } + continue; + } + + EXPECT_EQ("Identity", node.op()); + ASSERT_EQ(1, node.input_size()); + EXPECT_TRUE(is_arg_placeholder(node.input(0))); + + if (node.name() == "out0_output_node_0" && ++count) { + EXPECT_EQ("in0", node.input(0)); + } else if (node.name() == "arg2_output_node_0" && ++count) { + EXPECT_EQ("arg2", node.input(0)); + } else if (node.name() == "arg3_output_node_0" && ++count) { + EXPECT_EQ("arg3", node.input(0)); + } + } + EXPECT_EQ(8, count); +} + TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { const Tensor kTwo = test::AsScalar(2); FunctionDef func = FunctionDefHelper::Define( @@ -427,7 +600,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { EXPECT_EQ(0, item.input_size()); EXPECT_EQ(1, item.output_size()); - EXPECT_EQ("o_RetVal", item.output(0).node_name); + EXPECT_EQ("o_output_node_0", item.output(0).output_nodes[0]); EXPECT_EQ(3, item.function_body().node_size()); const NodeDef &two = item.function_body().node(0); @@ -440,7 +613,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { EXPECT_EQ("two", cast.input(0)); const NodeDef &retval = item.function_body().node(2); - EXPECT_EQ("o_RetVal", retval.name()); + EXPECT_EQ("o_output_node_0", retval.name()); EXPECT_EQ(1, retval.input_size()); EXPECT_EQ("o", retval.input(0)); } @@ -541,14 +714,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) { EXPECT_EQ("y", specialized.signature().output_arg(0).name()); EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type()); - // Function body specialized for instantiation types. + // Function body specialized for instantiation types int count = 0; for (const NodeDef &node : specialized.node_def()) { if (node.name() == "scale" && ++count) { EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type()); } else if (node.name() == "y" && ++count) { EXPECT_EQ("Mul", node.op()); - EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("x:0", node.input(0)); EXPECT_EQ("scale:y:0", node.input(1)); EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); } @@ -580,9 +753,9 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) { const NodeDef &input_x = item.function_body().node(0); const NodeDef &input_y = item.function_body().node(1); - // Initially inputs added to the graph as _Arg nodes. - EXPECT_EQ("_Arg", input_x.op()); - EXPECT_EQ("_Arg", input_y.op()); + // Initially inputs added to the graph as placeholders. + EXPECT_EQ("Placeholder", input_x.op()); + EXPECT_EQ("Placeholder", input_y.op()); // Replace inputs x and y with constants. NodeDef const_input_x; @@ -651,7 +824,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) { GraphDef id_func_body = test::function::GDef( {/* Read and return input argument through Identity node. */ NDef("read_x", "Identity", {"x"}, {{"T", "float"}}), - NDef("z_RetVal", "_Retval", {"read_x"}, {{"T", "float"}})}); + NDef("z_output_node_0", "Identity", {"read_x"}, {{"T", "float"}})}); protobuf::Map func_instantiation_attr; func_instantiation_attr["T"].set_type(DT_FLOAT); @@ -676,7 +849,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) { for (const NodeDef &node : specialized.node_def()) { if (node.name() == "read_x" && ++count) { EXPECT_EQ("Identity", node.op()); - EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("x:0", node.input(0)); } } EXPECT_EQ(1, count);