Automated rollback of commit 92f736f429e398df261cd2f3c8c949840dd06a76

PiperOrigin-RevId: 240915460
This commit is contained in:
Eugene Zhulenev 2019-03-28 21:17:34 -07:00 committed by TensorFlower Gardener
parent 937bf0a2e6
commit 9e467f4df3
14 changed files with 1076 additions and 488 deletions

View File

@ -177,8 +177,7 @@ class FunctionInstantiationHelper {
} else {
gnode->set_op(FunctionLibraryDefinition::kArgOp);
}
DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("T", dtypes[i], gnode);
AddAttr("index", arg_index, gnode);
result_.arg_types.push_back(dtypes[i]);
++arg_index;
@ -344,8 +343,7 @@ class FunctionInstantiationHelper {
gnode->set_op(FunctionLibraryDefinition::kRetOp);
}
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("T", dtypes[i], gnode);
AddAttr("index", (*ret_index)++, gnode);
result_.ret_types.push_back(dtypes[i]);
}

View File

@ -41,7 +41,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":utils",
"@com_google_absl//absl/types:optional",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:mutable_graph_view",

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "absl/types/optional.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -604,31 +603,22 @@ class SymbolicShapeRefiner {
" was not previously added to SymbolicShapeRefiner.");
}
const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
it->second;
if (!maybe_grappler_function_item.has_value()) {
VLOG(3) << "Skip failed to instantiate function call: function_name="
<< function_node->op();
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
for (int i = 0; i < ic->num_outputs(); ++i) {
TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
}
return Status::OK();
}
// Copy (not reference) so that changes we make here (e.g., replacing
// _Arg with Const and _Retval with Identity) don't affect one in
// Placeholder with Const) don't affect one in
// fun_to_grappler_function_item_.
GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
GrapplerFunctionItem grappler_function_item = it->second;
MutableGraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
auto& fun_input = grappler_function_item.input(i);
NodeDef* fun_node = gv.GetNode(fun_input.node_name);
if (fun_input.placeholders.size() > 1) {
// TODO(jmdecker): Handle case with multiple input placeholders
return errors::Unimplemented(
"Input arguments with multiple placeholders are not yet "
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
const TensorId input_tensor = ParseTensorName(function_node->input(i));
if (IsControlInput(input_tensor)) {
@ -659,18 +649,11 @@ class SymbolicShapeRefiner {
proto.mutable_dim(i)->set_size(-1);
}
}
// Turn _Arg node into a Placeholder. _Arg node is a system op without a
// valid shape function.
*attr_output_shape.mutable_shape() = proto;
fun_node->set_op("Placeholder");
(*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"];
(*fun_node->mutable_attr()).erase("index");
(*fun_node->mutable_attr()).erase("T");
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
// Replace input nodes with Consts, if values are known. Note that
// Replace input Placeholders with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
@ -701,15 +684,6 @@ class SymbolicShapeRefiner {
}
}
// Replace output _Retval nodes with Identity nodes. _Retval is a system op
// without outputs and registered shape function.
for (const auto& output_arg : grappler_function_item.outputs()) {
NodeDef* output_node = gv.GetNode(output_arg.node_name);
DCHECK_EQ(output_node->op(), "_Retval");
output_node->set_op("Identity");
output_node->mutable_attr()->erase("index");
}
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
@ -720,9 +694,16 @@ class SymbolicShapeRefiner {
ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
nullptr);
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_nodes.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
return errors::Unimplemented(
"Output arguments with multiple output tensors are not yet "
"supported.");
}
// It is guaranteed that output_tensors does not contain any control
// inputs, so port_id >= 0.
TensorId out_tensor = ParseTensorName(out_arg.node_name);
TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]);
const NodeDef* retnode = gv.GetNode(out_tensor.node());
if (retnode == nullptr) {
@ -1061,18 +1042,9 @@ class SymbolicShapeRefiner {
CHECK_NOTNULL(function_library_.Find(function_node->op()));
GrapplerFunctionItem grappler_function_item;
Status function_instantiated =
TF_RETURN_IF_ERROR(
MakeGrapplerFunctionItem(*function_def, function_library_,
graph_def_version_, &grappler_function_item);
// If function instantiation failed we will skip it during shape inference.
if (!function_instantiated.ok()) {
VLOG(3) << "Failed to instantiate a function. Error: "
<< function_instantiated.error_message();
fun_to_grappler_function_item_[function_def->signature().name()] =
absl::nullopt;
return Status::OK();
}
graph_def_version_, &grappler_function_item));
if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
@ -1719,9 +1691,7 @@ class SymbolicShapeRefiner {
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
// Store function instantiations only for valid function. If function
// instantiation failed it will have an `absl::nullopt`.
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
std::unordered_map<string, GrapplerFunctionItem>
fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;

View File

@ -67,10 +67,6 @@ bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}
bool IsArg(const NodeDef& node) {
return node.op() == "_Arg" || node.op() == "_DeviceArg";
}
bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
@ -423,10 +419,6 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice");
}
bool IsRetval(const NodeDef& node) {
return node.op() == "_Retval" || node.op() == "_DeviceRetval";
}
bool IsReverse(const NodeDef& node) {
return node.op() == "Reverse" || node.op() == "ReverseV2";
}

View File

