[TF2XLA] Look for stack traces in global FunctionLibraryDefinition if they are not found in local FunctionLibraryDefinition
XlaCompiler has two FunctionLibraryDefinitions, the function can be stored in either one. PiperOrigin-RevId: 354006305 Change-Id: I26d6b8558342148c0851b2760ee37031a87a3855
This commit is contained in:
parent
e8a58dc6a2
commit
d18b4969ad
@ -380,7 +380,8 @@ Status MaybeRewriteWhileNode(
|
||||
Status MaybeRewriteIfNode(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten,
|
||||
const FunctionLibraryDefinition* global_fld) {
|
||||
// This node needs rewrite when either of these is true:
|
||||
// 1) Tin has DT_RESOURCE which requires rearrange;
|
||||
// 2) Tout has DT_RESOURCE.
|
||||
@ -456,8 +457,11 @@ Status MaybeRewriteIfNode(
|
||||
string new_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->AddFunctionDef(new_fdef, fld->GetStackTraces(f.name())));
|
||||
const StackTracesMap& stack_traces =
|
||||
fld->GetStackTraces(f.name()).empty() && global_fld
|
||||
? global_fld->GetStackTraces(f.name())
|
||||
: fld->GetStackTraces(f.name());
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, stack_traces));
|
||||
|
||||
// Change node to use rewritten function.
|
||||
f.set_name(new_name);
|
||||
@ -506,7 +510,8 @@ Status MaybeRewriteIfNode(
|
||||
Status RearrangeFunctionArguments(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, FunctionLibraryDefinition* fld) {
|
||||
Graph* g, FunctionLibraryDefinition* fld,
|
||||
const FunctionLibraryDefinition* global_fld) {
|
||||
// Inline StatefulPartitionedCall nodes.
|
||||
std::vector<Node*> call_nodes;
|
||||
for (Node* n : g->nodes()) {
|
||||
@ -535,8 +540,8 @@ Status RearrangeFunctionArguments(
|
||||
&node_rewritten));
|
||||
} else if (n->IsIfNode()) {
|
||||
bool node_rewritten;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten));
|
||||
TF_RETURN_IF_ERROR(MaybeRewriteIfNode(get_function_body_fn, g, n, fld,
|
||||
&node_rewritten, global_fld));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,10 +29,13 @@ namespace tensorflow {
|
||||
// arguments and return values.
|
||||
// `get_function_body_fn` is used to instantiate FunctionDef.
|
||||
// `fld` is used to store rewritten functions.
|
||||
// `global_fld` is used to potentially supply stack traces for functions when
|
||||
// they are not found in `fld`.
|
||||
Status RearrangeFunctionArguments(
|
||||
std::function<Status(const NameAttrList&, const FunctionBody**)>
|
||||
get_function_body_fn,
|
||||
Graph* g, FunctionLibraryDefinition* fld);
|
||||
Graph* g, FunctionLibraryDefinition* fld,
|
||||
const FunctionLibraryDefinition* global_fld = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1283,7 +1283,9 @@ Status XlaCompiler::CompileGraph(
|
||||
[this](const NameAttrList& function, const FunctionBody** fbody) {
|
||||
return FindFunctionBody(function, fbody);
|
||||
},
|
||||
graph.get(), local_flib_def_.get()));
|
||||
graph.get(), local_flib_def_.get(),
|
||||
pflr_->GetFunctionLibraryDefinition()));
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "XlaCompiler::CompileGraph: "
|
||||
<< DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
|
||||
|
Loading…
x
Reference in New Issue
Block a user