Fix outside compilation to add new function instead of inplace function update.

PiperOrigin-RevId: 350863196
Change-Id: I70a41fd44028768032af09a71dc3f489ae785b45
This commit is contained in:
A. Unique TensorFlower 2021-01-08 17:10:58 -08:00 committed by TensorFlower Gardener
parent ef06907104
commit e81693934c

View File

@ -565,6 +565,20 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
function_body.graph->RemoveNode(lifted_arg_node);
}
// Adds function def to function definition library and update the function
// callsite operation `callsite_node` to invoke new function instead.
Status AddFunctionWithNewName(const std::string& new_name,
const std::string& func_attr_name,
const FunctionDef& function_def,
NameAttrList* func_attr, Node* callsite_node,
FunctionLibraryDefinition* fld) {
TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
func_attr->set_name(new_name);
callsite_node->ClearAttr(func_attr_name);
callsite_node->AddAttr(func_attr_name, *func_attr);
return Status::OK();
}
// Reconnect outside compilation lifted arguments in a functional While node to
// its outside compilation tensor sources.
Status PostprocessLiftedArgsForWhile(
@ -633,12 +647,15 @@ Status PostprocessLiftedArgsForWhile(
*body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
}
const auto new_body_function_name =
fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
FunctionDef rewritten_body_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*body_function_body->graph, body_func.name(), HostGraphControlRetMapping,
&rewritten_body_function_def));
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(body_func.name(), rewritten_body_function_def));
*body_function_body->graph, new_body_function_name,
HostGraphControlRetMapping, &rewritten_body_function_def));
TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
rewritten_body_function_def,
&body_func, n, fld));
// In cond_graph, just add new _Arg nodes.
NameAttrList cond_func;
@ -657,13 +674,15 @@ Status PostprocessLiftedArgsForWhile(
TF_RETURN_IF_ERROR(arg_node_or.status());
}
const auto new_cond_function_name =
fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
FunctionDef rewritten_cond_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*cond_function_body->graph, cond_func.name(), HostGraphControlRetMapping,
&rewritten_cond_function_def));
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(cond_func.name(), rewritten_cond_function_def));
*cond_function_body->graph, new_cond_function_name,
HostGraphControlRetMapping, &rewritten_cond_function_def));
TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
rewritten_cond_function_def,
&cond_func, n, fld));
return Status::OK();
}
@ -779,19 +798,25 @@ Status PostprocessLiftedArgsForIf(
else_branch_lifted_arg_nodes, else_branch_arg_node);
}
const auto new_then_function_name = fld->UniqueFunctionName(
absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
FunctionDef rewritten_then_branch_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*then_branch_function_body->graph, then_branch_func.name(),
*then_branch_function_body->graph, new_then_function_name,
HostGraphControlRetMapping, &rewritten_then_branch_function_def));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(),
rewritten_then_branch_function_def));
TF_RETURN_IF_ERROR(AddFunctionWithNewName(
new_then_function_name, "then_branch", rewritten_then_branch_function_def,
&then_branch_func, n, fld));
const auto new_else_function_name = fld->UniqueFunctionName(
absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
FunctionDef rewritten_else_branch_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*else_branch_function_body->graph, else_branch_func.name(),
*else_branch_function_body->graph, new_else_function_name,
HostGraphControlRetMapping, &rewritten_else_branch_function_def));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(),
rewritten_else_branch_function_def));
TF_RETURN_IF_ERROR(AddFunctionWithNewName(
new_else_function_name, "else_branch", rewritten_else_branch_function_def,
&else_branch_func, n, fld));
return Status::OK();
}
@ -852,11 +877,19 @@ Status PostprocessLiftedArgsForCall(
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
HostGraphControlRetMapping,
&rewritten_fdef));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef));
const auto new_function_name =
fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
rewritten_fdef.mutable_signature()->set_name(new_function_name);
TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
// We need to recreate the node. Otherwise TF will not know n->num_inputs()
// has increased.
NodeDef node_def = n->def();
// Function name is represented via the Op's type. Reset the op type to new
// function def name;
*node_def.mutable_op() = new_function_name;
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
Node* outside_compilation_node =
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
@ -1439,14 +1472,15 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
// Rewrites loop cond to add a node which sends loop cond to host.
TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func,
const string& while_node_name, const string& host_transfer_key) {
const string& cond_xla_func_name, const string& host_transfer_key,
NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
Node* while_node) {
// Instantiate the loop cond function.
std::unique_ptr<FunctionBody> fbody;
const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name());
const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
TF_RET_CHECK(loop_cond_fdef);
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody));
*loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
Graph* g = fbody->graph;
// Find the _Retval node and the loop cond node.
@ -1455,7 +1489,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
if (n->type_string() == "_Retval") {
if (ret_node) {
return errors::Internal("Multiple return node for loop cond function ",
loop_cond_func.name(), ": ",
loop_cond_func->name(), ": ",
ret_node->DebugString(), " and ",
n->DebugString());
} else {
@ -1465,14 +1499,14 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
}
if (!ret_node) {
return errors::Internal("No _Retval node for loop cond function ",
loop_cond_func.name());
loop_cond_func->name());
}
Node* loop_cond;
TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
// Build the XlaSendToHost node.
NodeDefBuilder send_loop_cond_builder(
absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost");
absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
send_loop_cond_builder.Attr("Tinput", DT_BOOL);
send_loop_cond_builder.Attr("key",
absl::StrCat(host_transfer_key, "_dtoh_0"));
@ -1488,11 +1522,26 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
TF_RETURN_IF_ERROR(s);
g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
// Replace original function.
// Replace original function if loop_cond_func already has been re-written
// for outside compilation.
FunctionDef replace_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef));
if (loop_cond_func->name() == cond_xla_func_name) {
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
TF_RETURN_IF_ERROR(
fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
} else {
// If original while cond function has not been modified, add a new function
// with send loop predicated added and update the while node callsite
// operation.
const auto new_name = fld->UniqueFunctionName(
absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
loop_cond_func->set_name(new_name);
while_node->ClearAttr("cond");
while_node->AddAttr("cond", *loop_cond_func);
}
return Status::OK();
}
@ -2011,8 +2060,8 @@ Status ExtractOutsideCompilationForWhileNode(
// XLA computation: rewrite cond function to add a SendToHost node to send
// loop predicate.
TF_RETURN_IF_ERROR(
AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
cond_xla_func_name, host_transfer_key, &cond, fld, n));
n->AddAttr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName});