Fix outside compilation to add new function instead of inplace function update.
PiperOrigin-RevId: 350863196 Change-Id: I70a41fd44028768032af09a71dc3f489ae785b45
This commit is contained in:
parent
ef06907104
commit
e81693934c
@ -565,6 +565,20 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
function_body.graph->RemoveNode(lifted_arg_node);
|
||||
}
|
||||
|
||||
// Adds function def to function definition library and update the function
|
||||
// callsite operation `callsite_node` to invoke new function instead.
|
||||
Status AddFunctionWithNewName(const std::string& new_name,
|
||||
const std::string& func_attr_name,
|
||||
const FunctionDef& function_def,
|
||||
NameAttrList* func_attr, Node* callsite_node,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
|
||||
func_attr->set_name(new_name);
|
||||
callsite_node->ClearAttr(func_attr_name);
|
||||
callsite_node->AddAttr(func_attr_name, *func_attr);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reconnect outside compilation lifted arguments in a functional While node to
|
||||
// its outside compilation tensor sources.
|
||||
Status PostprocessLiftedArgsForWhile(
|
||||
@ -633,12 +647,15 @@ Status PostprocessLiftedArgsForWhile(
|
||||
*body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
|
||||
}
|
||||
|
||||
const auto new_body_function_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
|
||||
FunctionDef rewritten_body_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*body_function_body->graph, body_func.name(), HostGraphControlRetMapping,
|
||||
&rewritten_body_function_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->ReplaceFunction(body_func.name(), rewritten_body_function_def));
|
||||
*body_function_body->graph, new_body_function_name,
|
||||
HostGraphControlRetMapping, &rewritten_body_function_def));
|
||||
TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
|
||||
rewritten_body_function_def,
|
||||
&body_func, n, fld));
|
||||
|
||||
// In cond_graph, just add new _Arg nodes.
|
||||
NameAttrList cond_func;
|
||||
@ -657,13 +674,15 @@ Status PostprocessLiftedArgsForWhile(
|
||||
TF_RETURN_IF_ERROR(arg_node_or.status());
|
||||
}
|
||||
|
||||
const auto new_cond_function_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
|
||||
FunctionDef rewritten_cond_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*cond_function_body->graph, cond_func.name(), HostGraphControlRetMapping,
|
||||
&rewritten_cond_function_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->ReplaceFunction(cond_func.name(), rewritten_cond_function_def));
|
||||
|
||||
*cond_function_body->graph, new_cond_function_name,
|
||||
HostGraphControlRetMapping, &rewritten_cond_function_def));
|
||||
TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
|
||||
rewritten_cond_function_def,
|
||||
&cond_func, n, fld));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -779,19 +798,25 @@ Status PostprocessLiftedArgsForIf(
|
||||
else_branch_lifted_arg_nodes, else_branch_arg_node);
|
||||
}
|
||||
|
||||
const auto new_then_function_name = fld->UniqueFunctionName(
|
||||
absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
|
||||
FunctionDef rewritten_then_branch_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*then_branch_function_body->graph, then_branch_func.name(),
|
||||
*then_branch_function_body->graph, new_then_function_name,
|
||||
HostGraphControlRetMapping, &rewritten_then_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(),
|
||||
rewritten_then_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(AddFunctionWithNewName(
|
||||
new_then_function_name, "then_branch", rewritten_then_branch_function_def,
|
||||
&then_branch_func, n, fld));
|
||||
|
||||
const auto new_else_function_name = fld->UniqueFunctionName(
|
||||
absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
|
||||
FunctionDef rewritten_else_branch_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*else_branch_function_body->graph, else_branch_func.name(),
|
||||
*else_branch_function_body->graph, new_else_function_name,
|
||||
HostGraphControlRetMapping, &rewritten_else_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(),
|
||||
rewritten_else_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(AddFunctionWithNewName(
|
||||
new_else_function_name, "else_branch", rewritten_else_branch_function_def,
|
||||
&else_branch_func, n, fld));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -852,11 +877,19 @@ Status PostprocessLiftedArgsForCall(
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
|
||||
HostGraphControlRetMapping,
|
||||
&rewritten_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef));
|
||||
const auto new_function_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
|
||||
rewritten_fdef.mutable_signature()->set_name(new_function_name);
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
|
||||
|
||||
// We need to recreate the node. Otherwise TF will not know n->num_inputs()
|
||||
// has increased.
|
||||
NodeDef node_def = n->def();
|
||||
|
||||
// Function name is represented via the Op's type. Reset the op type to new
|
||||
// function def name;
|
||||
*node_def.mutable_op() = new_function_name;
|
||||
|
||||
for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
|
||||
Node* outside_compilation_node =
|
||||
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
|
||||
@ -1439,14 +1472,15 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
|
||||
|
||||
// Rewrites loop cond to add a node which sends loop cond to host.
|
||||
TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
|
||||
FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func,
|
||||
const string& while_node_name, const string& host_transfer_key) {
|
||||
const string& cond_xla_func_name, const string& host_transfer_key,
|
||||
NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
|
||||
Node* while_node) {
|
||||
// Instantiate the loop cond function.
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name());
|
||||
const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
|
||||
TF_RET_CHECK(loop_cond_fdef);
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody));
|
||||
*loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
|
||||
Graph* g = fbody->graph;
|
||||
|
||||
// Find the _Retval node and the loop cond node.
|
||||
@ -1455,7 +1489,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
|
||||
if (n->type_string() == "_Retval") {
|
||||
if (ret_node) {
|
||||
return errors::Internal("Multiple return node for loop cond function ",
|
||||
loop_cond_func.name(), ": ",
|
||||
loop_cond_func->name(), ": ",
|
||||
ret_node->DebugString(), " and ",
|
||||
n->DebugString());
|
||||
} else {
|
||||
@ -1465,14 +1499,14 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
|
||||
}
|
||||
if (!ret_node) {
|
||||
return errors::Internal("No _Retval node for loop cond function ",
|
||||
loop_cond_func.name());
|
||||
loop_cond_func->name());
|
||||
}
|
||||
Node* loop_cond;
|
||||
TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
|
||||
|
||||
// Build the XlaSendToHost node.
|
||||
NodeDefBuilder send_loop_cond_builder(
|
||||
absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost");
|
||||
absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
|
||||
send_loop_cond_builder.Attr("Tinput", DT_BOOL);
|
||||
send_loop_cond_builder.Attr("key",
|
||||
absl::StrCat(host_transfer_key, "_dtoh_0"));
|
||||
@ -1488,11 +1522,26 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
|
||||
|
||||
// Replace original function.
|
||||
// Replace original function if loop_cond_func already has been re-written
|
||||
// for outside compilation.
|
||||
FunctionDef replace_fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef));
|
||||
if (loop_cond_func->name() == cond_xla_func_name) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
|
||||
} else {
|
||||
// If original while cond function has not been modified, add a new function
|
||||
// with send loop predicated added and update the while node callsite
|
||||
// operation.
|
||||
const auto new_name = fld->UniqueFunctionName(
|
||||
absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
|
||||
loop_cond_func->set_name(new_name);
|
||||
while_node->ClearAttr("cond");
|
||||
while_node->AddAttr("cond", *loop_cond_func);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2011,8 +2060,8 @@ Status ExtractOutsideCompilationForWhileNode(
|
||||
|
||||
// XLA computation: rewrite cond function to add a SendToHost node to send
|
||||
// loop predicate.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
|
||||
TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
|
||||
cond_xla_func_name, host_transfer_key, &cond, fld, n));
|
||||
n->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user