[TF2XLA] Look for stack traces in global FunctionLibraryDefinition if they are not found in local FunctionLibraryDefinition

XlaCompiler has two FunctionLibraryDefinitions, the function can be stored in either one.

PiperOrigin-RevId: 354006305
Change-Id: I26d6b8558342148c0851b2760ee37031a87a3855
This commit is contained in:
George Karpenkov 2021-01-26 19:46:50 -08:00 committed by TensorFlower Gardener
parent e8a58dc6a2
commit d18b4969ad
3 changed files with 18 additions and 8 deletions

View File

@ -380,7 +380,8 @@ Status MaybeRewriteWhileNode(
Status MaybeRewriteIfNode(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten,
const FunctionLibraryDefinition* global_fld) {
// This node needs rewrite when either of these is true:
// 1) Tin has DT_RESOURCE which requires rearrange;
// 2) Tout has DT_RESOURCE.
@ -456,8 +457,11 @@ Status MaybeRewriteIfNode(
string new_name =
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(
fld->AddFunctionDef(new_fdef, fld->GetStackTraces(f.name())));
const StackTracesMap& stack_traces =
fld->GetStackTraces(f.name()).empty() && global_fld
? global_fld->GetStackTraces(f.name())
: fld->GetStackTraces(f.name());
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, stack_traces));
// Change node to use rewritten function.
f.set_name(new_name);
@ -506,7 +510,8 @@ Status MaybeRewriteIfNode(
Status RearrangeFunctionArguments(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, FunctionLibraryDefinition* fld) {
Graph* g, FunctionLibraryDefinition* fld,
const FunctionLibraryDefinition* global_fld) {
// Inline StatefulPartitionedCall nodes.
std::vector<Node*> call_nodes;
for (Node* n : g->nodes()) {
@ -535,8 +540,8 @@ Status RearrangeFunctionArguments(
&node_rewritten));
} else if (n->IsIfNode()) {
bool node_rewritten;
TF_RETURN_IF_ERROR(
MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten));
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(get_function_body_fn, g, n, fld,
&node_rewritten, global_fld));
}
}

View File

@ -29,10 +29,13 @@ namespace tensorflow {
// arguments and return values.
// `get_function_body_fn` is used to instantiate FunctionDef.
// `fld` is used to store rewritten functions.
// `global_fld` is used to potentially supply stack traces for functions when
// they are not found in `fld`.
Status RearrangeFunctionArguments(
std::function<Status(const NameAttrList&, const FunctionBody**)>
get_function_body_fn,
Graph* g, FunctionLibraryDefinition* fld);
Graph* g, FunctionLibraryDefinition* fld,
const FunctionLibraryDefinition* global_fld = nullptr);
} // namespace tensorflow

View File

@ -1283,7 +1283,9 @@ Status XlaCompiler::CompileGraph(
[this](const NameAttrList& function, const FunctionBody** fbody) {
return FindFunctionBody(function, fbody);
},
graph.get(), local_flib_def_.get()));
graph.get(), local_flib_def_.get(),
pflr_->GetFunctionLibraryDefinition()));
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,