Create new function overload for GraphToFunctionDef().
PiperOrigin-RevId: 274196887
This commit is contained in:
parent
e9b998ec50
commit
0d9f353863
@ -297,6 +297,76 @@ Status FillFunctionBody(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GraphToFunctionDefHelper(
|
||||||
|
const Graph& graph, const string& name,
|
||||||
|
const std::function<absl::optional<string>(const Node*)>& control_ret,
|
||||||
|
const std::vector<string>& output_names, FunctionDef* fdef) {
|
||||||
|
auto add_arg_or_retval = [](Node* node,
|
||||||
|
std::vector<OutputTensor>* args_or_retvals) {
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
|
||||||
|
if (index >= args_or_retvals->size()) {
|
||||||
|
args_or_retvals->resize(index + 1);
|
||||||
|
}
|
||||||
|
if ((*args_or_retvals)[index].node == nullptr) {
|
||||||
|
(*args_or_retvals)[index].node = node;
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Multiple '", node->type_string(),
|
||||||
|
"' nodes found with index ", index);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<const Node*> body_nodes;
|
||||||
|
std::vector<OutputTensor> inputs;
|
||||||
|
std::vector<OutputTensor> outputs;
|
||||||
|
std::vector<const Node*> control_outputs;
|
||||||
|
std::vector<string> control_output_names;
|
||||||
|
for (Node* node : graph.op_nodes()) {
|
||||||
|
if (node->IsArg()) {
|
||||||
|
TF_RETURN_IF_ERROR(add_arg_or_retval(node, &inputs));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->IsRetval()) {
|
||||||
|
TF_RETURN_IF_ERROR(add_arg_or_retval(node, &outputs));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (control_ret) {
|
||||||
|
auto control_ret_name = control_ret(node);
|
||||||
|
if (control_ret_name.has_value()) {
|
||||||
|
control_outputs.push_back(node);
|
||||||
|
control_output_names.push_back(control_ret_name.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto validate_args_retvals =
|
||||||
|
[](const std::vector<OutputTensor>& args_or_retvals,
|
||||||
|
const string& op_type) {
|
||||||
|
for (int i = 0, e = args_or_retvals.size(); i < e; ++i) {
|
||||||
|
if (args_or_retvals[i].node == nullptr) {
|
||||||
|
return errors::InvalidArgument("Missing '", op_type,
|
||||||
|
"' node at index ", i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg"));
|
||||||
|
TF_RETURN_IF_ERROR(validate_args_retvals(outputs, "_Retval"));
|
||||||
|
|
||||||
|
return GraphToFunctionDef(graph, name, /*append_hash_to_fn_name=*/false,
|
||||||
|
/*set_stateful_from_nodes=*/false,
|
||||||
|
/*copy_placeholder_attrs_from_nodes=*/false,
|
||||||
|
body_nodes, inputs, outputs, output_names,
|
||||||
|
control_outputs, control_output_names,
|
||||||
|
/*description=*/nullptr, fdef);
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||||
@ -499,70 +569,8 @@ Status GraphToFunctionDef(
|
|||||||
const Graph& graph, const string& name,
|
const Graph& graph, const string& name,
|
||||||
const std::function<absl::optional<string>(const Node*)>& control_ret,
|
const std::function<absl::optional<string>(const Node*)>& control_ret,
|
||||||
FunctionDef* fdef) {
|
FunctionDef* fdef) {
|
||||||
auto add_arg_or_retval = [](Node* node,
|
return GraphToFunctionDefHelper(graph, name, control_ret,
|
||||||
std::vector<OutputTensor>* args_or_retvals) {
|
/*output_names=*/{}, fdef);
|
||||||
int index;
|
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
|
|
||||||
if (index >= args_or_retvals->size()) {
|
|
||||||
args_or_retvals->resize(index + 1);
|
|
||||||
}
|
|
||||||
if ((*args_or_retvals)[index].node == nullptr) {
|
|
||||||
(*args_or_retvals)[index].node = node;
|
|
||||||
} else {
|
|
||||||
return errors::InvalidArgument("Multiple '", node->type_string(),
|
|
||||||
"' nodes found with index ", index);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<const Node*> body_nodes;
|
|
||||||
std::vector<OutputTensor> inputs;
|
|
||||||
std::vector<OutputTensor> outputs;
|
|
||||||
std::vector<const Node*> control_outputs;
|
|
||||||
std::vector<string> control_output_names;
|
|
||||||
for (Node* node : graph.op_nodes()) {
|
|
||||||
if (node->IsArg()) {
|
|
||||||
TF_RETURN_IF_ERROR(add_arg_or_retval(node, &inputs));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->IsRetval()) {
|
|
||||||
TF_RETURN_IF_ERROR(add_arg_or_retval(node, &outputs));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (control_ret) {
|
|
||||||
auto control_ret_name = control_ret(node);
|
|
||||||
if (control_ret_name.has_value()) {
|
|
||||||
control_outputs.push_back(node);
|
|
||||||
control_output_names.push_back(control_ret_name.value());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
body_nodes.push_back(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto validate_args_retvals =
|
|
||||||
[](const std::vector<OutputTensor>& args_or_retvals,
|
|
||||||
const string& op_type) {
|
|
||||||
for (int i = 0, e = args_or_retvals.size(); i < e; ++i) {
|
|
||||||
if (args_or_retvals[i].node == nullptr) {
|
|
||||||
return errors::InvalidArgument("Missing '", op_type,
|
|
||||||
"' node at index ", i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg"));
|
|
||||||
TF_RETURN_IF_ERROR(validate_args_retvals(outputs, "_Retval"));
|
|
||||||
|
|
||||||
return GraphToFunctionDef(graph, name, /*append_hash_to_fn_name=*/false,
|
|
||||||
/*set_stateful_from_nodes=*/false,
|
|
||||||
/*copy_placeholder_attrs_from_nodes=*/false,
|
|
||||||
body_nodes, inputs, outputs, /*output_names=*/{},
|
|
||||||
control_outputs, control_output_names,
|
|
||||||
/*description=*/nullptr, fdef);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||||
@ -570,4 +578,11 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
|||||||
return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef);
|
return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
FunctionDef* fdef) {
|
||||||
|
return GraphToFunctionDefHelper(graph, name, /*control_ret=*/nullptr,
|
||||||
|
output_names, fdef);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -56,6 +56,10 @@ Status GraphToFunctionDef(
|
|||||||
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||||
FunctionDef* fdef);
|
FunctionDef* fdef);
|
||||||
|
|
||||||
|
Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
FunctionDef* fdef);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user