@ -33,7 +33,6 @@ bool IsAnyMaxPool(const NodeDef& node);
bool IsAnyMin(const NodeDef& node);
bool IsAnyMul(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsArg(const NodeDef& node);
bool IsArgMax(const NodeDef& node);
bool IsArgMin(const NodeDef& node);
bool IsAssert(const NodeDef& node);
@ -138,7 +137,6 @@ bool IsRelu6Grad(const NodeDef& node);
bool IsReluGrad(const NodeDef& node);
bool IsReshape(const NodeDef& node);
bool IsRestore(const NodeDef& node);
bool IsRetval(const NodeDef& node);
bool IsReverse(const NodeDef& node);
bool IsReverseV2(const NodeDef& node);
bool IsRsqrt(const NodeDef& node);

View File

@ -226,19 +226,11 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func));
} else {
VLOG(2) << "Failed to optimize dataset function. Error: "
<< s.error_message();
}
} else if (IsDatasetNodeOfType(node, kSourceDatasetOps)) {
return errors::InvalidArgument(
"Reached a source dataset: ", node.op(),
" without encountering a batch transformation.");
} else if (IsRetval(node)) {
// _Retvals added to the function body graph in place of function outputs.
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
TF_RETURN_IF_ERROR(
RecursivelyHandleOp(*input_node, num_workers, flib, graph));
} else {
return errors::InvalidArgument("Encountered an unsupported op: ",
node.op());

View File

@ -76,7 +76,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
return false;
}
for (const auto& consumer : node_map_->GetOutputs(node.name())) {
if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
if (node.input_size() > 1 && IsMerge(*consumer)) {
return false;
}
if (IsSwitch(*input)) {

View File

@ -109,10 +109,6 @@ bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
// Check if func_node has function attribute with a function name matching
// FunctionDef signature.
bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
return false;
}
auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
return func_attr != nullptr && func_attr->has_func() &&
func_attr->func().name() == func.signature().name();
@ -824,7 +820,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// update outputs for the fetch nodes, so we just skip them.
std::vector<std::pair<int, int>> output_mapping;
if (!signature.is_in_fetch_set) {
int num_func_outputs = item.output_size();
int num_func_outputs = 0;
for (const auto& out_arg : item.outputs()) {
num_func_outputs += out_arg.output_nodes.size();
}
absl::flat_hash_set<int> remove;
for (int i = 0; i < num_func_outputs; ++i) {
@ -975,8 +974,10 @@ NodeDef InlinedFunctionInputsNode(const NodeDef& func_node,
AttrValue::ListValue* type_list =
(*inputs.mutable_attr())["T"].mutable_list();
for (const InputArgInstantiation& input_arg : item.inputs()) {
type_list->add_type(input_arg.data_type);
for (const InputArgExpansion& input_arg : item.inputs()) {
for (int i = 0; i < input_arg.placeholders.size(); ++i) {
type_list->add_type(input_arg.data_type);
}
}
return inputs;
@ -995,11 +996,12 @@ NodeDef InlinedFunctionOutputsNode(
AttrValue::ListValue* type_list =
(*outputs.mutable_attr())["T"].mutable_list();
for (const OutputArgInstantiation& output_arg : item.outputs()) {
const absl::string_view output_tensor =
output_tensors.at(output_arg.node_name);
type_list->add_type(output_arg.data_type);
outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
for (const OutputArgExpansion& output_arg : item.outputs()) {
for (const string& output_node : output_arg.output_nodes) {
const absl::string_view output_tensor = output_tensors.at(output_node);
type_list->add_type(output_arg.data_type);
outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
}
}
return outputs;
@ -1026,24 +1028,29 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
". Error: ", item_status.error_message());
}
// Mapping from input arg node name to function input position.
absl::flat_hash_map<absl::string_view, int> input_args_idx;
for (const InputArgInstantiation& input_arg : item.inputs()) {
const int idx = input_args_idx.size();
input_args_idx[input_arg.node_name] = idx;
// Mapping from input placeholder name to function input position.
absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
for (const InputArgExpansion& input_arg : item.inputs()) {
for (const string& placeholder : input_arg.placeholders) {
const int idx = input_placeholders_idx.size();
input_placeholders_idx[placeholder] = idx;
}
}
// Mapping from the '_Retval' node name to the output tensor.
absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
for (const NodeDef& func_body_node : item.function_body().node()) {
if (!IsRetval(func_body_node)) continue;
if (func_body_node.input_size() != 1) {
return errors::Internal("_Retval node must have single input: ",
SummarizeNodeDef(func_body_node));
// Bypass identity nodes added to the graph in place of function outputs.
absl::flat_hash_set<absl::string_view> output_nodes;
for (const OutputArgExpansion& output_arg : item.outputs()) {
for (const string& output_node : output_arg.output_nodes) {
output_nodes.insert(output_node);
}
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
}
// For each function output value we added an identity node that reads the
// tensor from one of the function body nodes. When we inline function into
// the main graph we want to bypass these nodes, so we keep a mapping from
// 'output node name' -> 'output tensor name'.
absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
// Hook inlined function inputs to IdentityN node.
NodeDef* func_inputs = optimized_graph->add_node();
*func_inputs = InlinedFunctionInputsNode(func_node, item);
@ -1051,18 +1058,22 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
const string& node_name = func_body_node.name();
// Skip function output nodes.
if (IsRetval(func_body_node)) continue;
// Skip output identity node, and update a mapping to the output tensor.
if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
output_tensors.emplace(node_name, func_body_node.input(0));
continue;
}
// Turn _Arg nodes added in place of input arguments into identity nodes.
const auto input_arg_idx = input_args_idx.find(node_name);
if (input_arg_idx != input_args_idx.end()) {
// Turn placeholders added in place of input arguments into identity nodes.
const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
if (input_placeholder_idx != input_placeholders_idx.end()) {
CHECK_EQ(0, func_body_node.input_size());
func_body_node.set_op("Identity");
func_body_node.mutable_attr()->erase("index");
(*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
func_body_node.mutable_attr()->erase("dtype");
func_body_node.mutable_attr()->erase("shape");
func_body_node.add_input(
strings::StrCat(func_inputs->name(), ":", input_arg_idx->second));
func_body_node.add_input(strings::StrCat(func_inputs->name(), ":",
input_placeholder_idx->second));
} else {
// Update the input names if any.
for (string& input : *func_body_node.mutable_input()) {
@ -1335,8 +1346,10 @@ Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
// Names of the function body nodes that return function output values.
absl::flat_hash_set<absl::string_view> output_nodes;
for (const auto& output_arg : item.outputs()) {
output_nodes.insert(output_arg.node_name);
for (const auto& output_expansion : item.outputs()) {
for (const auto& output_node : output_expansion.output_nodes) {
output_nodes.insert(output_node);
}
}
GraphTopologyView topology_view;
@ -1417,10 +1430,7 @@ Status CheckThatSideEffectsWillExecute(
// can't produce any visible side-effects.
const bool read_only = IsReadVariableOp(func_body_node);
// _Retval marked as stateful, but we will remove it before inlining.
const bool retval = IsRetval(func_body_node);
if (read_only || retval || !node_must_execute) continue;
if (read_only || !node_must_execute) continue;
VLOG(3) << "Check that node " << func_body_node.name()
<< " will execute after inlining.";
@ -1460,7 +1470,7 @@ Status CheckThatSideEffectsWillExecute(
Status PlaceInlinedFunctionBody(
const NodeDef& func_node, const GrapplerFunctionItem& item,
const absl::flat_hash_map<absl::string_view, int>& input_args_idx,
const absl::flat_hash_map<absl::string_view, int>& input_placeholders_idx,
FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
// Control flow lowering and Placer works with a Graph object.
std::unique_ptr<Graph> func_body_graph =
@ -1488,14 +1498,15 @@ Status PlaceInlinedFunctionBody(
TF_RETURN_IF_ERROR(pass.Run(opt_options));
// ------------------------------------------------------------------------ //
// Before placing the function body nodes we pin input arguments to the
// Before placing the function body nodes we pin input placeholders to the
// same device as their corresponding input nodes.
for (Node* func_body_node : func_body_graph->nodes()) {
const auto input_arg_idx = input_args_idx.find(func_body_node->name());
const auto input_placeholder_idx =
input_placeholders_idx.find(func_body_node->name());
if (input_arg_idx != input_args_idx.end()) {
const int input_idx = input_arg_idx->second;
if (input_placeholder_idx != input_placeholders_idx.end()) {
const int input_idx = input_placeholder_idx->second;
const GraphView::OutputPort output_port =
ctx->graph_view().GetRegularFanin({&func_node, input_idx});
@ -1620,26 +1631,45 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
inputs.push_back(tensor_id);
}
// Mapping from input argument node to function input position.
absl::flat_hash_map<absl::string_view, int> input_args_idx;
for (const InputArgInstantiation& input_arg : item.inputs()) {
const int idx = input_args_idx.size();
input_args_idx[input_arg.node_name] = idx;
// Mapping from input placeholder name to function input position.
absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
for (const InputArgExpansion& input_arg : item.inputs()) {
for (const string& placeholder : input_arg.placeholders) {
const int idx = input_placeholders_idx.size();
input_placeholders_idx[placeholder] = idx;
}
}
const string prefix = strings::StrCat(func_node.name(), "/");
// ------------------------------------------------------------------------ //
// Mapping from the '_Retval' node name to the output tensor.
absl::flat_hash_map<absl::string_view, string> output_tensors;
// For each function output value we added an identity node that reads the
// tensor from one of the function body nodes. When we inline function into
// the main graph we want to bypass these nodes, so we keep a mapping from
// 'output node name' -> 'output tensor name'.
absl::flat_hash_map<string, string> output_tensors;
for (const NodeDef& func_body_node : item.function_body().node()) {
if (!IsRetval(func_body_node)) continue;
if (func_body_node.input_size() != 1) {
return errors::Internal("_Retval node must have single input: ",
SummarizeNodeDef(func_body_node));
// Unique names of nodes producing tensors in `output_tensors`.
absl::flat_hash_set<string> output_tensors_nodes;
// Identity nodes added to the function body in place of function outputs.
absl::flat_hash_set<string> output_nodes;
for (const OutputArgExpansion& output_arg : item.outputs()) {
for (const string& output_node : output_arg.output_nodes) {
output_nodes.insert(output_node);
}
}
for (const NodeDef& func_body_node : item.graph.node()) {
const string& node_name = func_body_node.name();
if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
const string& output_tensor = func_body_node.input(0);
output_tensors.emplace(node_name, output_tensor);
SafeTensorId tensor_id = ParseTensorName(output_tensor);
output_tensors_nodes.insert(tensor_id.node());
}
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
}
// ------------------------------------------------------------------------ //
@ -1713,8 +1743,8 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
// make sure that after inlining all nodes will have valid device assignment.
GraphDef placed_graph_def;
TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(func_node, item, input_args_idx,
ctx, &placed_graph_def));
TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(
func_node, item, input_placeholders_idx, ctx, &placed_graph_def));
// ------------------------------------------------------------------------ //
// After all nodes placed we need to prepare them for inlining into the
@ -1728,14 +1758,15 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
const string& node_name = func_body_node.name();
// Turn _Arg nodes added in place of input arguments into identity nodes.
const auto input_arg_idx = input_args_idx.find(node_name);
if (input_arg_idx != input_args_idx.end()) {
// Turn placeholders added in place of input arguments into identity nodes.
const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
if (input_placeholder_idx != input_placeholders_idx.end()) {
DCHECK_EQ(0, func_body_node.input_size());
func_body_node.set_op("Identity");
func_body_node.mutable_attr()->erase("index");
(*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
func_body_node.mutable_attr()->erase("dtype");
func_body_node.mutable_attr()->erase("shape");
const int input_idx = input_arg_idx->second;
const int input_idx = input_placeholder_idx->second;
func_body_node.add_input(inputs[input_idx].ToString());
// Add a control dependency on 'inputs_ready' node, to guarantee that all
@ -1788,7 +1819,17 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
// ------------------------------------------------------------------------ //
// Check that after inlining all side-effects will be executed in well defined
// order. We do it by checking if there is a path from stateful/dataset ops to
// one of the control output nodes.
// one of the output nodes.
// Because we rename all the nodes before inlining, we need a copy of
// output_nodes with a new names.
absl::flat_hash_set<string> inlined_output_nodes;
for (const string& output_node : output_nodes) {
inlined_output_nodes.insert(inlined_node_name(output_node));
}
const auto is_inlined_output_node = [&](const NodeDef& node) -> bool {
return inlined_output_nodes.find(node.name()) != inlined_output_nodes.end();
};
// Names of the inlined control output nodes.
absl::flat_hash_set<string> inlined_control_output_nodes;
@ -1844,8 +1885,10 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
}
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
// We bypass _Retval nodes and fetch tensors from `retval.input(0)`.
if (IsRetval(func_body_node)) continue;
// Skip output identity nodes.
if (IsIdentity(func_body_node) && is_inlined_output_node(func_body_node))
continue;
optimized_graph->add_node()->Swap(&func_body_node);
}
@ -1853,17 +1896,19 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
// not copy the original function call node, so we have to setup tensor
// mapping from old output tensors, to the outputs of inlined nodes.
int output_idx = 0;
for (const OutputArgInstantiation& output : item.outputs()) {
const string& output_tensor = output_tensors.at(output.node_name);
for (const OutputArgExpansion& output : item.outputs()) {
for (const string& output_node : output.output_nodes) {
const string& output_tensor = output_tensors.at(output_node);
const SafeTensorId from_tensor(func_node.name(), output_idx++);
const SafeTensorId to_tensor = ParseTensorName(output_tensor);
const SafeTensorId from_tensor(func_node.name(), output_idx++);
const SafeTensorId to_tensor = ParseTensorName(output_tensor);
const SafeTensorId inlined_to_tensor =
SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
to_tensor.index());
const SafeTensorId inlined_to_tensor =
SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
to_tensor.index());
ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
}
}
// If function call node was in keep_ops set, it means that we need to keep a

View File

@ -123,7 +123,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SkipErrorsIfGraphNotModified) {
// Standard XTimesTwo() function.
FunctionDef x_times_two = test::function::XTimesTwo();
// Function signature has non-type attribute (currently not supported).
// Function with sequence of tensors as an input (currently not supported).
FunctionDef my_identity_n = FunctionDefHelper::Create(
// Name
"MyIdentityN",

View File

@ -367,8 +367,8 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
if (node.name() == "my_mul/inlined_inputs" && ++count) {
EXPECT_EQ("IdentityN", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("x", node.input(1));
EXPECT_EQ("x:0", node.input(0));
EXPECT_EQ("x:0", node.input(1));
} else if (node.name() == "my_mul/x" && ++count) {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
@ -623,17 +623,17 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
MetaOptimizer optimizer(nullptr, config_proto);
// Define simple function library with two identical mul functions.
FunctionDef mul_func_1 = FunctionDefHelper::Create(
"MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
FunctionDef mul_func_1 =
FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
{}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
FunctionDef mul_func_2 = FunctionDefHelper::Create(
"MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
FunctionDef mul_func_2 =
FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
{}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
// Tensorflow graph:
//

View File

@ -172,7 +172,6 @@ cc_library(
hdrs = ["functions.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -17,9 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@ -36,29 +34,306 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
const NodeDef& node,
GrapplerFunctionConnectivity* connectivity) {
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
connectivity->RegisterFunctionBodyOutputs(node.name(),
std::move(outputs_range_map));
return Status::OK();
}
Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
const NodeDef& node,
GrapplerFunctionConnectivity* connectivity) {
const OpRegistrationData* registration;
TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
return RegisterFunctionBodyOutputs(*registration, node, connectivity);
}
// Replace the placeholder attribute values with the values specified in
// instantiation attributes.
Status ResolveFunctionBodyNodeAttrPlaceholders(
const AttrSlice& func_instantiation_attr, NodeDef* node) {
for (auto& attr : *node->mutable_attr()) {
const string& placeholder = attr.second.placeholder();
if (placeholder.empty()) continue;
const AttrValue* attr_value = func_instantiation_attr.Find(placeholder);
if (attr_value) {
attr.second = *attr_value;
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
}
}
return Status::OK();
}
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
InputArgExpansion input_arg_expansion) {
string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
input_arg_placeholders_.insert(
{placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
}
input_arg_expansions_.insert(
{std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
const string& node_name, tensorflow::NameRangeMap&& outputs) {
function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const string& func_def_input, std::vector<string>* graph_def_inputs) const {
using ::tensorflow::strings::Scanner;
if (IsControlInput(func_def_input)) {
graph_def_inputs->push_back(func_def_input);
return Status::OK();
}
// Parse input format: "node_name[:node_output][:position]"
string node_name;
string node_output;
int position = -1;
StringPiece capture;
StringPiece remaining;
// Parse "node_name"
if (Scanner(func_def_input)
.One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
.Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.GetResult(&remaining, &capture)) {
node_name = string(capture.data(), capture.size());
}
// Parse "node_output" if it exists
if (Scanner(remaining)
.OneLiteral(":")
.RestartCapture()
.One(strings::Scanner::LETTER)
.Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
.GetResult(&remaining, &capture)) {
node_output = string(capture.data(), capture.size());
}
// Parse "position" if it exists
if (Scanner(remaining)
.OneLiteral(":")
.RestartCapture()
.Many(strings::Scanner::DIGIT)
.GetResult(nullptr, &capture)) {
CHECK(strings::safe_strto32(capture, &position));
}
// If "node_output" is not empty, it must be an output of a function body node
bool is_function_body_output = !node_output.empty();
// Function input argument: "node_name[:position]"
if (!is_function_body_output) {
auto input_arg = input_arg_expansions_.find(node_name);
if (input_arg != input_arg_expansions_.end()) {
const InputArgExpansion& input_arg_expansion = input_arg->second;
const auto& placeholders = input_arg_expansion.placeholders;
if (position == -1) {
// If position is not defined use all placeholders
graph_def_inputs->reserve(placeholders.size());
for (const string& placeholder : placeholders) {
graph_def_inputs->push_back(placeholder);
}
} else {
if (position > input_arg_expansion.placeholders.size() - 1) {
return errors::InvalidArgument("Invalid input ", node_name,
"position: ", position,
" (out of range)");
}
graph_def_inputs->push_back(input_arg_expansion.placeholders[position]);
}
return Status::OK();
}
}
// Function body output: "node_name:node_output[:position]"
if (is_function_body_output) {
auto function_body_outputs = function_body_outputs_.find(node_name);
if (function_body_outputs != function_body_outputs_.end()) {
const tensorflow::NameRangeMap& outputs = function_body_outputs->second;
auto output = outputs.find(node_output);
if (output != outputs.end()) {
const auto& output_range = output->second;
if (position == -1) {
graph_def_inputs->reserve(graph_def_inputs->size() +
output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
graph_def_inputs->push_back(
i == 0 ? node_name : absl::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
return errors::InvalidArgument(
"Invalid node ", node_name, " output ", node_output,
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
graph_def_inputs->push_back(
pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
}
return Status::OK();
}
}
}
return errors::InvalidArgument("Failed to expand a function def input: ",
func_def_input);
}
Status GrapplerFunctionConnectivity::ExpandNodeInputs(
NodeDef* function_body_node) const {
std::vector<string> expanded_inputs;
for (const string& function_def_input : function_body_node->input()) {
TF_RETURN_IF_ERROR(
ExpandFunctionDefInput(function_def_input, &expanded_inputs));
}
function_body_node->clear_input();
for (string& expanded_input : expanded_inputs)
function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
Status GrapplerFunctionConnectivity::AsFunctionDefInput(
const string& graph_def_input, string* func_def_input) const {
if (IsControlInput(graph_def_input)) {
*func_def_input = graph_def_input;
return Status::OK();
}
const TensorId tensor = ParseTensorName(graph_def_input);
DCHECK_GE(tensor.index(), 0);
const absl::string_view node_name = tensor.node();
const int index = tensor.index();
// Check if it's an input arg placeholder
if (tensor.index() == 0) {
const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
if (is_input_placeholder != input_arg_placeholders_.end()) {
const InputArgPlaceholder& placeholder = is_input_placeholder->second;
*func_def_input =
absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
return Status::OK();
}
}
// It must be output from one of the function body nodes
const auto is_body_output = function_body_outputs_.find(tensor.node());
if (is_body_output != function_body_outputs_.end()) {
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
for (const auto& el : outputs_range_map) {
const auto& output_name = el.first;
const auto& output_range = el.second;
if (index >= output_range.first && index < output_range.second) {
int pos = index - output_range.first;
*func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
return Status::OK();
}
}
}
return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
}
Status GrapplerFunctionConnectivity::AsFunctionDefNode(
NodeDef* function_body_node) const {
string func_def_input;
for (int i = 0; i < function_body_node->input_size(); ++i) {
TF_RETURN_IF_ERROR(
AsFunctionDefInput(function_body_node->input(i), &func_def_input));
function_body_node->set_input(i, func_def_input);
}
return Status::OK();
}
Status GrapplerFunctionItemInstantiation::GetTypeAttr(
const string& type_attr_name, DataType* data_type) const {
const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name);
if (type_attr == nullptr) {
return errors::InvalidArgument("Type attribute ", type_attr_name,
" is not defined");
} else if (type_attr->type() == DT_INVALID) {
return errors::InvalidArgument("Type attribute ", type_attr_name,
" is not defined with a valid type");
} else {
*data_type = type_attr->type();
}
return Status::OK();
}
Status GrapplerFunctionItemInstantiation::GetArgType(
const OpDef::ArgDef& arg, DataType* data_type) const {
if (arg.type() != DT_INVALID) {
*data_type = arg.type();
} else {
if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
return errors::InvalidArgument(
"Arguments with sequence of tensors are not supported. Unsupported "
"argument name: ",
arg.name());
}
TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
}
return Status::OK();
}
GrapplerFunctionItem::GrapplerFunctionItem(
string func_name, string description, AttrSlice func_attr,
std::vector<InputArgInstantiation> input_args,
std::vector<OutputArgInstantiation> output_args,
std::vector<InputArgExpansion> input_arg_expansions,
std::vector<OutputArgExpansion> output_arg_expansions,
std::vector<ControlOutput> control_outputs, const int graph_def_version,
const bool is_stateful, GraphDef&& function_body)
: description_(std::move(description)),
func_attr_(func_attr),
input_args_(std::move(input_args)),
output_args_(std::move(output_args)),
input_arg_expansions_(std::move(input_arg_expansions)),
output_arg_expansions_(std::move(output_arg_expansions)),
control_outputs_(std::move(control_outputs)),
is_stateful_(is_stateful) {
id = std::move(func_name);
graph = std::move(function_body);
graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with function input arguments.
for (const InputArgInstantiation& input_arg : input_args_) {
feed.push_back({input_arg.node_name, Tensor()});
graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
feed.push_back({placeholder, Tensor()});
}
}
// Fill the fetch nodes with outputs.
for (const OutputArgInstantiation& output_arg : output_args_) {
fetch.push_back(output_arg.node_name);
for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
for (const string& output_node : output_arg.output_nodes) {
fetch.push_back(output_node);
}
}
// We must keep all control output nodes.
for (const ControlOutput& control_output : control_outputs_) {
@ -72,29 +347,28 @@ GrapplerFunctionItem::GrapplerFunctionItem(
const string& GrapplerFunctionItem::description() const { return description_; }
const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
return input_args_;
const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
return input_arg_expansions_;
}
const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
return input_args_[i];
const InputArgExpansion& GrapplerFunctionItem::input(int i) const {
return input_arg_expansions_[i];
}
const std::size_t GrapplerFunctionItem::input_size() const {
return input_args_.size();
return input_arg_expansions_.size();
}
const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
const {
return output_args_;
const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
return output_arg_expansions_;
}
const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
return output_args_[i];
const OutputArgExpansion& GrapplerFunctionItem::output(int i) const {
return output_arg_expansions_[i];
}
const std::size_t GrapplerFunctionItem::output_size() const {
return output_args_.size();
return output_arg_expansions_.size();
}
const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
@ -153,23 +427,15 @@ Status InstantiationTypeParameters(
return errors::InvalidArgument("Type parameters output map must be empty");
}
const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
if (!arg.type_attr().empty()) {
DataType dtype;
TF_RETURN_IF_ERROR(
GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
type_parameters->emplace(arg.type_attr(), dtype);
GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
} else if (!arg.type_list_attr().empty()) {
std::vector<DataType> dtypes;
TF_RETURN_IF_ERROR(
GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
int index = 0;
for (const DataType& dtype : dtypes) {
type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
dtype);
++index;
}
const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
// Check if it's unknown and unresolved type.
if (arg.type() == DT_INVALID &&
type_parameters->find(arg.type_attr()) == type_parameters->end()) {
DataType data_type;
TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
type_parameters->insert({arg.type_attr(), data_type});
}
return Status::OK();
};
@ -193,7 +459,8 @@ Status InstantiationBodyParameters(
for (auto& attr : func_body_node.attr()) {
const string& placeholder = attr.second.placeholder();
if (placeholder.empty() || body_parameters->contains(placeholder)) {
if (placeholder.empty() ||
body_parameters->find(placeholder) != body_parameters->end()) {
continue;
}
@ -231,13 +498,15 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
// Instantiate function into a statically defined FunctionBody Graph.
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
// Helper methods to lookup function instantiation attributes
GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
// Mapping from FunctionDef input format (name[:output][:position]) to
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
// Instantiate function body into a statically defined graph def.
GraphDef function_body;
fbody->graph->ToGraphDef(&function_body);
// Function body shares the library with the graph that instantiated it. We do
// not need a full copy of the function library, just the reachable subset.
@ -249,25 +518,122 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
flib.num_functions() - function_body.library().function_size(),
signature.name(), function_body.library().function_size());
const int num_instantiated_inputs = fbody->arg_types.size();
const int num_instantiated_outputs = fbody->ret_types.size();
// TODO(ezhulenev): support functions with tensor sequence inputs/outputs
std::vector<InputArgInstantiation> inputs;
inputs.reserve(num_instantiated_inputs);
for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
const Node* node = fbody->arg_nodes[in_id];
const DataType& dtype = fbody->arg_types[in_id];
inputs.emplace_back(node->name(), dtype);
// Make sure that there are no tensor lists in inputs or outputs.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
return errors::InvalidArgument(
"Inputs with lists of tensors are not supported. Input: ",
input.name());
}
}
for (const OpDef::ArgDef& output : signature.output_arg()) {
if (!output.type_list_attr().empty() || !output.number_attr().empty()) {
return errors::InvalidArgument(
"Outputs with lists of tensors are not supported. Output: ",
output.name());
}
}
std::vector<OutputArgInstantiation> outputs;
outputs.reserve(num_instantiated_outputs);
std::vector<InputArgExpansion> inputs;
inputs.reserve(signature.input_arg_size());
for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
const Node* node = fbody->ret_nodes[out_id];
const DataType& dtype = fbody->ret_types[out_id];
outputs.emplace_back(node->name(), dtype);
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
DataType input_data_type;
TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
NodeDef* placeholder = function_body.add_node();
placeholder->set_name(input.name());
placeholder->set_op("Placeholder");
(*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
(*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
true);
InputArgExpansion input_expansion{/*input_name=*/input.name(),
/*data_type=*/input_data_type,
/*is_ref=*/input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
inputs.push_back(std::move(input_expansion));
}
// Keep names of all nodes in the function body to guarantee that we do not
// add an identity with a duplicate name.
absl::flat_hash_set<absl::string_view> func_body_nodes;
// Generate unique output node name: "${out_arg_name}_output_node_${index}".
const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out,
int index) -> string {
string name = absl::StrCat(out.name(), "_output_node_", index);
int i = 1;
while (func_body_nodes.find(name) != func_body_nodes.end()) {
name = absl::StrCat(out.name(), "_output_node_", index, "_", i++);
}
return name;
};
// Add all function nodes to the function body.
for (const NodeDef& func_def_node : func.node_def()) {
func_body_nodes.insert(func_def_node.name());
NodeDef* new_node = function_body.add_node();
*new_node = func_def_node;
const OpRegistrationData* registration;
TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), &registration));
// Resolve all placeholder values using function instantiation attributes.
TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
func_instantiation_attr, new_node));
// Register node output range in a function connectivity.
TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
&connectivity));
}
// Rewrite inputs to use GraphDef format
for (NodeDef& node : *function_body.mutable_node()) {
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
std::vector<OutputArgExpansion> outputs;
outputs.reserve(signature.output_arg_size());
// For each function output argument we create an Identity node in the
// function body, that reads output tensor from the function body node.
for (const OpDef::ArgDef& out : signature.output_arg()) {
DataType output_data_type;
TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
std::vector<string> output_tensors;
auto ret = func.ret().find(out.name());
TF_RETURN_IF_ERROR(
ret != func.ret().end()
// Expand outputs using provided output mapping
? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
// Otherwise output must be one of the function inputs
: connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
absl::InlinedVector<string, 1> output_nodes;
for (int i = 0; i < output_tensors.size(); ++i) {
const string& output_tensor = output_tensors[i];
NodeDef* identity = function_body.add_node();
identity->set_name(output_node_name(out, i));
identity->set_op("Identity");
(*identity->mutable_attr())["T"].set_type(output_data_type);
identity->add_input(output_tensor);
output_nodes.push_back(identity->name());
}
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
/*output_nodes=*/std::move(output_nodes)};
outputs.push_back(std::move(output));
}
// Control outputs ensure that all side-effectful nodes in the function body
@ -295,42 +661,70 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
item);
}
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity.
Status RegisterGrapplerFunctionConnectivity(
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
GrapplerFunctionConnectivity* connectivity) {
for (const InputArgExpansion& input : item.inputs()) {
connectivity->RegisterInputArgExpansion(input);
}
for (const NodeDef& func_body_node : item.function_body().node()) {
TF_RETURN_IF_ERROR(
RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
}
return Status::OK();
}
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item) {
if (!IsConstant(input_const)) {
return errors::InvalidArgument("Input node is not a constant: ",
SummarizeNodeDef(input_const));
}
if (input_index < 0 || input_index >= item->input_size()) {
return errors::InvalidArgument(
"Function input index is out of bound: index=", input_index,
" input_size=", item->input_size());
return errors::InvalidArgument("Input node ", input_const.name(),
" is not a constant");
}
const InputArgInstantiation& input_arg = item->input(input_index);
auto& inputs = item->input_arg_expansions_;
// Find input arg expansion and input placeholder position in it for the
// given function input position.
InputArgExpansion* input_arg_expansion = nullptr;
int placeholder_idx = input_index;
for (InputArgExpansion& input : inputs) {
if (placeholder_idx < input.placeholders.size()) {
input_arg_expansion = &input;
break;
}
placeholder_idx -= input.placeholders.size();
}
if (input_arg_expansion == nullptr) {
return errors::InvalidArgument("Input placeholder not found: input_index=",
input_index, " function=", item->id);
}
// Delete placeholder from input expansion.
string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
input_arg_expansion->placeholders.erase(
input_arg_expansion->placeholders.begin() + placeholder_idx);
// Delete empty input expansions.
inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
[](const InputArgExpansion& input) {
return input.placeholders.empty();
}),
inputs.end());
// Replace placeholder node in the function body with a const node.
for (NodeDef& node : *item->graph.mutable_node()) {
// Replace '_Arg' node in the function body with a 'Const' node.
if (node.name() == input_arg.node_name) {
if (node.name() == placeholder_name) {
node = input_const;
node.set_name(input_arg.node_name);
node.clear_input();
node.set_name(placeholder_name);
node.clear_input(); // remove potential control inputs
node.clear_device(); // device placement is defined by instantiating node
}
// Update index in all inputs after the removed const input.
if (IsArg(node)) {
auto attrs = AttrSlice(node);
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
if (index >= input_index) {
(*node.mutable_attr())["index"].set_i(index - 1);
}
}
}
item->input_args_.erase(item->input_args_.begin() + input_index);
return Status::OK();
}
@ -339,24 +733,31 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
std::vector<std::pair<int, int>>* output_mapping) {
DCHECK(output_mapping->empty());
// Code below assumes that we do not support tensor list outputs and there is
// a 1-to-1 mapping between output tensor and output argument expansion.
for (const OutputArgExpansion& out_arg : item->outputs()) {
DCHECK(out_arg.output_nodes.size() == 1)
<< "Output arg expansion must have single output";
}
// Do some sanity checking of the removed outputs positions.
for (int remove_output : remove_outputs) {
if (remove_output < 0 || remove_output >= item->output_size()) {
return errors::InvalidArgument(
"Function output index is out of bound: index=", remove_output,
" output_size=", item->output_size());
" max_output_index=", item->output_size());
}
}
absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
absl::flat_hash_set<const OutputArgExpansion*> remove_output_args;
const auto is_remove_output_arg = [&](const OutputArgExpansion& output) {
return remove_output_args.find(&output) != remove_output_args.end();
};
for (int i = 0; i < item->output_size(); ++i) {
const OutputArgInstantiation& output = item->output(i);
if (remove_outputs.contains(i)) {
VLOG(3) << "Remove functions output: name=" << output.node_name
const OutputArgExpansion& output = item->output(i);
if (remove_outputs.find(i) != remove_outputs.end()) {
VLOG(3) << "Remove functions output: output_name=" << output.output_name
<< "(index = " << i << ")";
remove_output_args.insert(&output);
} else if (!remove_output_args.empty()) {
@ -365,130 +766,12 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
}
}
// Update 'index' attribute in all '_Retval' nodes that are in output mapping.
for (NodeDef& node : *item->graph.mutable_node()) {
if (IsRetval(node)) {
auto attrs = AttrSlice(node);
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
for (const auto& mapping : *output_mapping) {
const int from = mapping.first;
const int to = mapping.second;
if (index == from) {
(*node.mutable_attr())["index"].set_i(to);
}
}
}
}
auto& o = item->output_args_;
auto& o = item->output_arg_expansions_;
o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
return Status::OK();
}
namespace {
// FunctionDef uses different connectivity encoding for the function body nodes,
// than a GraphDef (see function.proto for details). This is a helper class that
// converts inputs in GraphDef format (node[:position]) to the FunctionDef
// format (node:output[:position]).
class MakeFunctionDefHelper {
public:
MakeFunctionDefHelper() = default;
Status Initialize(const GrapplerFunctionItem& item,
const FunctionLibraryDefinition& flib);
// Converts input name from GraphDef format (name[:position]) to the
// FunctionDef input format (name[:output][:position]) using registered input
// arg instantiations and function body outputs.
Status AsFunctionDefInput(const string& graph_def_input,
string* func_def_input) const;
// Updates Node inputs from GraphDef to FunctionDef format.
Status AsFunctionDefNode(NodeDef* function_body_node) const;
private:
absl::flat_hash_set<absl::string_view> input_nodes_;
// Mapping from function body node name to output names range map.
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
};
Status MakeFunctionDefHelper::Initialize(
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
for (const InputArgInstantiation& input_arg : item.inputs()) {
input_nodes_.insert(input_arg.node_name);
}
for (const NodeDef& node : item.function_body().node()) {
const OpRegistrationData* registration;
TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration->op_def, nullptr, &outputs_range_map));
function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
}
return Status::OK();
}
Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
string* func_def_input) const {
if (IsControlInput(graph_def_input)) {
*func_def_input = graph_def_input;
return Status::OK();
}
const SafeTensorId tensor = ParseTensorName(graph_def_input);
DCHECK_GE(tensor.index(), 0);
// Graph def input corresponds to one of the function inputs.
const auto is_input = input_nodes_.find(tensor.node());
if (is_input != input_nodes_.end()) {
DCHECK_EQ(tensor.index(), 0);
*func_def_input = tensor.node();
return Status::OK();
}
// Or it must be output from one of the function body nodes
const auto is_body_output = function_body_outputs_.find(tensor.node());
if (is_body_output != function_body_outputs_.end()) {
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
for (const auto& el : outputs_range_map) {
const auto& output_name = el.first;
const auto& output_range = el.second;
if (tensor.index() >= output_range.first &&
tensor.index() < output_range.second) {
*func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
tensor.index() - output_range.first);
return Status::OK();
}
}
}
return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
}
Status MakeFunctionDefHelper::AsFunctionDefNode(
NodeDef* function_body_node) const {
string func_def_input;
for (int i = 0; i < function_body_node->input_size(); ++i) {
TF_RETURN_IF_ERROR(
AsFunctionDefInput(function_body_node->input(i), &func_def_input));
function_body_node->set_input(i, func_def_input);
}
return Status::OK();
}
} // namespace
Status MakeFunctionDef(const GrapplerFunctionItem& item,
const FunctionLibraryDefinition& flib,
FunctionDef* func) {
@ -496,55 +779,86 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
func->mutable_signature()->set_description(item.description());
func->mutable_signature()->set_is_stateful(item.is_stateful());
MakeFunctionDefHelper helper;
TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
// Keep track of '_Arg' nodes that were added to the graph in place of
// instantiated function input arguments.
absl::flat_hash_set<absl::string_view> input_nodes;
for (const InputArgInstantiation& input_arg : item.inputs()) {
input_nodes.insert(input_arg.node_name);
// Keep track of placeholders that were added to the graph in place of
// expanded function input arguments.
absl::flat_hash_set<absl::string_view> input_placeholders;
for (const InputArgExpansion& input_arg : item.inputs()) {
for (const string& placeholder : input_arg.placeholders) {
input_placeholders.insert(placeholder);
}
}
// Mapping from the '_Retval' node name to the output tensor.
// Keep track of identity nodes that were added to the graph in place of
// expanded function output arguments.
absl::flat_hash_set<absl::string_view> output_nodes;
for (const OutputArgExpansion& output_arg : item.outputs()) {
for (const string& output_node : output_arg.output_nodes) {
output_nodes.insert(output_node);
}
}
// If the output identity node was not modified by any optimizer, we can
// bypass it and returns the function value from its input.
absl::flat_hash_map<absl::string_view, string> output_tensors;
for (const NodeDef& func_body_node : item.function_body().node()) {
if (!IsRetval(func_body_node)) continue;
if (func_body_node.input_size() != 1) {
return errors::Internal("_Retval node must have single input: ",
SummarizeNodeDef(func_body_node));
if (!IsIdentity(func_body_node)) continue;
const string& node_name = func_body_node.name();
if (output_nodes.find(node_name) != output_nodes.end()) {
// Grappler optimizers might optimize nodes in the fanin of the output
// node, and forward their control dependencies. We can't express control
// dependencies in a function signature, so we have to keep the node.
if (func_body_node.input_size() == 1) {
VLOG(3) << "Bypass function output node: " << node_name << " -> "
<< func_body_node.input(0);
output_tensors.emplace(node_name, func_body_node.input(0));
} else {
VLOG(3) << "Keep function output node: " << node_name;
}
}
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
}
for (const InputArgInstantiation& input_arg : item.inputs()) {
// Return output tensor name (input of the output node) if it's safe to bypass
// output node, otherwise returns the output node name.
const auto output_tensor =
[&output_tensors](const OutputArgExpansion& output_arg) -> const string& {
const string& output_node = output_arg.output_nodes[0];
const auto is_output_tensor = output_tensors.find(output_node);
return is_output_tensor == output_tensors.end() ? output_node
: is_output_tensor->second;
};
// Build a GrapplerFunctionConnectivity from inputs and new function body.
GrapplerFunctionConnectivity connectivity;
TF_RETURN_IF_ERROR(
RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
// Add function input arguments.
for (const InputArgExpansion& input_arg : item.inputs()) {
DCHECK(input_arg.placeholders.size() == 1) // do some sanity checking
<< "Inputs of tensor lists are not supported";
OpDef::ArgDef arg_def;
arg_def.set_name(input_arg.node_name);
arg_def.set_name(input_arg.input_name);
arg_def.set_type(input_arg.data_type);
arg_def.set_is_ref(IsRefType(input_arg.data_type));
arg_def.set_is_ref(input_arg.is_ref);
*func->mutable_signature()->add_input_arg() = arg_def;
}
// Add function output arguments.
for (const OutputArgInstantiation& output_arg : item.outputs()) {
const string output_name =
absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
for (const OutputArgExpansion& output_arg : item.outputs()) {
DCHECK(output_arg.output_nodes.size() == 1) // do some sanity checking
<< "Outputs of tensor lists are not supported";
OpDef::ArgDef arg_def;
arg_def.set_name(output_name);
arg_def.set_name(output_arg.output_name);
arg_def.set_type(output_arg.data_type);
arg_def.set_is_ref(IsRefType(output_arg.data_type));
arg_def.set_is_ref(output_arg.is_ref);
*func->mutable_signature()->add_output_arg() = arg_def;
auto it = output_tensors.find(output_arg.node_name);
if (it == output_tensors.end()) {
return errors::Internal(
"Can't find an output tensor for the output node: ",
output_arg.node_name);
}
TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
it->second, &(*func->mutable_ret())[output_name]));
TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(
output_tensor(output_arg),
&(*func->mutable_ret())[output_arg.output_name]));
}
// Add function control outputs.
@ -563,12 +877,16 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
// Copy function body nodes to the FunctionDef and update input format
for (const NodeDef& func_node : item.function_body().node()) {
// Do not copy input/output nodes.
if (IsArg(func_node) || IsRetval(func_node)) continue;
const string& name = func_node.name();
// Do not copy input placeholders.
if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue;
// Do not copy output nodes that we bypassed.
if (IsIdentity(func_node) && output_tensors.count(name)) continue;
NodeDef* func_def_node = func->add_node_def();
*func_def_node = func_node;
TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
}
return Status::OK();

View File

@ -33,22 +33,45 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Function input argument instantiated into an '_Arg' node in the function body
// graph, with an 'index' attribute corresponding to the input position.
struct InputArgInstantiation {
InputArgInstantiation(string node_name, DataType data_type)
: node_name(std::move(node_name)), data_type(data_type) {}
string node_name;
// WARNING(ezhulenev): Currently we do not support functions with inputs or
// outputs instantiated into multiple tensors. This can happen if the
// input/output type is 'T*N' or 'list(type)'. This is enforced by multiple
// checks across this file and also function_optimizer.cc. InputArgExpansion and
// OutputArgExpansion already support lists of tensors, but that's pretty much
// it, all other code is written with assumption that expansions are always of
// size 1. MakeGrapplerFunctionItem will gracefully fail with Status error.
//
// This is a low priority feature, because in practice we don't see a lot (any
// at all?) functions with such arguments. Tensorflow-Eager always produces
// functions with plain input/output arguments.
// TODO(ezhulenev): Support inputs and outputs of type 'T*N'.
// TODO(ezhulenev): Support inputs and outputs of type 'list(type)'.
// Depending on the function instantiation attributes, input argument to the
// function might be a single tensor, list of tensors of the same type, or a
// list of tensors of different types.
//
// InputArgExpansion keeps track of the placeholders that were added to the
// function body in place of function inputs and a resolved input data type.
struct InputArgExpansion {
string input_name;
DataType data_type;
bool is_ref;
absl::InlinedVector<string, 1> placeholders;
};
// Function output instantiated into a '_Retval' node in the function body
// graph, with an 'index' attribute corresponding to the output position.
struct OutputArgInstantiation {
OutputArgInstantiation(string node_name, DataType data_type)
: node_name(std::move(node_name)), data_type(data_type) {}
string node_name;
// Depending on the function instantiation attributes, output argument is mapped
// to one or more outputs of one of the function body nodes.
//
// OutputArgExpansion keeps track of the Identity nodes that were added to the
// function body to forward output tensors. Adding these output nodes allows
// nested function inlining and specialization (see function optimizer).
struct OutputArgExpansion {
string output_name;
DataType data_type;
bool is_ref;
absl::InlinedVector<string, 1> output_nodes;
};
// A mapping from control output name to node name in function body graph.
@ -57,6 +80,78 @@ struct ControlOutput {
string node_name;
};
// FunctionDef uses different connectivity encoding for the function body nodes,
// then a GraphDef (see function.proto for details). Input name in FunctionDef
// can potentially represent a sequence of tensors (instead just one tensor in
// GraphDef), we need to expand it when converting from FunctionDef to GraphDef,
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
tensorflow::NameRangeMap&& outputs);
// Expands input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
Status ExpandFunctionDefInput(const string& func_def_input,
std::vector<string>* graph_def_inputs) const;
// Updates Node inputs from FunctionDef to GraphDef format.
Status ExpandNodeInputs(NodeDef* function_body_node) const;
// When expanding inputs in function def format, single input might be
// expanded into multiple tensors. When converting back to the function def
// format from graph def format, it's always a 1-to-1 relationship.
// FunctionDef built from GrapplerFunctionItem is always specialized to its
// instantiation attributes and length of input args (and node def outputs) is
// known.
// Converts input name from GraphDef format (name[:position]) to the
// FunctionDef input format (name[:output][:position]) using registered input
// arg expansion and function body outputs.
Status AsFunctionDefInput(const string& graph_def_input,
string* func_def_input) const;
// Updates Node inputs from GraphDef to FunctionDef format.
Status AsFunctionDefNode(NodeDef* function_body_node) const;
private:
// Mapping from input name to input arg expansion.
absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
// Mapping from function body node name to output names range map.
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
// For each placeholder added to the function instantiation graph, we keep a
// mapping back to the function input argument name and index.
struct InputArgPlaceholder {
string input_name; // Name of the function input argument.
int input_index; // Index of a tensor in the function input argument
// expansion, it can be greater than `0` if input
// argument is a list of tensors (aka list(type)).
};
// Mapping from input arg placeholder to the function input tensor.
absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
};
// Get Function type attributes using attributes of a node that instantiated
// a function.
class GrapplerFunctionItemInstantiation {
public:
explicit GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr)
: func_instantiation_attr_(func_instantiation_attr) {}
// Get DataType from attributes by name. Return error if attribute is missing,
// or it doesn't define a valid data type.
Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const;
// Get argument data type. If data type is not explicitly defined, uses
// provided attribute name to look it up in function attributes.
Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const;
private:
const AttrSlice func_instantiation_attr_; // do not own
};
// A special case of GrapplerItem, constructed from a TensorFlow Function.
class GrapplerFunctionItem : public GrapplerItem {
public:
@ -64,12 +159,12 @@ class GrapplerFunctionItem : public GrapplerItem {
const string& description() const;
const std::vector<InputArgInstantiation>& inputs() const;
const InputArgInstantiation& input(int i) const;
const std::vector<InputArgExpansion>& inputs() const;
const InputArgExpansion& input(int i) const;
const std::size_t input_size() const;
const std::vector<OutputArgInstantiation>& outputs() const;
const OutputArgInstantiation& output(int i) const;
const std::vector<OutputArgExpansion>& outputs() const;
const OutputArgExpansion& output(int i) const;
const std::size_t output_size() const;
const std::vector<ControlOutput>& control_outputs() const;
@ -95,8 +190,8 @@ class GrapplerFunctionItem : public GrapplerItem {
GrapplerFunctionItem(string func_name, string description,
AttrSlice func_attr,
std::vector<InputArgInstantiation> input_args,
std::vector<OutputArgInstantiation> output_args,
std::vector<InputArgExpansion> input_arg_expansions,
std::vector<OutputArgExpansion> output_arg_expansions,
std::vector<ControlOutput> control_outputs,
int graph_def_version, bool is_stateful,
GraphDef&& function_body);
@ -105,8 +200,8 @@ class GrapplerFunctionItem : public GrapplerItem {
AttrSlice func_attr_; // Attributes specific to function definition that
// produced this item (FuncDef.attr field).
std::vector<InputArgInstantiation> input_args_;
std::vector<OutputArgInstantiation> output_args_;
std::vector<InputArgExpansion> input_arg_expansions_;
std::vector<OutputArgExpansion> output_arg_expansions_;
std::vector<ControlOutput> control_outputs_;
bool is_stateful_ = false;
@ -137,13 +232,22 @@ Status InstantiationBodyParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
absl::flat_hash_map<string, AttrValue>* body_parameters);
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to
// lookup function body nodes output names and ranges.
Status RegisterGrapplerFunctionConnectivity(
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
GrapplerFunctionConnectivity* connectivity);
// Replace one of the function inputs with a constant.
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item);
// Removes outputs from instantiated grappler function item. For all active
// function outputs that changed its output index, this function adds an output
// mapping (std::pair<old index, new index>).
// Removes outputs from instantiated grappler function item. Function node
// outputs use GraphDef output index encoding, and multiple outputs might belong
// to the same output argument expansion (in case of tensor list outputs). For
// all active function outputs that changed its output index, this function adds
// an output mapping (std::pair<old index, new index>).
Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping);

View File

@ -63,16 +63,11 @@ TEST_F(FunctionsTest, InstantiationParameters) {
FunctionDef func = FunctionDefHelper::Create(
"ParametrizedFunc",
/* inputs */
{"input1:A", "input2:B", "input3:float", "input4: C"},
{"input1:A", "input2:B", "input3:float"},
/* outputs */
{"output1: A", "output2:D"},
{"output1: A", "output2:C"},
/* type parameters */
{
"A: {float, double}",
"B: {float, int32}",
"C: list(type)",
"D: {float, double}",
},
{"A: {float, double}", "B: {float, int32}", "C: {float, double}"},
/* function body*/
{{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}},
/* Mapping between function returns and function node outputs. */
@ -82,20 +77,16 @@ TEST_F(FunctionsTest, InstantiationParameters) {
func_instantiation_attr["key"].set_s("key-value");
func_instantiation_attr["A"].set_type(DT_FLOAT);
func_instantiation_attr["B"].set_type(DT_INT32);
func_instantiation_attr["C"].mutable_list()->add_type(DT_FLOAT);
func_instantiation_attr["C"].mutable_list()->add_type(DT_INT32);
func_instantiation_attr["D"].set_type(DT_DOUBLE);
func_instantiation_attr["C"].set_type(DT_DOUBLE);
absl::flat_hash_map<string, DataType> type_parameters;
TF_EXPECT_OK(InstantiationTypeParameters(
func, AttrSlice(&func_instantiation_attr), &type_parameters));
ASSERT_EQ(5, type_parameters.size());
ASSERT_EQ(3, type_parameters.size());
EXPECT_EQ(DT_FLOAT, type_parameters["A"]);
EXPECT_EQ(DT_INT32, type_parameters["B"]);
EXPECT_EQ(DT_FLOAT, type_parameters["C:0"]);
EXPECT_EQ(DT_INT32, type_parameters["C:1"]);
EXPECT_EQ(DT_DOUBLE, type_parameters["D"]);
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
absl::flat_hash_map<string, AttrValue> body_parameters;
TF_EXPECT_OK(InstantiationBodyParameters(
@ -105,6 +96,131 @@ TEST_F(FunctionsTest, InstantiationParameters) {
EXPECT_EQ("key-value", body_parameters["key"].s());
}
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
GrapplerFunctionConnectivity connectivity;
connectivity.RegisterInputArgExpansion(
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
connectivity.RegisterInputArgExpansion(
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
connectivity.RegisterFunctionBodyOutputs("Func",
{{"o1", {0, 2}}, {"o2", {2, 4}}});
std::vector<string> inputs;
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputA", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("inputA", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB", &inputs));
ASSERT_EQ(2, inputs.size());
EXPECT_EQ("inputB_0", inputs[0]);
EXPECT_EQ("inputB_1", inputs[1]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB:1", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("inputB_1", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Add:z", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("Add", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1", &inputs));
ASSERT_EQ(2, inputs.size());
EXPECT_EQ("Func", inputs[0]);
EXPECT_EQ("Func:1", inputs[1]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2", &inputs));
ASSERT_EQ(2, inputs.size());
EXPECT_EQ("Func:2", inputs[0]);
EXPECT_EQ("Func:3", inputs[1]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:0", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("Func", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:1", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("Func:1", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:0", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("Func:2", inputs[0]);
inputs.clear();
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:1", &inputs));
ASSERT_EQ(1, inputs.size());
EXPECT_EQ("Func:3", inputs[0]);
}
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) {
GrapplerFunctionConnectivity connectivity;
connectivity.RegisterInputArgExpansion(
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
connectivity.RegisterInputArgExpansion(
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
connectivity.RegisterFunctionBodyOutputs("Func",
{{"o1", {0, 2}}, {"o2", {2, 4}}});
string input;
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input));
EXPECT_EQ("inputA:0", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input));
EXPECT_EQ("inputB:0", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input));
EXPECT_EQ("inputB:1", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input));
EXPECT_EQ("Add:z:0", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input));
EXPECT_EQ("Func:o1:0", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input));
EXPECT_EQ("Func:o1:1", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input));
EXPECT_EQ("Func:o2:0", input);
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input));
EXPECT_EQ("Func:o2:1", input);
}
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) {
GrapplerFunctionConnectivity connectivity;
connectivity.RegisterInputArgExpansion(
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
connectivity.RegisterInputArgExpansion(
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
NodeDef node;
node.add_input("inputA:0");
node.add_input("inputB");
TF_EXPECT_OK(connectivity.ExpandNodeInputs(&node));
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("inputA", node.input(0));
EXPECT_EQ("inputB_0", node.input(1));
EXPECT_EQ("inputB_1", node.input(2));
}
TEST_F(FunctionsTest, FromSimpleFunctionDef) {
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(
@ -136,17 +252,19 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
EXPECT_EQ(5, item.function_body().node_size());
EXPECT_EQ(1, item.input_size());
EXPECT_EQ("x", item.input(0).node_name);
EXPECT_EQ("x", item.input(0).input_name);
ASSERT_EQ(1, item.input(0).placeholders.size());
EXPECT_EQ("x", item.input(0).placeholders[0]);
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("y_RetVal", item.output(0).node_name);
EXPECT_EQ("y", item.output(0).output_name);
EXPECT_EQ("y_output_node_0", item.output(0).output_nodes[0]);
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "x" && ++count) {
EXPECT_EQ("_Arg", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
EXPECT_EQ(0, node.attr().at("index").i());
EXPECT_EQ("Placeholder", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "two" && ++count) {
EXPECT_EQ("Const", node.op());
@ -162,11 +280,10 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("scale", node.input(1));
} else if (node.name() == "y_RetVal" && ++count) {
EXPECT_EQ("_Retval", node.op());
} else if (node.name() == "y_output_node_0" && ++count) {
EXPECT_EQ("Identity", node.op());
ASSERT_EQ(1, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(0, node.attr().at("index").i());
}
}
EXPECT_EQ(5, count);
@ -217,22 +334,20 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
EXPECT_EQ(14, item.function_body().node_size());
ASSERT_EQ(3, item.input_size());
EXPECT_EQ("x", item.input(0).node_name);
EXPECT_EQ("y", item.input(1).node_name);
EXPECT_EQ("dz", item.input(2).node_name);
EXPECT_EQ("x", item.input(0).input_name);
EXPECT_EQ("y", item.input(1).input_name);
EXPECT_EQ("dz", item.input(2).input_name);
ASSERT_EQ(2, item.output_size());
EXPECT_EQ("dx_RetVal", item.output(0).node_name);
EXPECT_EQ("dy_RetVal", item.output(1).node_name);
EXPECT_EQ("dx_output_node_0", item.output(0).output_nodes[0]);
EXPECT_EQ("dy_output_node_0", item.output(1).output_nodes[0]);
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "x" || node.name() == "y" || node.name() == "dz") {
count++;
EXPECT_EQ("_Arg", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
int expected_index = node.name() == "x" ? 0 : node.name() == "y" ? 1 : 2;
EXPECT_EQ(expected_index, node.attr().at("index").i());
EXPECT_EQ("Placeholder", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "rx" && ++count) {
EXPECT_EQ("BroadcastGradientArgs", node.op());
@ -249,14 +364,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("gy", node.input(0));
EXPECT_EQ("rx:1", node.input(1));
} else if (node.name() == "dx_RetVal" && ++count) {
EXPECT_EQ("_Retval", node.op());
EXPECT_EQ(0, node.attr().at("index").i());
} else if (node.name() == "dx_output_node_0" && ++count) {
EXPECT_EQ("Identity", node.op());
ASSERT_EQ(1, node.input_size());
EXPECT_EQ("dx", node.input(0));
} else if (node.name() == "dy_RetVal" && ++count) {
EXPECT_EQ("_Retval", node.op());
EXPECT_EQ(1, node.attr().at("index").i());
} else if (node.name() == "dy_output_node_0" && ++count) {
EXPECT_EQ("Identity", node.op());
ASSERT_EQ(1, node.input_size());
EXPECT_EQ("dy", node.input(0));
}
@ -312,10 +425,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "x" || node.name() == "y") {
count++;
EXPECT_EQ("_Arg", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
int expected_index = node.name() == "x" ? 0 : 1;
EXPECT_EQ(expected_index, node.attr().at("index").i());
EXPECT_EQ("Placeholder", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "a0" && ++count) {
EXPECT_EQ("Swap", node.op());
@ -374,14 +485,13 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
flib, TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("out_RetVal", item.output(0).node_name);
EXPECT_EQ("out_output_node_0", item.output(0).output_nodes[0]);
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "in" && ++count) {
EXPECT_EQ("_Arg", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
EXPECT_EQ(0, node.attr().at("index").i());
EXPECT_EQ("Placeholder", node.op());
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "Linear_func" && ++count) {
EXPECT_EQ("Identity", node.op());
@ -391,9 +501,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
EXPECT_EQ("Exp", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("Linear_func", node.input(0));
} else if (node.name() == "out_RetVal" && ++count) {
EXPECT_EQ("_Retval", node.op());
EXPECT_EQ(0, node.attr().at("index").i());
} else if (node.name() == "out_output_node_0" && ++count) {
EXPECT_EQ("Identity", node.op());
ASSERT_EQ(1, node.input_size());
EXPECT_EQ("Exp", node.input(0));
}
@ -401,6 +510,70 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
EXPECT_EQ(4, count);
}
TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
FunctionDef func = FunctionDefHelper::Create(
// Name
"ForwardInputs",
// Args
{"in0: float", "in1: float", "arg2: float", "arg3: int32", "arg4: float"},
// Return values
{"out0: float", "arg2: float", "arg3: int32"},
// Attr def
{},
// Nodes
{},
// Mapping
{{"out0", "in0"}});
protobuf::Map<string, AttrValue> func_instantiation_attr;
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
TF_EXPECT_OK(MakeGrapplerFunctionItem(func,
AttrSlice(&func_instantiation_attr),
flib, TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("ForwardInputs", item.id);
EXPECT_EQ(8, item.function_body().node_size());
EXPECT_EQ(3, item.output_size());
EXPECT_EQ("out0_output_node_0", item.output(0).output_nodes[0]);
EXPECT_EQ("arg2_output_node_0", item.output(1).output_nodes[0]);
EXPECT_EQ("arg3_output_node_0", item.output(2).output_nodes[0]);
int count = 0;
const auto is_arg_placeholder = [](const string &name) {
return name == "in0" || name == "in1" || name == "arg2" || name == "arg3" ||
name == "arg4";
};
for (const NodeDef &node : item.function_body().node()) {
if (is_arg_placeholder(node.name()) && node.op() == "Placeholder") {
count++;
if (node.name() == "arg3") {
EXPECT_EQ(DT_INT32, node.attr().at("dtype").type());
} else {
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
}
continue;
}
EXPECT_EQ("Identity", node.op());
ASSERT_EQ(1, node.input_size());
EXPECT_TRUE(is_arg_placeholder(node.input(0)));
if (node.name() == "out0_output_node_0" && ++count) {
EXPECT_EQ("in0", node.input(0));
} else if (node.name() == "arg2_output_node_0" && ++count) {
EXPECT_EQ("arg2", node.input(0));
} else if (node.name() == "arg3_output_node_0" && ++count) {
EXPECT_EQ("arg3", node.input(0));
}
}
EXPECT_EQ(8, count);
}
TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(
@ -427,7 +600,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
EXPECT_EQ(0, item.input_size());
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("o_RetVal", item.output(0).node_name);
EXPECT_EQ("o_output_node_0", item.output(0).output_nodes[0]);
EXPECT_EQ(3, item.function_body().node_size());
const NodeDef &two = item.function_body().node(0);
@ -440,7 +613,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
EXPECT_EQ("two", cast.input(0));
const NodeDef &retval = item.function_body().node(2);
EXPECT_EQ("o_RetVal", retval.name());
EXPECT_EQ("o_output_node_0", retval.name());
EXPECT_EQ(1, retval.input_size());
EXPECT_EQ("o", retval.input(0));
}
@ -541,14 +714,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
EXPECT_EQ("y", specialized.signature().output_arg(0).name());
EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
// Function body specialized for instantiation types.
// Function body specialized for instantiation types
int count = 0;
for (const NodeDef &node : specialized.node_def()) {
if (node.name() == "scale" && ++count) {
EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type());
} else if (node.name() == "y" && ++count) {
EXPECT_EQ("Mul", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("x:0", node.input(0));
EXPECT_EQ("scale:y:0", node.input(1));
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
}
@ -580,9 +753,9 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) {
const NodeDef &input_x = item.function_body().node(0);
const NodeDef &input_y = item.function_body().node(1);
// Initially inputs added to the graph as _Arg nodes.
EXPECT_EQ("_Arg", input_x.op());
EXPECT_EQ("_Arg", input_y.op());
// Initially inputs added to the graph as placeholders.
EXPECT_EQ("Placeholder", input_x.op());
EXPECT_EQ("Placeholder", input_y.op());
// Replace inputs x and y with constants.
NodeDef const_input_x;
@ -651,7 +824,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
GraphDef id_func_body = test::function::GDef(
{/* Read and return input argument through Identity node. */
NDef("read_x", "Identity", {"x"}, {{"T", "float"}}),
NDef("z_RetVal", "_Retval", {"read_x"}, {{"T", "float"}})});
NDef("z_output_node_0", "Identity", {"read_x"}, {{"T", "float"}})});
protobuf::Map<string, AttrValue> func_instantiation_attr;
func_instantiation_attr["T"].set_type(DT_FLOAT);
@ -676,7 +849,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
for (const NodeDef &node : specialized.node_def()) {
if (node.name() == "read_x" && ++count) {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("x:0", node.input(0));
}
}
EXPECT_EQ(1, count);