Fix outside compilation to add new function instead of inplace function update.

PiperOrigin-RevId: 350863196
Change-Id: I70a41fd44028768032af09a71dc3f489ae785b45
This commit is contained in:
A. Unique TensorFlower 2021-01-08 17:10:58 -08:00 committed by TensorFlower Gardener
parent ef06907104
commit e81693934c

View File

@ -565,6 +565,20 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
function_body.graph->RemoveNode(lifted_arg_node); function_body.graph->RemoveNode(lifted_arg_node);
} }
// Adds function def to function definition library and update the function
// callsite operation `callsite_node` to invoke new function instead.
Status AddFunctionWithNewName(const std::string& new_name,
const std::string& func_attr_name,
const FunctionDef& function_def,
NameAttrList* func_attr, Node* callsite_node,
FunctionLibraryDefinition* fld) {
TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
func_attr->set_name(new_name);
callsite_node->ClearAttr(func_attr_name);
callsite_node->AddAttr(func_attr_name, *func_attr);
return Status::OK();
}
// Reconnect outside compilation lifted arguments in a functional While node to // Reconnect outside compilation lifted arguments in a functional While node to
// its outside compilation tensor sources. // its outside compilation tensor sources.
Status PostprocessLiftedArgsForWhile( Status PostprocessLiftedArgsForWhile(
@ -633,12 +647,15 @@ Status PostprocessLiftedArgsForWhile(
*body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node); *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
} }
const auto new_body_function_name =
fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
FunctionDef rewritten_body_function_def; FunctionDef rewritten_body_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef( TF_RETURN_IF_ERROR(GraphToFunctionDef(
*body_function_body->graph, body_func.name(), HostGraphControlRetMapping, *body_function_body->graph, new_body_function_name,
&rewritten_body_function_def)); HostGraphControlRetMapping, &rewritten_body_function_def));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
fld->ReplaceFunction(body_func.name(), rewritten_body_function_def)); rewritten_body_function_def,
&body_func, n, fld));
// In cond_graph, just add new _Arg nodes. // In cond_graph, just add new _Arg nodes.
NameAttrList cond_func; NameAttrList cond_func;
@ -657,13 +674,15 @@ Status PostprocessLiftedArgsForWhile(
TF_RETURN_IF_ERROR(arg_node_or.status()); TF_RETURN_IF_ERROR(arg_node_or.status());
} }
const auto new_cond_function_name =
fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
FunctionDef rewritten_cond_function_def; FunctionDef rewritten_cond_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef( TF_RETURN_IF_ERROR(GraphToFunctionDef(
*cond_function_body->graph, cond_func.name(), HostGraphControlRetMapping, *cond_function_body->graph, new_cond_function_name,
&rewritten_cond_function_def)); HostGraphControlRetMapping, &rewritten_cond_function_def));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
fld->ReplaceFunction(cond_func.name(), rewritten_cond_function_def)); rewritten_cond_function_def,
&cond_func, n, fld));
return Status::OK(); return Status::OK();
} }
@ -779,19 +798,25 @@ Status PostprocessLiftedArgsForIf(
else_branch_lifted_arg_nodes, else_branch_arg_node); else_branch_lifted_arg_nodes, else_branch_arg_node);
} }
const auto new_then_function_name = fld->UniqueFunctionName(
absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
FunctionDef rewritten_then_branch_function_def; FunctionDef rewritten_then_branch_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef( TF_RETURN_IF_ERROR(GraphToFunctionDef(
*then_branch_function_body->graph, then_branch_func.name(), *then_branch_function_body->graph, new_then_function_name,
HostGraphControlRetMapping, &rewritten_then_branch_function_def)); HostGraphControlRetMapping, &rewritten_then_branch_function_def));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(), TF_RETURN_IF_ERROR(AddFunctionWithNewName(
rewritten_then_branch_function_def)); new_then_function_name, "then_branch", rewritten_then_branch_function_def,
&then_branch_func, n, fld));
const auto new_else_function_name = fld->UniqueFunctionName(
absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
FunctionDef rewritten_else_branch_function_def; FunctionDef rewritten_else_branch_function_def;
TF_RETURN_IF_ERROR(GraphToFunctionDef( TF_RETURN_IF_ERROR(GraphToFunctionDef(
*else_branch_function_body->graph, else_branch_func.name(), *else_branch_function_body->graph, new_else_function_name,
HostGraphControlRetMapping, &rewritten_else_branch_function_def)); HostGraphControlRetMapping, &rewritten_else_branch_function_def));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(), TF_RETURN_IF_ERROR(AddFunctionWithNewName(
rewritten_else_branch_function_def)); new_else_function_name, "else_branch", rewritten_else_branch_function_def,
&else_branch_func, n, fld));
return Status::OK(); return Status::OK();
} }
@ -852,11 +877,19 @@ Status PostprocessLiftedArgsForCall(
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(), TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
HostGraphControlRetMapping, HostGraphControlRetMapping,
&rewritten_fdef)); &rewritten_fdef));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef)); const auto new_function_name =
fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
rewritten_fdef.mutable_signature()->set_name(new_function_name);
TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
// We need to recreate the node. Otherwise TF will not know n->num_inputs() // We need to recreate the node. Otherwise TF will not know n->num_inputs()
// has increased. // has increased.
NodeDef node_def = n->def(); NodeDef node_def = n->def();
// Function name is represented via the Op's type. Reset the op type to new
// function def name;
*node_def.mutable_op() = new_function_name;
for (int i = original_arg_count, end = data_types.size(); i < end; i++) { for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
Node* outside_compilation_node = Node* outside_compilation_node =
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
@ -1439,14 +1472,15 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
// Rewrites loop cond to add a node which sends loop cond to host. // Rewrites loop cond to add a node which sends loop cond to host.
TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func, const string& cond_xla_func_name, const string& host_transfer_key,
const string& while_node_name, const string& host_transfer_key) { NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
Node* while_node) {
// Instantiate the loop cond function. // Instantiate the loop cond function.
std::unique_ptr<FunctionBody> fbody; std::unique_ptr<FunctionBody> fbody;
const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name()); const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
TF_RET_CHECK(loop_cond_fdef); TF_RET_CHECK(loop_cond_fdef);
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody)); *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
Graph* g = fbody->graph; Graph* g = fbody->graph;
// Find the _Retval node and the loop cond node. // Find the _Retval node and the loop cond node.
@ -1455,7 +1489,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
if (n->type_string() == "_Retval") { if (n->type_string() == "_Retval") {
if (ret_node) { if (ret_node) {
return errors::Internal("Multiple return node for loop cond function ", return errors::Internal("Multiple return node for loop cond function ",
loop_cond_func.name(), ": ", loop_cond_func->name(), ": ",
ret_node->DebugString(), " and ", ret_node->DebugString(), " and ",
n->DebugString()); n->DebugString());
} else { } else {
@ -1465,14 +1499,14 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
} }
if (!ret_node) { if (!ret_node) {
return errors::Internal("No _Retval node for loop cond function ", return errors::Internal("No _Retval node for loop cond function ",
loop_cond_func.name()); loop_cond_func->name());
} }
Node* loop_cond; Node* loop_cond;
TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
// Build the XlaSendToHost node. // Build the XlaSendToHost node.
NodeDefBuilder send_loop_cond_builder( NodeDefBuilder send_loop_cond_builder(
absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
send_loop_cond_builder.Attr("Tinput", DT_BOOL); send_loop_cond_builder.Attr("Tinput", DT_BOOL);
send_loop_cond_builder.Attr("key", send_loop_cond_builder.Attr("key",
absl::StrCat(host_transfer_key, "_dtoh_0")); absl::StrCat(host_transfer_key, "_dtoh_0"));
@ -1488,11 +1522,26 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
// Replace original function. // Replace original function if loop_cond_func already has been re-written
// for outside compilation.
FunctionDef replace_fdef; FunctionDef replace_fdef;
if (loop_cond_func->name() == cond_xla_func_name) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); TF_RETURN_IF_ERROR(
fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
} else {
// If original while cond function has not been modified, add a new function
// with send loop predicated added and update the while node callsite
// operation.
const auto new_name = fld->UniqueFunctionName(
absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
loop_cond_func->set_name(new_name);
while_node->ClearAttr("cond");
while_node->AddAttr("cond", *loop_cond_func);
}
return Status::OK(); return Status::OK();
} }
@ -2011,8 +2060,8 @@ Status ExtractOutsideCompilationForWhileNode(
// XLA computation: rewrite cond function to add a SendToHost node to send // XLA computation: rewrite cond function to add a SendToHost node to send
// loop predicate. // loop predicate.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); cond_xla_func_name, host_transfer_key, &cond, fld, n));
n->AddAttr(kXlaTokenInputNodesAttrName, n->AddAttr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName}); std::vector<string>{kXlaTokenArgNodeName});