Prune unreachable functions at the end of TPU rewrite passes.

PiperOrigin-RevId: 247681525
This commit is contained in:
Tong Shen 2019-05-10 14:59:09 -07:00 committed by TensorFlower Gardener
parent aa5110bdb2
commit 951b1d2d1d
7 changed files with 41 additions and 3 deletions

View File

@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
std::map<string, int>{}});
}
bool modified;
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
graph_out.get(), flr, lib_def.get());
graph_out.get(), flr, lib_def.get(), &modified);
if (!s.ok()) return s;
GraphDef graphdef_out;

View File

@ -1691,11 +1691,13 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified) {
if (VLOG_IS_ON(4)) {
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
}
*modified = false;
auto node_name_index = g->BuildNodeNameIndex();
for (auto& iter : clusters) {
string xla_cluster_name = iter.first;
@ -1711,6 +1713,7 @@ Status ExtractOutsideCompilation(
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
host_compute_core, flr, fld, &shape_inference_graphs,
&has_outside_compilation));
*modified |= has_outside_compilation;
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
Node* pivot_node = node_name_index[pivot_name];

View File

@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified);
} // namespace tensorflow

View File

@ -253,6 +253,7 @@ Status FunctionalizeControlFlowPass::Run(
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
bool fld_modified = false;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@ -273,9 +274,16 @@ Status FunctionalizeControlFlowPass::Run(
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
fld_modified = true;
}
}
if (fld_modified) {
TF_RETURN_IF_ERROR(
PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("functionalize_control_flow_after", *graph,
options.flib_def);

View File

@ -733,6 +733,7 @@ Status RearrangeFunctionArgumentPass::Run(
{"XlaLaunch", "function"},
};
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
bool fld_modified = false;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@ -753,8 +754,14 @@ Status RearrangeFunctionArgumentPass::Run(
n->ClearAttr(func_attr);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
fld_modified = true;
}
}
if (fld_modified) {
TF_RETURN_IF_ERROR(
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
}
if (VLOG_IS_ON(4)) {
DumpGraphToFile("rearrange_function_argument_after", *graph,

View File

@ -773,4 +773,17 @@ Status PropagateConstIntoFunctionalNodes(
return Status::OK();
}
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld) {
GraphDef graph_def;
g.ToGraphDef(&graph_def);
FunctionLibraryDefinition reachable_functions =
fld->ReachableDefinitions(graph_def);
for (const string& func_name : fld->ListFunctionNames()) {
if (!reachable_functions.Find(func_name)) {
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@ -197,6 +198,10 @@ Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld);
// Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_