Automated rollback of commit 92f736f429e398df261cd2f3c8c949840dd06a76
PiperOrigin-RevId: 240915460
This commit is contained in:
parent
937bf0a2e6
commit
9e467f4df3
@ -177,8 +177,7 @@ class FunctionInstantiationHelper {
|
||||
} else {
|
||||
gnode->set_op(FunctionLibraryDefinition::kArgOp);
|
||||
}
|
||||
DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
|
||||
AddAttr("T", dtype, gnode);
|
||||
AddAttr("T", dtypes[i], gnode);
|
||||
AddAttr("index", arg_index, gnode);
|
||||
result_.arg_types.push_back(dtypes[i]);
|
||||
++arg_index;
|
||||
@ -344,8 +343,7 @@ class FunctionInstantiationHelper {
|
||||
gnode->set_op(FunctionLibraryDefinition::kRetOp);
|
||||
}
|
||||
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
|
||||
DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
|
||||
AddAttr("T", dtype, gnode);
|
||||
AddAttr("T", dtypes[i], gnode);
|
||||
AddAttr("index", (*ret_index)++, gnode);
|
||||
result_.ret_types.push_back(dtypes[i]);
|
||||
}
|
||||
|
||||
@ -41,7 +41,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
|
||||
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -604,31 +603,22 @@ class SymbolicShapeRefiner {
|
||||
" was not previously added to SymbolicShapeRefiner.");
|
||||
}
|
||||
|
||||
const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
|
||||
it->second;
|
||||
if (!maybe_grappler_function_item.has_value()) {
|
||||
VLOG(3) << "Skip failed to instantiate function call: function_name="
|
||||
<< function_node->op();
|
||||
|
||||
auto* ctx = GetNodeContext(function_node);
|
||||
auto* ic = ctx->inference_context.get();
|
||||
for (int i = 0; i < ic->num_outputs(); ++i) {
|
||||
TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copy (not reference) so that changes we make here (e.g., replacing
|
||||
// _Arg with Const and _Retval with Identity) don't affect one in
|
||||
// Placeholder with Const) don't affect one in
|
||||
// fun_to_grappler_function_item_.
|
||||
GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
|
||||
GrapplerFunctionItem grappler_function_item = it->second;
|
||||
MutableGraphView gv(&grappler_function_item.graph);
|
||||
|
||||
// Forward shapes from function input nodes to argument nodes.
|
||||
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
|
||||
auto& fun_input = grappler_function_item.input(i);
|
||||
NodeDef* fun_node = gv.GetNode(fun_input.node_name);
|
||||
if (fun_input.placeholders.size() > 1) {
|
||||
// TODO(jmdecker): Handle case with multiple input placeholders
|
||||
return errors::Unimplemented(
|
||||
"Input arguments with multiple placeholders are not yet "
|
||||
"supported.");
|
||||
}
|
||||
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
|
||||
const TensorId input_tensor = ParseTensorName(function_node->input(i));
|
||||
|
||||
if (IsControlInput(input_tensor)) {
|
||||
@ -659,18 +649,11 @@ class SymbolicShapeRefiner {
|
||||
proto.mutable_dim(i)->set_size(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Turn _Arg node into a Placeholder. _Arg node is a system op without a
|
||||
// valid shape function.
|
||||
*attr_output_shape.mutable_shape() = proto;
|
||||
fun_node->set_op("Placeholder");
|
||||
(*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"];
|
||||
(*fun_node->mutable_attr()).erase("index");
|
||||
(*fun_node->mutable_attr()).erase("T");
|
||||
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
|
||||
}
|
||||
|
||||
// Replace input nodes with Consts, if values are known. Note that
|
||||
// Replace input Placeholders with Consts, if values are known. Note that
|
||||
// we don't check exceptions here as it's done in the above loop.
|
||||
auto* ctx = GetNodeContext(function_node);
|
||||
auto* ic = ctx->inference_context.get();
|
||||
@ -701,15 +684,6 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
}
|
||||
|
||||
// Replace output _Retval nodes with Identity nodes. _Retval is a system op
|
||||
// without outputs and registered shape function.
|
||||
for (const auto& output_arg : grappler_function_item.outputs()) {
|
||||
NodeDef* output_node = gv.GetNode(output_arg.node_name);
|
||||
DCHECK_EQ(output_node->op(), "_Retval");
|
||||
output_node->set_op("Identity");
|
||||
output_node->mutable_attr()->erase("index");
|
||||
}
|
||||
|
||||
// Perform inference on function body.
|
||||
GraphProperties gp(grappler_function_item);
|
||||
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
|
||||
@ -720,9 +694,16 @@ class SymbolicShapeRefiner {
|
||||
ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
|
||||
nullptr);
|
||||
for (auto const& out_arg : grappler_function_item.outputs()) {
|
||||
if (out_arg.output_nodes.size() > 1) {
|
||||
// TODO(jmdecker): Handle case of multiple output tensors
|
||||
return errors::Unimplemented(
|
||||
"Output arguments with multiple output tensors are not yet "
|
||||
"supported.");
|
||||
}
|
||||
|
||||
// It is guaranteed that output_tensors does not contain any control
|
||||
// inputs, so port_id >= 0.
|
||||
TensorId out_tensor = ParseTensorName(out_arg.node_name);
|
||||
TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]);
|
||||
|
||||
const NodeDef* retnode = gv.GetNode(out_tensor.node());
|
||||
if (retnode == nullptr) {
|
||||
@ -1061,18 +1042,9 @@ class SymbolicShapeRefiner {
|
||||
CHECK_NOTNULL(function_library_.Find(function_node->op()));
|
||||
|
||||
GrapplerFunctionItem grappler_function_item;
|
||||
Status function_instantiated =
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeGrapplerFunctionItem(*function_def, function_library_,
|
||||
graph_def_version_, &grappler_function_item);
|
||||
|
||||
// If function instantiation failed we will skip it during shape inference.
|
||||
if (!function_instantiated.ok()) {
|
||||
VLOG(3) << "Failed to instantiate a function. Error: "
|
||||
<< function_instantiated.error_message();
|
||||
fun_to_grappler_function_item_[function_def->signature().name()] =
|
||||
absl::nullopt;
|
||||
return Status::OK();
|
||||
}
|
||||
graph_def_version_, &grappler_function_item));
|
||||
|
||||
if (grappler_function_item.inputs().size() > function_node->input_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
@ -1719,9 +1691,7 @@ class SymbolicShapeRefiner {
|
||||
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
|
||||
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
|
||||
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
|
||||
// Store function instantiations only for valid function. If function
|
||||
// instantiation failed it will have an `absl::nullopt`.
|
||||
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
|
||||
std::unordered_map<string, GrapplerFunctionItem>
|
||||
fun_to_grappler_function_item_;
|
||||
FunctionLibraryDefinition function_library_;
|
||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
|
||||
|
||||
@ -67,10 +67,6 @@ bool IsApproximateEqual(const NodeDef& node) {
|
||||
return node.op() == "ApproximateEqual";
|
||||
}
|
||||
|
||||
bool IsArg(const NodeDef& node) {
|
||||
return node.op() == "_Arg" || node.op() == "_DeviceArg";
|
||||
}
|
||||
|
||||
bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
|
||||
|
||||
bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
|
||||
@ -423,10 +419,6 @@ bool IsRestore(const NodeDef& node) {
|
||||
node.op() == "RestoreSlice");
|
||||
}
|
||||
|
||||
bool IsRetval(const NodeDef& node) {
|
||||
return node.op() == "_Retval" || node.op() == "_DeviceRetval";
|
||||
}
|
||||
|
||||
bool IsReverse(const NodeDef& node) {
|
||||
return node.op() == "Reverse" || node.op() == "ReverseV2";
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ bool IsAnyMaxPool(const NodeDef& node);
|
||||
bool IsAnyMin(const NodeDef& node);
|
||||
bool IsAnyMul(const NodeDef& node);
|
||||
bool IsApproximateEqual(const NodeDef& node);
|
||||
bool IsArg(const NodeDef& node);
|
||||
bool IsArgMax(const NodeDef& node);
|
||||
bool IsArgMin(const NodeDef& node);
|
||||
bool IsAssert(const NodeDef& node);
|
||||
@ -138,7 +137,6 @@ bool IsRelu6Grad(const NodeDef& node);
|
||||
bool IsReluGrad(const NodeDef& node);
|
||||
bool IsReshape(const NodeDef& node);
|
||||
bool IsRestore(const NodeDef& node);
|
||||
bool IsRetval(const NodeDef& node);
|
||||
bool IsReverse(const NodeDef& node);
|
||||
bool IsReverseV2(const NodeDef& node);
|
||||
bool IsRsqrt(const NodeDef& node);
|
||||
|
||||
@ -226,19 +226,11 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
||||
|
||||
// Replace optimized function with a new FunctionDef.
|
||||
TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func));
|
||||
} else {
|
||||
VLOG(2) << "Failed to optimize dataset function. Error: "
|
||||
<< s.error_message();
|
||||
}
|
||||
} else if (IsDatasetNodeOfType(node, kSourceDatasetOps)) {
|
||||
return errors::InvalidArgument(
|
||||
"Reached a source dataset: ", node.op(),
|
||||
" without encountering a batch transformation.");
|
||||
} else if (IsRetval(node)) {
|
||||
// _Retvals added to the function body graph in place of function outputs.
|
||||
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||
TF_RETURN_IF_ERROR(
|
||||
RecursivelyHandleOp(*input_node, num_workers, flib, graph));
|
||||
} else {
|
||||
return errors::InvalidArgument("Encountered an unsupported op: ",
|
||||
node.op());
|
||||
|
||||
@ -76,7 +76,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
||||
return false;
|
||||
}
|
||||
for (const auto& consumer : node_map_->GetOutputs(node.name())) {
|
||||
if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
|
||||
if (node.input_size() > 1 && IsMerge(*consumer)) {
|
||||
return false;
|
||||
}
|
||||
if (IsSwitch(*input)) {
|
||||
|
||||
@ -109,10 +109,6 @@ bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
|
||||
// Check if func_node has function attribute with a function name matching
|
||||
// FunctionDef signature.
|
||||
bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
|
||||
if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
|
||||
return func_attr != nullptr && func_attr->has_func() &&
|
||||
func_attr->func().name() == func.signature().name();
|
||||
@ -824,7 +820,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
|
||||
// update outputs for the fetch nodes, so we just skip them.
|
||||
std::vector<std::pair<int, int>> output_mapping;
|
||||
if (!signature.is_in_fetch_set) {
|
||||
int num_func_outputs = item.output_size();
|
||||
int num_func_outputs = 0;
|
||||
for (const auto& out_arg : item.outputs()) {
|
||||
num_func_outputs += out_arg.output_nodes.size();
|
||||
}
|
||||
|
||||
absl::flat_hash_set<int> remove;
|
||||
for (int i = 0; i < num_func_outputs; ++i) {
|
||||
@ -975,8 +974,10 @@ NodeDef InlinedFunctionInputsNode(const NodeDef& func_node,
|
||||
AttrValue::ListValue* type_list =
|
||||
(*inputs.mutable_attr())["T"].mutable_list();
|
||||
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
type_list->add_type(input_arg.data_type);
|
||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||
for (int i = 0; i < input_arg.placeholders.size(); ++i) {
|
||||
type_list->add_type(input_arg.data_type);
|
||||
}
|
||||
}
|
||||
|
||||
return inputs;
|
||||
@ -995,11 +996,12 @@ NodeDef InlinedFunctionOutputsNode(
|
||||
AttrValue::ListValue* type_list =
|
||||
(*outputs.mutable_attr())["T"].mutable_list();
|
||||
|
||||
for (const OutputArgInstantiation& output_arg : item.outputs()) {
|
||||
const absl::string_view output_tensor =
|
||||
output_tensors.at(output_arg.node_name);
|
||||
type_list->add_type(output_arg.data_type);
|
||||
outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
|
||||
for (const OutputArgExpansion& output_arg : item.outputs()) {
|
||||
for (const string& output_node : output_arg.output_nodes) {
|
||||
const absl::string_view output_tensor = output_tensors.at(output_node);
|
||||
type_list->add_type(output_arg.data_type);
|
||||
outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
@ -1026,24 +1028,29 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
|
||||
". Error: ", item_status.error_message());
|
||||
}
|
||||
|
||||
// Mapping from input arg node name to function input position.
|
||||
absl::flat_hash_map<absl::string_view, int> input_args_idx;
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
const int idx = input_args_idx.size();
|
||||
input_args_idx[input_arg.node_name] = idx;
|
||||
// Mapping from input placeholder name to function input position.
|
||||
absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
|
||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||
for (const string& placeholder : input_arg.placeholders) {
|
||||
const int idx = input_placeholders_idx.size();
|
||||
input_placeholders_idx[placeholder] = idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Mapping from the '_Retval' node name to the output tensor.
|
||||
absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
|
||||
for (const NodeDef& func_body_node : item.function_body().node()) {
|
||||
if (!IsRetval(func_body_node)) continue;
|
||||
if (func_body_node.input_size() != 1) {
|
||||
return errors::Internal("_Retval node must have single input: ",
|
||||
SummarizeNodeDef(func_body_node));
|
||||
// Bypass identity nodes added to the graph in place of function outputs.
|
||||
absl::flat_hash_set<absl::string_view> output_nodes;
|
||||
for (const OutputArgExpansion& output_arg : item.outputs()) {
|
||||
for (const string& output_node : output_arg.output_nodes) {
|
||||
output_nodes.insert(output_node);
|
||||
}
|
||||
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
|
||||
}
|
||||
|
||||
// For each function output value we added an identity node that reads the
|
||||
// tensor from one of the function body nodes. When we inline function into
|
||||
// the main graph we want to bypass these nodes, so we keep a mapping from
|
||||
// 'output node name' -> 'output tensor name'.
|
||||
absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
|
||||
|
||||
// Hook inlined function inputs to IdentityN node.
|
||||
NodeDef* func_inputs = optimized_graph->add_node();
|
||||
*func_inputs = InlinedFunctionInputsNode(func_node, item);
|
||||
@ -1051,18 +1058,22 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
|
||||
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
|
||||
const string& node_name = func_body_node.name();
|
||||
|
||||
// Skip function output nodes.
|
||||
if (IsRetval(func_body_node)) continue;
|
||||
// Skip output identity node, and update a mapping to the output tensor.
|
||||
if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
|
||||
output_tensors.emplace(node_name, func_body_node.input(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Turn _Arg nodes added in place of input arguments into identity nodes.
|
||||
const auto input_arg_idx = input_args_idx.find(node_name);
|
||||
if (input_arg_idx != input_args_idx.end()) {
|
||||
// Turn placeholders added in place of input arguments into identity nodes.
|
||||
const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
|
||||
if (input_placeholder_idx != input_placeholders_idx.end()) {
|
||||
CHECK_EQ(0, func_body_node.input_size());
|
||||
func_body_node.set_op("Identity");
|
||||
func_body_node.mutable_attr()->erase("index");
|
||||
(*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
|
||||
func_body_node.mutable_attr()->erase("dtype");
|
||||
func_body_node.mutable_attr()->erase("shape");
|
||||
func_body_node.add_input(
|
||||
strings::StrCat(func_inputs->name(), ":", input_arg_idx->second));
|
||||
func_body_node.add_input(strings::StrCat(func_inputs->name(), ":",
|
||||
input_placeholder_idx->second));
|
||||
} else {
|
||||
// Update the input names if any.
|
||||
for (string& input : *func_body_node.mutable_input()) {
|
||||
@ -1335,8 +1346,10 @@ Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
|
||||
|
||||
// Names of the function body nodes that return function output values.
|
||||
absl::flat_hash_set<absl::string_view> output_nodes;
|
||||
for (const auto& output_arg : item.outputs()) {
|
||||
output_nodes.insert(output_arg.node_name);
|
||||
for (const auto& output_expansion : item.outputs()) {
|
||||
for (const auto& output_node : output_expansion.output_nodes) {
|
||||
output_nodes.insert(output_node);
|
||||
}
|
||||
}
|
||||
|
||||
GraphTopologyView topology_view;
|
||||
@ -1417,10 +1430,7 @@ Status CheckThatSideEffectsWillExecute(
|
||||
// can't produce any visible side-effects.
|
||||
const bool read_only = IsReadVariableOp(func_body_node);
|
||||
|
||||
// _Retval marked as stateful, but we will remove it before inlining.
|
||||
const bool retval = IsRetval(func_body_node);
|
||||
|
||||
if (read_only || retval || !node_must_execute) continue;
|
||||
if (read_only || !node_must_execute) continue;
|
||||
|
||||
VLOG(3) << "Check that node " << func_body_node.name()
|
||||
<< " will execute after inlining.";
|
||||
@ -1460,7 +1470,7 @@ Status CheckThatSideEffectsWillExecute(
|
||||
|
||||
Status PlaceInlinedFunctionBody(
|
||||
const NodeDef& func_node, const GrapplerFunctionItem& item,
|
||||
const absl::flat_hash_map<absl::string_view, int>& input_args_idx,
|
||||
const absl::flat_hash_map<absl::string_view, int>& input_placeholders_idx,
|
||||
FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
|
||||
// Control flow lowering and Placer works with a Graph object.
|
||||
std::unique_ptr<Graph> func_body_graph =
|
||||
@ -1488,14 +1498,15 @@ Status PlaceInlinedFunctionBody(
|
||||
TF_RETURN_IF_ERROR(pass.Run(opt_options));
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Before placing the function body nodes we pin input arguments to the
|
||||
// Before placing the function body nodes we pin input placeholders to the
|
||||
// same device as their corresponding input nodes.
|
||||
|
||||
for (Node* func_body_node : func_body_graph->nodes()) {
|
||||
const auto input_arg_idx = input_args_idx.find(func_body_node->name());
|
||||
const auto input_placeholder_idx =
|
||||
input_placeholders_idx.find(func_body_node->name());
|
||||
|
||||
if (input_arg_idx != input_args_idx.end()) {
|
||||
const int input_idx = input_arg_idx->second;
|
||||
if (input_placeholder_idx != input_placeholders_idx.end()) {
|
||||
const int input_idx = input_placeholder_idx->second;
|
||||
const GraphView::OutputPort output_port =
|
||||
ctx->graph_view().GetRegularFanin({&func_node, input_idx});
|
||||
|
||||
@ -1620,26 +1631,45 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
inputs.push_back(tensor_id);
|
||||
}
|
||||
|
||||
// Mapping from input argument node to function input position.
|
||||
absl::flat_hash_map<absl::string_view, int> input_args_idx;
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
const int idx = input_args_idx.size();
|
||||
input_args_idx[input_arg.node_name] = idx;
|
||||
// Mapping from input placeholder name to function input position.
|
||||
absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
|
||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||
for (const string& placeholder : input_arg.placeholders) {
|
||||
const int idx = input_placeholders_idx.size();
|
||||
input_placeholders_idx[placeholder] = idx;
|
||||
}
|
||||
}
|
||||
|
||||
const string prefix = strings::StrCat(func_node.name(), "/");
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Mapping from the '_Retval' node name to the output tensor.
|
||||
absl::flat_hash_map<absl::string_view, string> output_tensors;
|
||||
// For each function output value we added an identity node that reads the
|
||||
// tensor from one of the function body nodes. When we inline function into
|
||||
// the main graph we want to bypass these nodes, so we keep a mapping from
|
||||
// 'output node name' -> 'output tensor name'.
|
||||
absl::flat_hash_map<string, string> output_tensors;
|
||||
|
||||
for (const NodeDef& func_body_node : item.function_body().node()) {
|
||||
if (!IsRetval(func_body_node)) continue;
|
||||
if (func_body_node.input_size() != 1) {
|
||||
return errors::Internal("_Retval node must have single input: ",
|
||||
SummarizeNodeDef(func_body_node));
|
||||
// Unique names of nodes producing tensors in `output_tensors`.
|
||||
absl::flat_hash_set<string> output_tensors_nodes;
|
||||
|
||||
// Identity nodes added to the function body in place of function outputs.
|
||||
absl::flat_hash_set<string> output_nodes;
|
||||
for (const OutputArgExpansion& output_arg : item.outputs()) {
|
||||
for (const string& output_node : output_arg.output_nodes) {
|
||||
output_nodes.insert(output_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (const NodeDef& func_body_node : item.graph.node()) {
|
||||
const string& node_name = func_body_node.name();
|
||||
|
||||
if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
|
||||
const string& output_tensor = func_body_node.input(0);
|
||||
output_tensors.emplace(node_name, output_tensor);
|
||||
|
||||
SafeTensorId tensor_id = ParseTensorName(output_tensor);
|
||||
output_tensors_nodes.insert(tensor_id.node());
|
||||
}
|
||||
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
@ -1713,8 +1743,8 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
// make sure that after inlining all nodes will have valid device assignment.
|
||||
|
||||
GraphDef placed_graph_def;
|
||||
TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(func_node, item, input_args_idx,
|
||||
ctx, &placed_graph_def));
|
||||
TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(
|
||||
func_node, item, input_placeholders_idx, ctx, &placed_graph_def));
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// After all nodes placed we need to prepare them for inlining into the
|
||||
@ -1728,14 +1758,15 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
|
||||
const string& node_name = func_body_node.name();
|
||||
|
||||
// Turn _Arg nodes added in place of input arguments into identity nodes.
|
||||
const auto input_arg_idx = input_args_idx.find(node_name);
|
||||
if (input_arg_idx != input_args_idx.end()) {
|
||||
// Turn placeholders added in place of input arguments into identity nodes.
|
||||
const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
|
||||
if (input_placeholder_idx != input_placeholders_idx.end()) {
|
||||
DCHECK_EQ(0, func_body_node.input_size());
|
||||
func_body_node.set_op("Identity");
|
||||
func_body_node.mutable_attr()->erase("index");
|
||||
(*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
|
||||
func_body_node.mutable_attr()->erase("dtype");
|
||||
func_body_node.mutable_attr()->erase("shape");
|
||||
const int input_idx = input_arg_idx->second;
|
||||
const int input_idx = input_placeholder_idx->second;
|
||||
func_body_node.add_input(inputs[input_idx].ToString());
|
||||
|
||||
// Add a control dependency on 'inputs_ready' node, to guarantee that all
|
||||
@ -1788,7 +1819,17 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Check that after inlining all side-effects will be executed in well defined
|
||||
// order. We do it by checking if there is a path from stateful/dataset ops to
|
||||
// one of the control output nodes.
|
||||
// one of the output nodes.
|
||||
|
||||
// Because we rename all the nodes before inlining, we need a copy of
|
||||
// output_nodes with a new names.
|
||||
absl::flat_hash_set<string> inlined_output_nodes;
|
||||
for (const string& output_node : output_nodes) {
|
||||
inlined_output_nodes.insert(inlined_node_name(output_node));
|
||||
}
|
||||
const auto is_inlined_output_node = [&](const NodeDef& node) -> bool {
|
||||
return inlined_output_nodes.find(node.name()) != inlined_output_nodes.end();
|
||||
};
|
||||
|
||||
// Names of the inlined control output nodes.
|
||||
absl::flat_hash_set<string> inlined_control_output_nodes;
|
||||
@ -1844,8 +1885,10 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
}
|
||||
|
||||
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
|
||||
// We bypass _Retval nodes and fetch tensors from `retval.input(0)`.
|
||||
if (IsRetval(func_body_node)) continue;
|
||||
// Skip output identity nodes.
|
||||
if (IsIdentity(func_body_node) && is_inlined_output_node(func_body_node))
|
||||
continue;
|
||||
|
||||
optimized_graph->add_node()->Swap(&func_body_node);
|
||||
}
|
||||
|
||||
@ -1853,17 +1896,19 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
// not copy the original function call node, so we have to setup tensor
|
||||
// mapping from old output tensors, to the outputs of inlined nodes.
|
||||
int output_idx = 0;
|
||||
for (const OutputArgInstantiation& output : item.outputs()) {
|
||||
const string& output_tensor = output_tensors.at(output.node_name);
|
||||
for (const OutputArgExpansion& output : item.outputs()) {
|
||||
for (const string& output_node : output.output_nodes) {
|
||||
const string& output_tensor = output_tensors.at(output_node);
|
||||
|
||||
const SafeTensorId from_tensor(func_node.name(), output_idx++);
|
||||
const SafeTensorId to_tensor = ParseTensorName(output_tensor);
|
||||
const SafeTensorId from_tensor(func_node.name(), output_idx++);
|
||||
const SafeTensorId to_tensor = ParseTensorName(output_tensor);
|
||||
|
||||
const SafeTensorId inlined_to_tensor =
|
||||
SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
|
||||
to_tensor.index());
|
||||
const SafeTensorId inlined_to_tensor =
|
||||
SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
|
||||
to_tensor.index());
|
||||
|
||||
ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
|
||||
ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
// If function call node was in keep_ops set, it means that we need to keep a
|
||||
|
||||
@ -123,7 +123,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SkipErrorsIfGraphNotModified) {
|
||||
// Standard XTimesTwo() function.
|
||||
FunctionDef x_times_two = test::function::XTimesTwo();
|
||||
|
||||
// Function signature has non-type attribute (currently not supported).
|
||||
// Function with sequence of tensors as an input (currently not supported).
|
||||
FunctionDef my_identity_n = FunctionDefHelper::Create(
|
||||
// Name
|
||||
"MyIdentityN",
|
||||
|
||||
@ -367,8 +367,8 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
|
||||
if (node.name() == "my_mul/inlined_inputs" && ++count) {
|
||||
EXPECT_EQ("IdentityN", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("x", node.input(1));
|
||||
EXPECT_EQ("x:0", node.input(0));
|
||||
EXPECT_EQ("x:0", node.input(1));
|
||||
} else if (node.name() == "my_mul/x" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
@ -623,17 +623,17 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
|
||||
MetaOptimizer optimizer(nullptr, config_proto);
|
||||
|
||||
// Define simple function library with two identical mul functions.
|
||||
FunctionDef mul_func_1 = FunctionDefHelper::Create(
|
||||
"MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
|
||||
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
FunctionDef mul_func_1 =
|
||||
FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
|
||||
{}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
|
||||
FunctionDef mul_func_2 = FunctionDefHelper::Create(
|
||||
"MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
|
||||
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
FunctionDef mul_func_2 =
|
||||
FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
|
||||
{}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
|
||||
// Tensorflow graph:
|
||||
//
|
||||
|
||||
@ -172,7 +172,6 @@ cc_library(
|
||||
hdrs = ["functions.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
||||
@ -17,9 +17,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -36,29 +34,306 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
|
||||
const NodeDef& node,
|
||||
GrapplerFunctionConnectivity* connectivity) {
|
||||
tensorflow::NameRangeMap outputs_range_map;
|
||||
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
|
||||
node, registration.op_def, nullptr, &outputs_range_map));
|
||||
connectivity->RegisterFunctionBodyOutputs(node.name(),
|
||||
std::move(outputs_range_map));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
|
||||
const NodeDef& node,
|
||||
GrapplerFunctionConnectivity* connectivity) {
|
||||
const OpRegistrationData* registration;
|
||||
TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration));
|
||||
return RegisterFunctionBodyOutputs(*registration, node, connectivity);
|
||||
}
|
||||
|
||||
// Replace the placeholder attribute values with the values specified in
|
||||
// instantiation attributes.
|
||||
Status ResolveFunctionBodyNodeAttrPlaceholders(
|
||||
const AttrSlice& func_instantiation_attr, NodeDef* node) {
|
||||
for (auto& attr : *node->mutable_attr()) {
|
||||
const string& placeholder = attr.second.placeholder();
|
||||
if (placeholder.empty()) continue;
|
||||
|
||||
const AttrValue* attr_value = func_instantiation_attr.Find(placeholder);
|
||||
if (attr_value) {
|
||||
attr.second = *attr_value;
|
||||
} else {
|
||||
return errors::InvalidArgument("Can't resolve placeholder: ",
|
||||
placeholder);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
|
||||
InputArgExpansion input_arg_expansion) {
|
||||
string input_name = input_arg_expansion.input_name;
|
||||
const auto& placeholders = input_arg_expansion.placeholders;
|
||||
|
||||
for (int i = 0; i < placeholders.size(); ++i) {
|
||||
const string& placeholder = input_arg_expansion.placeholders[i];
|
||||
input_arg_placeholders_.insert(
|
||||
{placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
|
||||
}
|
||||
input_arg_expansions_.insert(
|
||||
{std::move(input_name), std::move(input_arg_expansion)});
|
||||
}
|
||||
|
||||
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
|
||||
const string& node_name, tensorflow::NameRangeMap&& outputs) {
|
||||
function_body_outputs_[node_name] = std::move(outputs);
|
||||
}
|
||||
|
||||
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
|
||||
const string& func_def_input, std::vector<string>* graph_def_inputs) const {
|
||||
using ::tensorflow::strings::Scanner;
|
||||
|
||||
if (IsControlInput(func_def_input)) {
|
||||
graph_def_inputs->push_back(func_def_input);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Parse input format: "node_name[:node_output][:position]"
|
||||
string node_name;
|
||||
string node_output;
|
||||
int position = -1;
|
||||
|
||||
StringPiece capture;
|
||||
StringPiece remaining;
|
||||
|
||||
// Parse "node_name"
|
||||
if (Scanner(func_def_input)
|
||||
.One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
|
||||
.Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
|
||||
.GetResult(&remaining, &capture)) {
|
||||
node_name = string(capture.data(), capture.size());
|
||||
}
|
||||
|
||||
// Parse "node_output" if it exists
|
||||
if (Scanner(remaining)
|
||||
.OneLiteral(":")
|
||||
.RestartCapture()
|
||||
.One(strings::Scanner::LETTER)
|
||||
.Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
|
||||
.GetResult(&remaining, &capture)) {
|
||||
node_output = string(capture.data(), capture.size());
|
||||
}
|
||||
|
||||
// Parse "position" if it exists
|
||||
if (Scanner(remaining)
|
||||
.OneLiteral(":")
|
||||
.RestartCapture()
|
||||
.Many(strings::Scanner::DIGIT)
|
||||
.GetResult(nullptr, &capture)) {
|
||||
CHECK(strings::safe_strto32(capture, &position));
|
||||
}
|
||||
|
||||
// If "node_output" is not empty, it must be an output of a function body node
|
||||
bool is_function_body_output = !node_output.empty();
|
||||
|
||||
// Function input argument: "node_name[:position]"
|
||||
if (!is_function_body_output) {
|
||||
auto input_arg = input_arg_expansions_.find(node_name);
|
||||
if (input_arg != input_arg_expansions_.end()) {
|
||||
const InputArgExpansion& input_arg_expansion = input_arg->second;
|
||||
const auto& placeholders = input_arg_expansion.placeholders;
|
||||
|
||||
if (position == -1) {
|
||||
// If position is not defined use all placeholders
|
||||
graph_def_inputs->reserve(placeholders.size());
|
||||
for (const string& placeholder : placeholders) {
|
||||
graph_def_inputs->push_back(placeholder);
|
||||
}
|
||||
} else {
|
||||
if (position > input_arg_expansion.placeholders.size() - 1) {
|
||||
return errors::InvalidArgument("Invalid input ", node_name,
|
||||
"position: ", position,
|
||||
" (out of range)");
|
||||
}
|
||||
graph_def_inputs->push_back(input_arg_expansion.placeholders[position]);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Function body output: "node_name:node_output[:position]"
|
||||
if (is_function_body_output) {
|
||||
auto function_body_outputs = function_body_outputs_.find(node_name);
|
||||
if (function_body_outputs != function_body_outputs_.end()) {
|
||||
const tensorflow::NameRangeMap& outputs = function_body_outputs->second;
|
||||
auto output = outputs.find(node_output);
|
||||
if (output != outputs.end()) {
|
||||
const auto& output_range = output->second;
|
||||
|
||||
if (position == -1) {
|
||||
graph_def_inputs->reserve(graph_def_inputs->size() +
|
||||
output_range.second - output_range.first);
|
||||
// If position is not defined expand node output range
|
||||
for (int i = output_range.first; i < output_range.second; ++i) {
|
||||
graph_def_inputs->push_back(
|
||||
i == 0 ? node_name : absl::StrCat(node_name, ":", i));
|
||||
}
|
||||
} else {
|
||||
if (position > (output_range.second - output_range.first)) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid node ", node_name, " output ", node_output,
|
||||
" position: ", position, " (out of range)");
|
||||
}
|
||||
int pos = output_range.first + position;
|
||||
graph_def_inputs->push_back(
|
||||
pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors::InvalidArgument("Failed to expand a function def input: ",
|
||||
func_def_input);
|
||||
}
|
||||
|
||||
Status GrapplerFunctionConnectivity::ExpandNodeInputs(
|
||||
NodeDef* function_body_node) const {
|
||||
std::vector<string> expanded_inputs;
|
||||
|
||||
for (const string& function_def_input : function_body_node->input()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpandFunctionDefInput(function_def_input, &expanded_inputs));
|
||||
}
|
||||
|
||||
function_body_node->clear_input();
|
||||
for (string& expanded_input : expanded_inputs)
|
||||
function_body_node->add_input(std::move(expanded_input));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrapplerFunctionConnectivity::AsFunctionDefInput(
|
||||
const string& graph_def_input, string* func_def_input) const {
|
||||
if (IsControlInput(graph_def_input)) {
|
||||
*func_def_input = graph_def_input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const TensorId tensor = ParseTensorName(graph_def_input);
|
||||
DCHECK_GE(tensor.index(), 0);
|
||||
|
||||
const absl::string_view node_name = tensor.node();
|
||||
const int index = tensor.index();
|
||||
|
||||
// Check if it's an input arg placeholder
|
||||
if (tensor.index() == 0) {
|
||||
const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
|
||||
if (is_input_placeholder != input_arg_placeholders_.end()) {
|
||||
const InputArgPlaceholder& placeholder = is_input_placeholder->second;
|
||||
*func_def_input =
|
||||
absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// It must be output from one of the function body nodes
|
||||
const auto is_body_output = function_body_outputs_.find(tensor.node());
|
||||
if (is_body_output != function_body_outputs_.end()) {
|
||||
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
|
||||
|
||||
for (const auto& el : outputs_range_map) {
|
||||
const auto& output_name = el.first;
|
||||
const auto& output_range = el.second;
|
||||
if (index >= output_range.first && index < output_range.second) {
|
||||
int pos = index - output_range.first;
|
||||
*func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
|
||||
}
|
||||
|
||||
Status GrapplerFunctionConnectivity::AsFunctionDefNode(
|
||||
NodeDef* function_body_node) const {
|
||||
string func_def_input;
|
||||
|
||||
for (int i = 0; i < function_body_node->input_size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsFunctionDefInput(function_body_node->input(i), &func_def_input));
|
||||
function_body_node->set_input(i, func_def_input);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrapplerFunctionItemInstantiation::GetTypeAttr(
|
||||
const string& type_attr_name, DataType* data_type) const {
|
||||
const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name);
|
||||
if (type_attr == nullptr) {
|
||||
return errors::InvalidArgument("Type attribute ", type_attr_name,
|
||||
" is not defined");
|
||||
} else if (type_attr->type() == DT_INVALID) {
|
||||
return errors::InvalidArgument("Type attribute ", type_attr_name,
|
||||
" is not defined with a valid type");
|
||||
} else {
|
||||
*data_type = type_attr->type();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrapplerFunctionItemInstantiation::GetArgType(
|
||||
const OpDef::ArgDef& arg, DataType* data_type) const {
|
||||
if (arg.type() != DT_INVALID) {
|
||||
*data_type = arg.type();
|
||||
} else {
|
||||
if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Arguments with sequence of tensors are not supported. Unsupported "
|
||||
"argument name: ",
|
||||
arg.name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GrapplerFunctionItem::GrapplerFunctionItem(
|
||||
string func_name, string description, AttrSlice func_attr,
|
||||
std::vector<InputArgInstantiation> input_args,
|
||||
std::vector<OutputArgInstantiation> output_args,
|
||||
std::vector<InputArgExpansion> input_arg_expansions,
|
||||
std::vector<OutputArgExpansion> output_arg_expansions,
|
||||
std::vector<ControlOutput> control_outputs, const int graph_def_version,
|
||||
const bool is_stateful, GraphDef&& function_body)
|
||||
: description_(std::move(description)),
|
||||
func_attr_(func_attr),
|
||||
input_args_(std::move(input_args)),
|
||||
output_args_(std::move(output_args)),
|
||||
input_arg_expansions_(std::move(input_arg_expansions)),
|
||||
output_arg_expansions_(std::move(output_arg_expansions)),
|
||||
control_outputs_(std::move(control_outputs)),
|
||||
is_stateful_(is_stateful) {
|
||||
id = std::move(func_name);
|
||||
graph = std::move(function_body);
|
||||
graph.mutable_versions()->set_producer(graph_def_version);
|
||||
|
||||
// Fill the feed nodes with function input arguments.
|
||||
for (const InputArgInstantiation& input_arg : input_args_) {
|
||||
feed.push_back({input_arg.node_name, Tensor()});
|
||||
graph.mutable_versions()->set_producer(graph_def_version);
|
||||
// Fill the feed nodes with input placeholders.
|
||||
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
|
||||
for (const string& placeholder : input_arg.placeholders) {
|
||||
feed.push_back({placeholder, Tensor()});
|
||||
}
|
||||
}
|
||||
// Fill the fetch nodes with outputs.
|
||||
for (const OutputArgInstantiation& output_arg : output_args_) {
|
||||
fetch.push_back(output_arg.node_name);
|
||||
for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
|
||||
for (const string& output_node : output_arg.output_nodes) {
|
||||
fetch.push_back(output_node);
|
||||
}
|
||||
}
|
||||
// We must keep all control output nodes.
|
||||
for (const ControlOutput& control_output : control_outputs_) {
|
||||
@ -72,29 +347,28 @@ GrapplerFunctionItem::GrapplerFunctionItem(
|
||||
|
||||
const string& GrapplerFunctionItem::description() const { return description_; }
|
||||
|
||||
const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
|
||||
return input_args_;
|
||||
const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
|
||||
return input_arg_expansions_;
|
||||
}
|
||||
|
||||
const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
|
||||
return input_args_[i];
|
||||
const InputArgExpansion& GrapplerFunctionItem::input(int i) const {
|
||||
return input_arg_expansions_[i];
|
||||
}
|
||||
|
||||
const std::size_t GrapplerFunctionItem::input_size() const {
|
||||
return input_args_.size();
|
||||
return input_arg_expansions_.size();
|
||||
}
|
||||
|
||||
const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
|
||||
const {
|
||||
return output_args_;
|
||||
const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
|
||||
return output_arg_expansions_;
|
||||
}
|
||||
|
||||
const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
|
||||
return output_args_[i];
|
||||
const OutputArgExpansion& GrapplerFunctionItem::output(int i) const {
|
||||
return output_arg_expansions_[i];
|
||||
}
|
||||
|
||||
const std::size_t GrapplerFunctionItem::output_size() const {
|
||||
return output_args_.size();
|
||||
return output_arg_expansions_.size();
|
||||
}
|
||||
|
||||
const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
|
||||
@ -153,23 +427,15 @@ Status InstantiationTypeParameters(
|
||||
return errors::InvalidArgument("Type parameters output map must be empty");
|
||||
}
|
||||
|
||||
const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
|
||||
if (!arg.type_attr().empty()) {
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
|
||||
type_parameters->emplace(arg.type_attr(), dtype);
|
||||
GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
|
||||
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
std::vector<DataType> dtypes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
|
||||
int index = 0;
|
||||
for (const DataType& dtype : dtypes) {
|
||||
type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
|
||||
dtype);
|
||||
++index;
|
||||
}
|
||||
const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
|
||||
// Check if it's unknown and unresolved type.
|
||||
if (arg.type() == DT_INVALID &&
|
||||
type_parameters->find(arg.type_attr()) == type_parameters->end()) {
|
||||
DataType data_type;
|
||||
TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
|
||||
type_parameters->insert({arg.type_attr(), data_type});
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
@ -193,7 +459,8 @@ Status InstantiationBodyParameters(
|
||||
for (auto& attr : func_body_node.attr()) {
|
||||
const string& placeholder = attr.second.placeholder();
|
||||
|
||||
if (placeholder.empty() || body_parameters->contains(placeholder)) {
|
||||
if (placeholder.empty() ||
|
||||
body_parameters->find(placeholder) != body_parameters->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -231,13 +498,15 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate function into a statically defined FunctionBody Graph.
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
|
||||
// Helper methods to lookup function instantiation attributes
|
||||
GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
|
||||
|
||||
// Mapping from FunctionDef input format (name[:output][:position]) to
|
||||
// GraphDef input format (name[:position])
|
||||
GrapplerFunctionConnectivity connectivity;
|
||||
|
||||
// Instantiate function body into a statically defined graph def.
|
||||
GraphDef function_body;
|
||||
fbody->graph->ToGraphDef(&function_body);
|
||||
|
||||
// Function body shares the library with the graph that instantiated it. We do
|
||||
// not need a full copy of the function library, just the reachable subset.
|
||||
@ -249,25 +518,122 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
||||
flib.num_functions() - function_body.library().function_size(),
|
||||
signature.name(), function_body.library().function_size());
|
||||
|
||||
const int num_instantiated_inputs = fbody->arg_types.size();
|
||||
const int num_instantiated_outputs = fbody->ret_types.size();
|
||||
// TODO(ezhulenev): support functions with tensor sequence inputs/outputs
|
||||
|
||||
std::vector<InputArgInstantiation> inputs;
|
||||
inputs.reserve(num_instantiated_inputs);
|
||||
|
||||
for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
|
||||
const Node* node = fbody->arg_nodes[in_id];
|
||||
const DataType& dtype = fbody->arg_types[in_id];
|
||||
inputs.emplace_back(node->name(), dtype);
|
||||
// Make sure that there are no tensor lists in inputs or outputs.
|
||||
for (const OpDef::ArgDef& input : signature.input_arg()) {
|
||||
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Inputs with lists of tensors are not supported. Input: ",
|
||||
input.name());
|
||||
}
|
||||
}
|
||||
for (const OpDef::ArgDef& output : signature.output_arg()) {
|
||||
if (!output.type_list_attr().empty() || !output.number_attr().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Outputs with lists of tensors are not supported. Output: ",
|
||||
output.name());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<OutputArgInstantiation> outputs;
|
||||
outputs.reserve(num_instantiated_outputs);
|
||||
std::vector<InputArgExpansion> inputs;
|
||||
inputs.reserve(signature.input_arg_size());
|
||||
|
||||
for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
|
||||
const Node* node = fbody->ret_nodes[out_id];
|
||||
const DataType& dtype = fbody->ret_types[out_id];
|
||||
outputs.emplace_back(node->name(), dtype);
|
||||
// For each input argument create a placeholder in function body.
|
||||
for (const OpDef::ArgDef& input : signature.input_arg()) {
|
||||
DataType input_data_type;
|
||||
TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
|
||||
|
||||
NodeDef* placeholder = function_body.add_node();
|
||||
placeholder->set_name(input.name());
|
||||
placeholder->set_op("Placeholder");
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
|
||||
(*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
|
||||
true);
|
||||
|
||||
InputArgExpansion input_expansion{/*input_name=*/input.name(),
|
||||
/*data_type=*/input_data_type,
|
||||
/*is_ref=*/input.is_ref(),
|
||||
/*placeholders=*/{input.name()}};
|
||||
connectivity.RegisterInputArgExpansion(input_expansion);
|
||||
inputs.push_back(std::move(input_expansion));
|
||||
}
|
||||
|
||||
// Keep names of all nodes in the function body to guarantee that we do not
|
||||
// add an identity with a duplicate name.
|
||||
absl::flat_hash_set<absl::string_view> func_body_nodes;
|
||||
|
||||
// Generate unique output node name: "${out_arg_name}_output_node_${index}".
|
||||
const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out,
|
||||
int index) -> string {
|
||||
string name = absl::StrCat(out.name(), "_output_node_", index);
|
||||
int i = 1;
|
||||
while (func_body_nodes.find(name) != func_body_nodes.end()) {
|
||||
name = absl::StrCat(out.name(), "_output_node_", index, "_", i++);
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
// Add all function nodes to the function body.
|
||||
for (const NodeDef& func_def_node : func.node_def()) {
|
||||
func_body_nodes.insert(func_def_node.name());
|
||||
|
||||
NodeDef* new_node = function_body.add_node();
|
||||
*new_node = func_def_node;
|
||||
|
||||
const OpRegistrationData* registration;
|
||||
TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), ®istration));
|
||||
|
||||
// Resolve all placeholder values using function instantiation attributes.
|
||||
TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
|
||||
func_instantiation_attr, new_node));
|
||||
|
||||
// Register node output range in a function connectivity.
|
||||
TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
|
||||
&connectivity));
|
||||
}
|
||||
|
||||
// Rewrite inputs to use GraphDef format
|
||||
for (NodeDef& node : *function_body.mutable_node()) {
|
||||
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
|
||||
}
|
||||
|
||||
std::vector<OutputArgExpansion> outputs;
|
||||
outputs.reserve(signature.output_arg_size());
|
||||
|
||||
// For each function output argument we create an Identity node in the
|
||||
// function body, that reads output tensor from the function body node.
|
||||
for (const OpDef::ArgDef& out : signature.output_arg()) {
|
||||
DataType output_data_type;
|
||||
TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
|
||||
|
||||
std::vector<string> output_tensors;
|
||||
auto ret = func.ret().find(out.name());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ret != func.ret().end()
|
||||
// Expand outputs using provided output mapping
|
||||
? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
|
||||
// Otherwise output must be one of the function inputs
|
||||
: connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
|
||||
|
||||
absl::InlinedVector<string, 1> output_nodes;
|
||||
for (int i = 0; i < output_tensors.size(); ++i) {
|
||||
const string& output_tensor = output_tensors[i];
|
||||
|
||||
NodeDef* identity = function_body.add_node();
|
||||
identity->set_name(output_node_name(out, i));
|
||||
identity->set_op("Identity");
|
||||
(*identity->mutable_attr())["T"].set_type(output_data_type);
|
||||
identity->add_input(output_tensor);
|
||||
|
||||
output_nodes.push_back(identity->name());
|
||||
}
|
||||
|
||||
OutputArgExpansion output{/*output_name=*/out.name(),
|
||||
/*data_type=*/output_data_type,
|
||||
/*is_ref=*/out.is_ref(),
|
||||
/*output_nodes=*/std::move(output_nodes)};
|
||||
outputs.push_back(std::move(output));
|
||||
}
|
||||
|
||||
// Control outputs ensure that all side-effectful nodes in the function body
|
||||
@ -295,42 +661,70 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
||||
item);
|
||||
}
|
||||
|
||||
// Register GrapplerFunctionItem input arg expansion and function body outputs
|
||||
// in the GrapplerFunctionConnectivity.
|
||||
Status RegisterGrapplerFunctionConnectivity(
|
||||
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
|
||||
GrapplerFunctionConnectivity* connectivity) {
|
||||
for (const InputArgExpansion& input : item.inputs()) {
|
||||
connectivity->RegisterInputArgExpansion(input);
|
||||
}
|
||||
for (const NodeDef& func_body_node : item.function_body().node()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
|
||||
GrapplerFunctionItem* item) {
|
||||
if (!IsConstant(input_const)) {
|
||||
return errors::InvalidArgument("Input node is not a constant: ",
|
||||
SummarizeNodeDef(input_const));
|
||||
}
|
||||
if (input_index < 0 || input_index >= item->input_size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Function input index is out of bound: index=", input_index,
|
||||
" input_size=", item->input_size());
|
||||
return errors::InvalidArgument("Input node ", input_const.name(),
|
||||
" is not a constant");
|
||||
}
|
||||
|
||||
const InputArgInstantiation& input_arg = item->input(input_index);
|
||||
auto& inputs = item->input_arg_expansions_;
|
||||
|
||||
// Find input arg expansion and input placeholder position in it for the
|
||||
// given function input position.
|
||||
InputArgExpansion* input_arg_expansion = nullptr;
|
||||
int placeholder_idx = input_index;
|
||||
|
||||
for (InputArgExpansion& input : inputs) {
|
||||
if (placeholder_idx < input.placeholders.size()) {
|
||||
input_arg_expansion = &input;
|
||||
break;
|
||||
}
|
||||
placeholder_idx -= input.placeholders.size();
|
||||
}
|
||||
|
||||
if (input_arg_expansion == nullptr) {
|
||||
return errors::InvalidArgument("Input placeholder not found: input_index=",
|
||||
input_index, " function=", item->id);
|
||||
}
|
||||
|
||||
// Delete placeholder from input expansion.
|
||||
string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
|
||||
input_arg_expansion->placeholders.erase(
|
||||
input_arg_expansion->placeholders.begin() + placeholder_idx);
|
||||
|
||||
// Delete empty input expansions.
|
||||
inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
|
||||
[](const InputArgExpansion& input) {
|
||||
return input.placeholders.empty();
|
||||
}),
|
||||
inputs.end());
|
||||
|
||||
// Replace placeholder node in the function body with a const node.
|
||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||
// Replace '_Arg' node in the function body with a 'Const' node.
|
||||
if (node.name() == input_arg.node_name) {
|
||||
if (node.name() == placeholder_name) {
|
||||
node = input_const;
|
||||
node.set_name(input_arg.node_name);
|
||||
node.clear_input();
|
||||
node.set_name(placeholder_name);
|
||||
node.clear_input(); // remove potential control inputs
|
||||
node.clear_device(); // device placement is defined by instantiating node
|
||||
}
|
||||
|
||||
// Update index in all inputs after the removed const input.
|
||||
if (IsArg(node)) {
|
||||
auto attrs = AttrSlice(node);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
|
||||
if (index >= input_index) {
|
||||
(*node.mutable_attr())["index"].set_i(index - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
item->input_args_.erase(item->input_args_.begin() + input_index);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -339,24 +733,31 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
|
||||
std::vector<std::pair<int, int>>* output_mapping) {
|
||||
DCHECK(output_mapping->empty());
|
||||
|
||||
// Code below assumes that we do not support tensor list outputs and there is
|
||||
// a 1-to-1 mapping between output tensor and output argument expansion.
|
||||
for (const OutputArgExpansion& out_arg : item->outputs()) {
|
||||
DCHECK(out_arg.output_nodes.size() == 1)
|
||||
<< "Output arg expansion must have single output";
|
||||
}
|
||||
|
||||
// Do some sanity checking of the removed outputs positions.
|
||||
for (int remove_output : remove_outputs) {
|
||||
if (remove_output < 0 || remove_output >= item->output_size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Function output index is out of bound: index=", remove_output,
|
||||
" output_size=", item->output_size());
|
||||
" max_output_index=", item->output_size());
|
||||
}
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
|
||||
const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
|
||||
absl::flat_hash_set<const OutputArgExpansion*> remove_output_args;
|
||||
const auto is_remove_output_arg = [&](const OutputArgExpansion& output) {
|
||||
return remove_output_args.find(&output) != remove_output_args.end();
|
||||
};
|
||||
|
||||
for (int i = 0; i < item->output_size(); ++i) {
|
||||
const OutputArgInstantiation& output = item->output(i);
|
||||
if (remove_outputs.contains(i)) {
|
||||
VLOG(3) << "Remove functions output: name=" << output.node_name
|
||||
const OutputArgExpansion& output = item->output(i);
|
||||
if (remove_outputs.find(i) != remove_outputs.end()) {
|
||||
VLOG(3) << "Remove functions output: output_name=" << output.output_name
|
||||
<< "(index = " << i << ")";
|
||||
remove_output_args.insert(&output);
|
||||
} else if (!remove_output_args.empty()) {
|
||||
@ -365,130 +766,12 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
|
||||
}
|
||||
}
|
||||
|
||||
// Update 'index' attribute in all '_Retval' nodes that are in output mapping.
|
||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||
if (IsRetval(node)) {
|
||||
auto attrs = AttrSlice(node);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
|
||||
|
||||
for (const auto& mapping : *output_mapping) {
|
||||
const int from = mapping.first;
|
||||
const int to = mapping.second;
|
||||
if (index == from) {
|
||||
(*node.mutable_attr())["index"].set_i(to);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& o = item->output_args_;
|
||||
auto& o = item->output_arg_expansions_;
|
||||
o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// FunctionDef uses different connectivity encoding for the function body nodes,
|
||||
// than a GraphDef (see function.proto for details). This is a helper class that
|
||||
// converts inputs in GraphDef format (node[:position]) to the FunctionDef
|
||||
// format (node:output[:position]).
|
||||
class MakeFunctionDefHelper {
|
||||
public:
|
||||
MakeFunctionDefHelper() = default;
|
||||
|
||||
Status Initialize(const GrapplerFunctionItem& item,
|
||||
const FunctionLibraryDefinition& flib);
|
||||
|
||||
// Converts input name from GraphDef format (name[:position]) to the
|
||||
// FunctionDef input format (name[:output][:position]) using registered input
|
||||
// arg instantiations and function body outputs.
|
||||
Status AsFunctionDefInput(const string& graph_def_input,
|
||||
string* func_def_input) const;
|
||||
|
||||
// Updates Node inputs from GraphDef to FunctionDef format.
|
||||
Status AsFunctionDefNode(NodeDef* function_body_node) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_set<absl::string_view> input_nodes_;
|
||||
// Mapping from function body node name to output names range map.
|
||||
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
|
||||
};
|
||||
|
||||
Status MakeFunctionDefHelper::Initialize(
|
||||
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
input_nodes_.insert(input_arg.node_name);
|
||||
}
|
||||
|
||||
for (const NodeDef& node : item.function_body().node()) {
|
||||
const OpRegistrationData* registration;
|
||||
TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration));
|
||||
|
||||
tensorflow::NameRangeMap outputs_range_map;
|
||||
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
|
||||
node, registration->op_def, nullptr, &outputs_range_map));
|
||||
|
||||
function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
|
||||
string* func_def_input) const {
|
||||
if (IsControlInput(graph_def_input)) {
|
||||
*func_def_input = graph_def_input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const SafeTensorId tensor = ParseTensorName(graph_def_input);
|
||||
DCHECK_GE(tensor.index(), 0);
|
||||
|
||||
// Graph def input corresponds to one of the function inputs.
|
||||
const auto is_input = input_nodes_.find(tensor.node());
|
||||
if (is_input != input_nodes_.end()) {
|
||||
DCHECK_EQ(tensor.index(), 0);
|
||||
*func_def_input = tensor.node();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Or it must be output from one of the function body nodes
|
||||
const auto is_body_output = function_body_outputs_.find(tensor.node());
|
||||
if (is_body_output != function_body_outputs_.end()) {
|
||||
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
|
||||
|
||||
for (const auto& el : outputs_range_map) {
|
||||
const auto& output_name = el.first;
|
||||
const auto& output_range = el.second;
|
||||
if (tensor.index() >= output_range.first &&
|
||||
tensor.index() < output_range.second) {
|
||||
*func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
|
||||
tensor.index() - output_range.first);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
|
||||
}
|
||||
|
||||
Status MakeFunctionDefHelper::AsFunctionDefNode(
|
||||
NodeDef* function_body_node) const {
|
||||
string func_def_input;
|
||||
|
||||
for (int i = 0; i < function_body_node->input_size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsFunctionDefInput(function_body_node->input(i), &func_def_input));
|
||||
function_body_node->set_input(i, func_def_input);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MakeFunctionDef(const GrapplerFunctionItem& item,
|
||||
const FunctionLibraryDefinition& flib,
|
||||
FunctionDef* func) {
|
||||
@ -496,55 +779,86 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
|
||||
func->mutable_signature()->set_description(item.description());
|
||||
func->mutable_signature()->set_is_stateful(item.is_stateful());
|
||||
|
||||
MakeFunctionDefHelper helper;
|
||||
TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
|
||||
|
||||
// Keep track of '_Arg' nodes that were added to the graph in place of
|
||||
// instantiated function input arguments.
|
||||
absl::flat_hash_set<absl::string_view> input_nodes;
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
input_nodes.insert(input_arg.node_name);
|
||||
// Keep track of placeholders that were added to the graph in place of
|
||||
// expanded function input arguments.
|
||||
absl::flat_hash_set<absl::string_view> input_placeholders;
|
||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||
for (const string& placeholder : input_arg.placeholders) {
|
||||
input_placeholders.insert(placeholder);
|
||||
}
|
||||
}
|
||||
|
||||
// Mapping from the '_Retval' node name to the output tensor.
|
||||
// Keep track of identity nodes that were added to the graph in place of
|
||||
// expanded function output arguments.
|
||||
absl::flat_hash_set<absl::string_view> output_nodes;
|
||||
for (const OutputArgExpansion& output_arg : item.outputs()) {
|
||||
for (const string& output_node : output_arg.output_nodes) {
|
||||
output_nodes.insert(output_node);
|
||||
}
|
||||
}
|
||||
|
||||
// If the output identity node was not modified by any optimizer, we can
|
||||
// bypass it and returns the function value from its input.
|
||||
absl::flat_hash_map<absl::string_view, string> output_tensors;
|
||||
for (const NodeDef& func_body_node : item.function_body().node()) {
|
||||
if (!IsRetval(func_body_node)) continue;
|
||||
if (func_body_node.input_size() != 1) {
|
||||
return errors::Internal("_Retval node must have single input: ",
|
||||
SummarizeNodeDef(func_body_node));
|
||||
if (!IsIdentity(func_body_node)) continue;
|
||||
|
||||
const string& node_name = func_body_node.name();
|
||||
if (output_nodes.find(node_name) != output_nodes.end()) {
|
||||
// Grappler optimizers might optimize nodes in the fanin of the output
|
||||
// node, and forward their control dependencies. We can't express control
|
||||
// dependencies in a function signature, so we have to keep the node.
|
||||
if (func_body_node.input_size() == 1) {
|
||||
VLOG(3) << "Bypass function output node: " << node_name << " -> "
|
||||
<< func_body_node.input(0);
|
||||
output_tensors.emplace(node_name, func_body_node.input(0));
|
||||
} else {
|
||||
VLOG(3) << "Keep function output node: " << node_name;
|
||||
}
|
||||
}
|
||||
output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
|
||||
}
|
||||
|
||||
for (const InputArgInstantiation& input_arg : item.inputs()) {
|
||||
// Return output tensor name (input of the output node) if it's safe to bypass
|
||||
// output node, otherwise returns the output node name.
|
||||
const auto output_tensor =
|
||||
[&output_tensors](const OutputArgExpansion& output_arg) -> const string& {
|
||||
const string& output_node = output_arg.output_nodes[0];
|
||||
const auto is_output_tensor = output_tensors.find(output_node);
|
||||
return is_output_tensor == output_tensors.end() ? output_node
|
||||
: is_output_tensor->second;
|
||||
};
|
||||
|
||||
// Build a GrapplerFunctionConnectivity from inputs and new function body.
|
||||
GrapplerFunctionConnectivity connectivity;
|
||||
TF_RETURN_IF_ERROR(
|
||||
RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
|
||||
|
||||
// Add function input arguments.
|
||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||
DCHECK(input_arg.placeholders.size() == 1) // do some sanity checking
|
||||
<< "Inputs of tensor lists are not supported";
|
||||
|
||||
OpDef::ArgDef arg_def;
|
||||
arg_def.set_name(input_arg.node_name);
|
||||
arg_def.set_name(input_arg.input_name);
|
||||
arg_def.set_type(input_arg.data_type);
|
||||
arg_def.set_is_ref(IsRefType(input_arg.data_type));
|
||||
arg_def.set_is_ref(input_arg.is_ref);
|
||||
*func->mutable_signature()->add_input_arg() = arg_def;
|
||||
}
|
||||
|
||||
// Add function output arguments.
|
||||
for (const OutputArgInstantiation& output_arg : item.outputs()) {
|
||||
const string output_name =
|
||||
absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
|
||||
for (const OutputArgExpansion& output_arg : item.outputs()) {
|
||||
DCHECK(output_arg.output_nodes.size() == 1) // do some sanity checking
|
||||
<< "Outputs of tensor lists are not supported";
|
||||
|
||||
OpDef::ArgDef arg_def;
|
||||
arg_def.set_name(output_name);
|
||||
arg_def.set_name(output_arg.output_name);
|
||||
arg_def.set_type(output_arg.data_type);
|
||||
arg_def.set_is_ref(IsRefType(output_arg.data_type));
|
||||
arg_def.set_is_ref(output_arg.is_ref);
|
||||
*func->mutable_signature()->add_output_arg() = arg_def;
|
||||
|
||||
auto it = output_tensors.find(output_arg.node_name);
|
||||
if (it == output_tensors.end()) {
|
||||
return errors::Internal(
|
||||
"Can't find an output tensor for the output node: ",
|
||||
output_arg.node_name);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
|
||||
it->second, &(*func->mutable_ret())[output_name]));
|
||||
TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(
|
||||
output_tensor(output_arg),
|
||||
&(*func->mutable_ret())[output_arg.output_name]));
|
||||
}
|
||||
|
||||
// Add function control outputs.
|
||||
@ -563,12 +877,16 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
|
||||
|
||||
// Copy function body nodes to the FunctionDef and update input format
|
||||
for (const NodeDef& func_node : item.function_body().node()) {
|
||||
// Do not copy input/output nodes.
|
||||
if (IsArg(func_node) || IsRetval(func_node)) continue;
|
||||
const string& name = func_node.name();
|
||||
|
||||
// Do not copy input placeholders.
|
||||
if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue;
|
||||
// Do not copy output nodes that we bypassed.
|
||||
if (IsIdentity(func_node) && output_tensors.count(name)) continue;
|
||||
|
||||
NodeDef* func_def_node = func->add_node_def();
|
||||
*func_def_node = func_node;
|
||||
TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
|
||||
TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
@ -33,22 +33,45 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Function input argument instantiated into an '_Arg' node in the function body
|
||||
// graph, with an 'index' attribute corresponding to the input position.
|
||||
struct InputArgInstantiation {
|
||||
InputArgInstantiation(string node_name, DataType data_type)
|
||||
: node_name(std::move(node_name)), data_type(data_type) {}
|
||||
string node_name;
|
||||
// WARNING(ezhulenev): Currently we do not support functions with inputs or
|
||||
// outputs instantiated into multiple tensors. This can happen if the
|
||||
// input/output type is 'T*N' or 'list(type)'. This is enforced by multiple
|
||||
// checks across this file and also function_optimizer.cc. InputArgExpansion and
|
||||
// OutputArgExpansion already support lists of tensors, but that's pretty much
|
||||
// it, all other code is written with assumption that expansions are always of
|
||||
// size 1. MakeGrapplerFunctionItem will gracefully fail with Status error.
|
||||
//
|
||||
// This is a low priority feature, because in practice we don't see a lot (any
|
||||
// at all?) functions with such arguments. Tensorflow-Eager always produces
|
||||
// functions with plain input/output arguments.
|
||||
|
||||
// TODO(ezhulenev): Support inputs and outputs of type 'T*N'.
|
||||
// TODO(ezhulenev): Support inputs and outputs of type 'list(type)'.
|
||||
|
||||
// Depending on the function instantiation attributes, input argument to the
|
||||
// function might be a single tensor, list of tensors of the same type, or a
|
||||
// list of tensors of different types.
|
||||
//
|
||||
// InputArgExpansion keeps track of the placeholders that were added to the
|
||||
// function body in place of function inputs and a resolved input data type.
|
||||
struct InputArgExpansion {
|
||||
string input_name;
|
||||
DataType data_type;
|
||||
bool is_ref;
|
||||
absl::InlinedVector<string, 1> placeholders;
|
||||
};
|
||||
|
||||
// Function output instantiated into a '_Retval' node in the function body
|
||||
// graph, with an 'index' attribute corresponding to the output position.
|
||||
struct OutputArgInstantiation {
|
||||
OutputArgInstantiation(string node_name, DataType data_type)
|
||||
: node_name(std::move(node_name)), data_type(data_type) {}
|
||||
string node_name;
|
||||
// Depending on the function instantiation attributes, output argument is mapped
|
||||
// to one or more outputs of one of the function body nodes.
|
||||
//
|
||||
// OutputArgExpansion keeps track of the Identity nodes that were added to the
|
||||
// function body to forward output tensors. Adding these output nodes allows
|
||||
// nested function inlining and specialization (see function optimizer).
|
||||
struct OutputArgExpansion {
|
||||
string output_name;
|
||||
DataType data_type;
|
||||
bool is_ref;
|
||||
absl::InlinedVector<string, 1> output_nodes;
|
||||
};
|
||||
|
||||
// A mapping from control output name to node name in function body graph.
|
||||
@ -57,6 +80,78 @@ struct ControlOutput {
|
||||
string node_name;
|
||||
};
|
||||
|
||||
// FunctionDef uses different connectivity encoding for the function body nodes,
|
||||
// then a GraphDef (see function.proto for details). Input name in FunctionDef
|
||||
// can potentially represent a sequence of tensors (instead just one tensor in
|
||||
// GraphDef), we need to expand it when converting from FunctionDef to GraphDef,
|
||||
// and fold it back when doing backward conversion.
|
||||
class GrapplerFunctionConnectivity {
|
||||
public:
|
||||
void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
|
||||
void RegisterFunctionBodyOutputs(const string& node_name,
|
||||
tensorflow::NameRangeMap&& outputs);
|
||||
|
||||
// Expands input encoded in FunctionDef format (name[:output][:position]) into
|
||||
// multiple inputs in GraphDef format (name[:position]).
|
||||
Status ExpandFunctionDefInput(const string& func_def_input,
|
||||
std::vector<string>* graph_def_inputs) const;
|
||||
|
||||
// Updates Node inputs from FunctionDef to GraphDef format.
|
||||
Status ExpandNodeInputs(NodeDef* function_body_node) const;
|
||||
|
||||
// When expanding inputs in function def format, single input might be
|
||||
// expanded into multiple tensors. When converting back to the function def
|
||||
// format from graph def format, it's always a 1-to-1 relationship.
|
||||
// FunctionDef built from GrapplerFunctionItem is always specialized to its
|
||||
// instantiation attributes and length of input args (and node def outputs) is
|
||||
// known.
|
||||
|
||||
// Converts input name from GraphDef format (name[:position]) to the
|
||||
// FunctionDef input format (name[:output][:position]) using registered input
|
||||
// arg expansion and function body outputs.
|
||||
Status AsFunctionDefInput(const string& graph_def_input,
|
||||
string* func_def_input) const;
|
||||
|
||||
// Updates Node inputs from GraphDef to FunctionDef format.
|
||||
Status AsFunctionDefNode(NodeDef* function_body_node) const;
|
||||
|
||||
private:
|
||||
// Mapping from input name to input arg expansion.
|
||||
absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
|
||||
// Mapping from function body node name to output names range map.
|
||||
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
|
||||
|
||||
// For each placeholder added to the function instantiation graph, we keep a
|
||||
// mapping back to the function input argument name and index.
|
||||
struct InputArgPlaceholder {
|
||||
string input_name; // Name of the function input argument.
|
||||
int input_index; // Index of a tensor in the function input argument
|
||||
// expansion, it can be greater than `0` if input
|
||||
// argument is a list of tensors (aka list(type)).
|
||||
};
|
||||
// Mapping from input arg placeholder to the function input tensor.
|
||||
absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
|
||||
};
|
||||
|
||||
// Get Function type attributes using attributes of a node that instantiated
|
||||
// a function.
|
||||
class GrapplerFunctionItemInstantiation {
|
||||
public:
|
||||
explicit GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr)
|
||||
: func_instantiation_attr_(func_instantiation_attr) {}
|
||||
|
||||
// Get DataType from attributes by name. Return error if attribute is missing,
|
||||
// or it doesn't define a valid data type.
|
||||
Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const;
|
||||
|
||||
// Get argument data type. If data type is not explicitly defined, uses
|
||||
// provided attribute name to look it up in function attributes.
|
||||
Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const;
|
||||
|
||||
private:
|
||||
const AttrSlice func_instantiation_attr_; // do not own
|
||||
};
|
||||
|
||||
// A special case of GrapplerItem, constructed from a TensorFlow Function.
|
||||
class GrapplerFunctionItem : public GrapplerItem {
|
||||
public:
|
||||
@ -64,12 +159,12 @@ class GrapplerFunctionItem : public GrapplerItem {
|
||||
|
||||
const string& description() const;
|
||||
|
||||
const std::vector<InputArgInstantiation>& inputs() const;
|
||||
const InputArgInstantiation& input(int i) const;
|
||||
const std::vector<InputArgExpansion>& inputs() const;
|
||||
const InputArgExpansion& input(int i) const;
|
||||
const std::size_t input_size() const;
|
||||
|
||||
const std::vector<OutputArgInstantiation>& outputs() const;
|
||||
const OutputArgInstantiation& output(int i) const;
|
||||
const std::vector<OutputArgExpansion>& outputs() const;
|
||||
const OutputArgExpansion& output(int i) const;
|
||||
const std::size_t output_size() const;
|
||||
|
||||
const std::vector<ControlOutput>& control_outputs() const;
|
||||
@ -95,8 +190,8 @@ class GrapplerFunctionItem : public GrapplerItem {
|
||||
|
||||
GrapplerFunctionItem(string func_name, string description,
|
||||
AttrSlice func_attr,
|
||||
std::vector<InputArgInstantiation> input_args,
|
||||
std::vector<OutputArgInstantiation> output_args,
|
||||
std::vector<InputArgExpansion> input_arg_expansions,
|
||||
std::vector<OutputArgExpansion> output_arg_expansions,
|
||||
std::vector<ControlOutput> control_outputs,
|
||||
int graph_def_version, bool is_stateful,
|
||||
GraphDef&& function_body);
|
||||
@ -105,8 +200,8 @@ class GrapplerFunctionItem : public GrapplerItem {
|
||||
AttrSlice func_attr_; // Attributes specific to function definition that
|
||||
// produced this item (FuncDef.attr field).
|
||||
|
||||
std::vector<InputArgInstantiation> input_args_;
|
||||
std::vector<OutputArgInstantiation> output_args_;
|
||||
std::vector<InputArgExpansion> input_arg_expansions_;
|
||||
std::vector<OutputArgExpansion> output_arg_expansions_;
|
||||
std::vector<ControlOutput> control_outputs_;
|
||||
|
||||
bool is_stateful_ = false;
|
||||
@ -137,13 +232,22 @@ Status InstantiationBodyParameters(
|
||||
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
||||
absl::flat_hash_map<string, AttrValue>* body_parameters);
|
||||
|
||||
// Register GrapplerFunctionItem input arg expansion and function body outputs
|
||||
// in the GrapplerFunctionConnectivity. Use function library definition to
|
||||
// lookup function body nodes output names and ranges.
|
||||
Status RegisterGrapplerFunctionConnectivity(
|
||||
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
|
||||
GrapplerFunctionConnectivity* connectivity);
|
||||
|
||||
// Replace one of the function inputs with a constant.
|
||||
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
|
||||
GrapplerFunctionItem* item);
|
||||
|
||||
// Removes outputs from instantiated grappler function item. For all active
|
||||
// function outputs that changed its output index, this function adds an output
|
||||
// mapping (std::pair<old index, new index>).
|
||||
// Removes outputs from instantiated grappler function item. Function node
|
||||
// outputs use GraphDef output index encoding, and multiple outputs might belong
|
||||
// to the same output argument expansion (in case of tensor list outputs). For
|
||||
// all active function outputs that changed its output index, this function adds
|
||||
// an output mapping (std::pair<old index, new index>).
|
||||
Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
|
||||
GrapplerFunctionItem* item,
|
||||
std::vector<std::pair<int, int>>* output_mapping);
|
||||
|
||||
@ -63,16 +63,11 @@ TEST_F(FunctionsTest, InstantiationParameters) {
|
||||
FunctionDef func = FunctionDefHelper::Create(
|
||||
"ParametrizedFunc",
|
||||
/* inputs */
|
||||
{"input1:A", "input2:B", "input3:float", "input4: C"},
|
||||
{"input1:A", "input2:B", "input3:float"},
|
||||
/* outputs */
|
||||
{"output1: A", "output2:D"},
|
||||
{"output1: A", "output2:C"},
|
||||
/* type parameters */
|
||||
{
|
||||
"A: {float, double}",
|
||||
"B: {float, int32}",
|
||||
"C: list(type)",
|
||||
"D: {float, double}",
|
||||
},
|
||||
{"A: {float, double}", "B: {float, int32}", "C: {float, double}"},
|
||||
/* function body*/
|
||||
{{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}},
|
||||
/* Mapping between function returns and function node outputs. */
|
||||
@ -82,20 +77,16 @@ TEST_F(FunctionsTest, InstantiationParameters) {
|
||||
func_instantiation_attr["key"].set_s("key-value");
|
||||
func_instantiation_attr["A"].set_type(DT_FLOAT);
|
||||
func_instantiation_attr["B"].set_type(DT_INT32);
|
||||
func_instantiation_attr["C"].mutable_list()->add_type(DT_FLOAT);
|
||||
func_instantiation_attr["C"].mutable_list()->add_type(DT_INT32);
|
||||
func_instantiation_attr["D"].set_type(DT_DOUBLE);
|
||||
func_instantiation_attr["C"].set_type(DT_DOUBLE);
|
||||
|
||||
absl::flat_hash_map<string, DataType> type_parameters;
|
||||
TF_EXPECT_OK(InstantiationTypeParameters(
|
||||
func, AttrSlice(&func_instantiation_attr), &type_parameters));
|
||||
|
||||
ASSERT_EQ(5, type_parameters.size());
|
||||
ASSERT_EQ(3, type_parameters.size());
|
||||
EXPECT_EQ(DT_FLOAT, type_parameters["A"]);
|
||||
EXPECT_EQ(DT_INT32, type_parameters["B"]);
|
||||
EXPECT_EQ(DT_FLOAT, type_parameters["C:0"]);
|
||||
EXPECT_EQ(DT_INT32, type_parameters["C:1"]);
|
||||
EXPECT_EQ(DT_DOUBLE, type_parameters["D"]);
|
||||
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
|
||||
|
||||
absl::flat_hash_map<string, AttrValue> body_parameters;
|
||||
TF_EXPECT_OK(InstantiationBodyParameters(
|
||||
@ -105,6 +96,131 @@ TEST_F(FunctionsTest, InstantiationParameters) {
|
||||
EXPECT_EQ("key-value", body_parameters["key"].s());
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
|
||||
GrapplerFunctionConnectivity connectivity;
|
||||
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
|
||||
|
||||
connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
|
||||
connectivity.RegisterFunctionBodyOutputs("Func",
|
||||
{{"o1", {0, 2}}, {"o2", {2, 4}}});
|
||||
|
||||
std::vector<string> inputs;
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputA", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("inputA", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB", &inputs));
|
||||
ASSERT_EQ(2, inputs.size());
|
||||
EXPECT_EQ("inputB_0", inputs[0]);
|
||||
EXPECT_EQ("inputB_1", inputs[1]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("inputB:1", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("inputB_1", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Add:z", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("Add", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1", &inputs));
|
||||
ASSERT_EQ(2, inputs.size());
|
||||
EXPECT_EQ("Func", inputs[0]);
|
||||
EXPECT_EQ("Func:1", inputs[1]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2", &inputs));
|
||||
ASSERT_EQ(2, inputs.size());
|
||||
EXPECT_EQ("Func:2", inputs[0]);
|
||||
EXPECT_EQ("Func:3", inputs[1]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:0", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("Func", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o1:1", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("Func:1", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:0", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("Func:2", inputs[0]);
|
||||
|
||||
inputs.clear();
|
||||
TF_EXPECT_OK(connectivity.ExpandFunctionDefInput("Func:o2:1", &inputs));
|
||||
ASSERT_EQ(1, inputs.size());
|
||||
EXPECT_EQ("Func:3", inputs[0]);
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) {
|
||||
GrapplerFunctionConnectivity connectivity;
|
||||
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
|
||||
|
||||
connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
|
||||
connectivity.RegisterFunctionBodyOutputs("Func",
|
||||
{{"o1", {0, 2}}, {"o2", {2, 4}}});
|
||||
|
||||
string input;
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input));
|
||||
EXPECT_EQ("inputA:0", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input));
|
||||
EXPECT_EQ("inputB:0", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input));
|
||||
EXPECT_EQ("inputB:1", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input));
|
||||
EXPECT_EQ("Add:z:0", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input));
|
||||
EXPECT_EQ("Func:o1:0", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input));
|
||||
EXPECT_EQ("Func:o1:1", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input));
|
||||
EXPECT_EQ("Func:o2:0", input);
|
||||
|
||||
TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input));
|
||||
EXPECT_EQ("Func:o2:1", input);
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) {
|
||||
GrapplerFunctionConnectivity connectivity;
|
||||
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputA", DT_FLOAT, /*is_ref=*/false, {"inputA"}});
|
||||
connectivity.RegisterInputArgExpansion(
|
||||
{"inputB", DT_FLOAT, /*is_ref=*/false, {"inputB_0", "inputB_1"}});
|
||||
|
||||
NodeDef node;
|
||||
node.add_input("inputA:0");
|
||||
node.add_input("inputB");
|
||||
|
||||
TF_EXPECT_OK(connectivity.ExpandNodeInputs(&node));
|
||||
|
||||
EXPECT_EQ(3, node.input_size());
|
||||
EXPECT_EQ("inputA", node.input(0));
|
||||
EXPECT_EQ("inputB_0", node.input(1));
|
||||
EXPECT_EQ("inputB_1", node.input(2));
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, FromSimpleFunctionDef) {
|
||||
const Tensor kTwo = test::AsScalar<int64>(2);
|
||||
FunctionDef func = FunctionDefHelper::Define(
|
||||
@ -136,17 +252,19 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
|
||||
EXPECT_EQ(5, item.function_body().node_size());
|
||||
|
||||
EXPECT_EQ(1, item.input_size());
|
||||
EXPECT_EQ("x", item.input(0).node_name);
|
||||
EXPECT_EQ("x", item.input(0).input_name);
|
||||
ASSERT_EQ(1, item.input(0).placeholders.size());
|
||||
EXPECT_EQ("x", item.input(0).placeholders[0]);
|
||||
|
||||
EXPECT_EQ(1, item.output_size());
|
||||
EXPECT_EQ("y_RetVal", item.output(0).node_name);
|
||||
EXPECT_EQ("y", item.output(0).output_name);
|
||||
EXPECT_EQ("y_output_node_0", item.output(0).output_nodes[0]);
|
||||
|
||||
int count = 0;
|
||||
for (const NodeDef &node : item.function_body().node()) {
|
||||
if (node.name() == "x" && ++count) {
|
||||
EXPECT_EQ("_Arg", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
|
||||
EXPECT_EQ(0, node.attr().at("index").i());
|
||||
EXPECT_EQ("Placeholder", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
} else if (node.name() == "two" && ++count) {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
@ -162,11 +280,10 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("scale", node.input(1));
|
||||
} else if (node.name() == "y_RetVal" && ++count) {
|
||||
EXPECT_EQ("_Retval", node.op());
|
||||
} else if (node.name() == "y_output_node_0" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
ASSERT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("y", node.input(0));
|
||||
EXPECT_EQ(0, node.attr().at("index").i());
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(5, count);
|
||||
@ -217,22 +334,20 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
|
||||
EXPECT_EQ(14, item.function_body().node_size());
|
||||
|
||||
ASSERT_EQ(3, item.input_size());
|
||||
EXPECT_EQ("x", item.input(0).node_name);
|
||||
EXPECT_EQ("y", item.input(1).node_name);
|
||||
EXPECT_EQ("dz", item.input(2).node_name);
|
||||
EXPECT_EQ("x", item.input(0).input_name);
|
||||
EXPECT_EQ("y", item.input(1).input_name);
|
||||
EXPECT_EQ("dz", item.input(2).input_name);
|
||||
|
||||
ASSERT_EQ(2, item.output_size());
|
||||
EXPECT_EQ("dx_RetVal", item.output(0).node_name);
|
||||
EXPECT_EQ("dy_RetVal", item.output(1).node_name);
|
||||
EXPECT_EQ("dx_output_node_0", item.output(0).output_nodes[0]);
|
||||
EXPECT_EQ("dy_output_node_0", item.output(1).output_nodes[0]);
|
||||
|
||||
int count = 0;
|
||||
for (const NodeDef &node : item.function_body().node()) {
|
||||
if (node.name() == "x" || node.name() == "y" || node.name() == "dz") {
|
||||
count++;
|
||||
EXPECT_EQ("_Arg", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
|
||||
int expected_index = node.name() == "x" ? 0 : node.name() == "y" ? 1 : 2;
|
||||
EXPECT_EQ(expected_index, node.attr().at("index").i());
|
||||
EXPECT_EQ("Placeholder", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
} else if (node.name() == "rx" && ++count) {
|
||||
EXPECT_EQ("BroadcastGradientArgs", node.op());
|
||||
@ -249,14 +364,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("gy", node.input(0));
|
||||
EXPECT_EQ("rx:1", node.input(1));
|
||||
} else if (node.name() == "dx_RetVal" && ++count) {
|
||||
EXPECT_EQ("_Retval", node.op());
|
||||
EXPECT_EQ(0, node.attr().at("index").i());
|
||||
} else if (node.name() == "dx_output_node_0" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
ASSERT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("dx", node.input(0));
|
||||
} else if (node.name() == "dy_RetVal" && ++count) {
|
||||
EXPECT_EQ("_Retval", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("index").i());
|
||||
} else if (node.name() == "dy_output_node_0" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
ASSERT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("dy", node.input(0));
|
||||
}
|
||||
@ -312,10 +425,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
|
||||
for (const NodeDef &node : item.function_body().node()) {
|
||||
if (node.name() == "x" || node.name() == "y") {
|
||||
count++;
|
||||
EXPECT_EQ("_Arg", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
|
||||
int expected_index = node.name() == "x" ? 0 : 1;
|
||||
EXPECT_EQ(expected_index, node.attr().at("index").i());
|
||||
EXPECT_EQ("Placeholder", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
} else if (node.name() == "a0" && ++count) {
|
||||
EXPECT_EQ("Swap", node.op());
|
||||
@ -374,14 +485,13 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
|
||||
flib, TF_GRAPH_DEF_VERSION, &item));
|
||||
|
||||
EXPECT_EQ(1, item.output_size());
|
||||
EXPECT_EQ("out_RetVal", item.output(0).node_name);
|
||||
EXPECT_EQ("out_output_node_0", item.output(0).output_nodes[0]);
|
||||
|
||||
int count = 0;
|
||||
for (const NodeDef &node : item.function_body().node()) {
|
||||
if (node.name() == "in" && ++count) {
|
||||
EXPECT_EQ("_Arg", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
|
||||
EXPECT_EQ(0, node.attr().at("index").i());
|
||||
EXPECT_EQ("Placeholder", node.op());
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
} else if (node.name() == "Linear_func" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
@ -391,9 +501,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
|
||||
EXPECT_EQ("Exp", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("Linear_func", node.input(0));
|
||||
} else if (node.name() == "out_RetVal" && ++count) {
|
||||
EXPECT_EQ("_Retval", node.op());
|
||||
EXPECT_EQ(0, node.attr().at("index").i());
|
||||
} else if (node.name() == "out_output_node_0" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
ASSERT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("Exp", node.input(0));
|
||||
}
|
||||
@ -401,6 +510,70 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
|
||||
EXPECT_EQ(4, count);
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
|
||||
FunctionDef func = FunctionDefHelper::Create(
|
||||
// Name
|
||||
"ForwardInputs",
|
||||
// Args
|
||||
{"in0: float", "in1: float", "arg2: float", "arg3: int32", "arg4: float"},
|
||||
// Return values
|
||||
{"out0: float", "arg2: float", "arg3: int32"},
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{},
|
||||
// Mapping
|
||||
{{"out0", "in0"}});
|
||||
|
||||
protobuf::Map<string, AttrValue> func_instantiation_attr;
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
|
||||
|
||||
GrapplerFunctionItem item;
|
||||
TF_EXPECT_OK(MakeGrapplerFunctionItem(func,
|
||||
AttrSlice(&func_instantiation_attr),
|
||||
flib, TF_GRAPH_DEF_VERSION, &item));
|
||||
|
||||
EXPECT_EQ("ForwardInputs", item.id);
|
||||
EXPECT_EQ(8, item.function_body().node_size());
|
||||
|
||||
EXPECT_EQ(3, item.output_size());
|
||||
EXPECT_EQ("out0_output_node_0", item.output(0).output_nodes[0]);
|
||||
EXPECT_EQ("arg2_output_node_0", item.output(1).output_nodes[0]);
|
||||
EXPECT_EQ("arg3_output_node_0", item.output(2).output_nodes[0]);
|
||||
|
||||
int count = 0;
|
||||
|
||||
const auto is_arg_placeholder = [](const string &name) {
|
||||
return name == "in0" || name == "in1" || name == "arg2" || name == "arg3" ||
|
||||
name == "arg4";
|
||||
};
|
||||
|
||||
for (const NodeDef &node : item.function_body().node()) {
|
||||
if (is_arg_placeholder(node.name()) && node.op() == "Placeholder") {
|
||||
count++;
|
||||
if (node.name() == "arg3") {
|
||||
EXPECT_EQ(DT_INT32, node.attr().at("dtype").type());
|
||||
} else {
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
ASSERT_EQ(1, node.input_size());
|
||||
EXPECT_TRUE(is_arg_placeholder(node.input(0)));
|
||||
|
||||
if (node.name() == "out0_output_node_0" && ++count) {
|
||||
EXPECT_EQ("in0", node.input(0));
|
||||
} else if (node.name() == "arg2_output_node_0" && ++count) {
|
||||
EXPECT_EQ("arg2", node.input(0));
|
||||
} else if (node.name() == "arg3_output_node_0" && ++count) {
|
||||
EXPECT_EQ("arg3", node.input(0));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(8, count);
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
|
||||
const Tensor kTwo = test::AsScalar<int64>(2);
|
||||
FunctionDef func = FunctionDefHelper::Define(
|
||||
@ -427,7 +600,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
|
||||
|
||||
EXPECT_EQ(0, item.input_size());
|
||||
EXPECT_EQ(1, item.output_size());
|
||||
EXPECT_EQ("o_RetVal", item.output(0).node_name);
|
||||
EXPECT_EQ("o_output_node_0", item.output(0).output_nodes[0]);
|
||||
EXPECT_EQ(3, item.function_body().node_size());
|
||||
|
||||
const NodeDef &two = item.function_body().node(0);
|
||||
@ -440,7 +613,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
|
||||
EXPECT_EQ("two", cast.input(0));
|
||||
|
||||
const NodeDef &retval = item.function_body().node(2);
|
||||
EXPECT_EQ("o_RetVal", retval.name());
|
||||
EXPECT_EQ("o_output_node_0", retval.name());
|
||||
EXPECT_EQ(1, retval.input_size());
|
||||
EXPECT_EQ("o", retval.input(0));
|
||||
}
|
||||
@ -541,14 +714,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
|
||||
EXPECT_EQ("y", specialized.signature().output_arg(0).name());
|
||||
EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
|
||||
|
||||
// Function body specialized for instantiation types.
|
||||
// Function body specialized for instantiation types
|
||||
int count = 0;
|
||||
for (const NodeDef &node : specialized.node_def()) {
|
||||
if (node.name() == "scale" && ++count) {
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type());
|
||||
} else if (node.name() == "y" && ++count) {
|
||||
EXPECT_EQ("Mul", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("x:0", node.input(0));
|
||||
EXPECT_EQ("scale:y:0", node.input(1));
|
||||
EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
|
||||
}
|
||||
@ -580,9 +753,9 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) {
|
||||
const NodeDef &input_x = item.function_body().node(0);
|
||||
const NodeDef &input_y = item.function_body().node(1);
|
||||
|
||||
// Initially inputs added to the graph as _Arg nodes.
|
||||
EXPECT_EQ("_Arg", input_x.op());
|
||||
EXPECT_EQ("_Arg", input_y.op());
|
||||
// Initially inputs added to the graph as placeholders.
|
||||
EXPECT_EQ("Placeholder", input_x.op());
|
||||
EXPECT_EQ("Placeholder", input_y.op());
|
||||
|
||||
// Replace inputs x and y with constants.
|
||||
NodeDef const_input_x;
|
||||
@ -651,7 +824,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
|
||||
GraphDef id_func_body = test::function::GDef(
|
||||
{/* Read and return input argument through Identity node. */
|
||||
NDef("read_x", "Identity", {"x"}, {{"T", "float"}}),
|
||||
NDef("z_RetVal", "_Retval", {"read_x"}, {{"T", "float"}})});
|
||||
NDef("z_output_node_0", "Identity", {"read_x"}, {{"T", "float"}})});
|
||||
|
||||
protobuf::Map<string, AttrValue> func_instantiation_attr;
|
||||
func_instantiation_attr["T"].set_type(DT_FLOAT);
|
||||
@ -676,7 +849,7 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
|
||||
for (const NodeDef &node : specialized.node_def()) {
|
||||
if (node.name() == "read_x" && ++count) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("x:0", node.input(0));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, count);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user