Prune unreachable functions at the end of TPU rewrite passes.
PiperOrigin-RevId: 247681525
This commit is contained in:
parent
aa5110bdb2
commit
951b1d2d1d
@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
|
||||
std::map<string, int>{}});
|
||||
}
|
||||
bool modified;
|
||||
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
|
||||
graph_out.get(), flr, lib_def.get());
|
||||
graph_out.get(), flr, lib_def.get(), &modified);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
GraphDef graphdef_out;
|
||||
|
@ -1691,11 +1691,13 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified) {
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
|
||||
}
|
||||
|
||||
*modified = false;
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
for (auto& iter : clusters) {
|
||||
string xla_cluster_name = iter.first;
|
||||
@ -1711,6 +1713,7 @@ Status ExtractOutsideCompilation(
|
||||
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
|
||||
host_compute_core, flr, fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
*modified |= has_outside_compilation;
|
||||
|
||||
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
|
||||
Node* pivot_node = node_name_index[pivot_name];
|
||||
|
@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
bool* modified);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -253,6 +253,7 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
bool fld_modified = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
|
||||
if (it == kNodeTypeToFunctionAttrMapping->end()) {
|
||||
@ -273,9 +274,16 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
n->ClearAttr(func_attr);
|
||||
func.set_name(new_func_name);
|
||||
n->AddAttr(func_attr, func);
|
||||
|
||||
fld_modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (fld_modified) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("functionalize_control_flow_after", *graph,
|
||||
options.flib_def);
|
||||
|
@ -733,6 +733,7 @@ Status RearrangeFunctionArgumentPass::Run(
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
|
||||
bool fld_modified = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
|
||||
if (it == kNodeTypeToFunctionAttrMapping->end()) {
|
||||
@ -753,8 +754,14 @@ Status RearrangeFunctionArgumentPass::Run(
|
||||
n->ClearAttr(func_attr);
|
||||
func.set_name(new_func_name);
|
||||
n->AddAttr(func_attr, func);
|
||||
|
||||
fld_modified = true;
|
||||
}
|
||||
}
|
||||
if (fld_modified) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile("rearrange_function_argument_after", *graph,
|
||||
|
@ -773,4 +773,17 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
GraphDef graph_def;
|
||||
g.ToGraphDef(&graph_def);
|
||||
FunctionLibraryDefinition reachable_functions =
|
||||
fld->ReachableDefinitions(graph_def);
|
||||
for (const string& func_name : fld->ListFunctionNames()) {
|
||||
if (!reachable_functions.Find(func_name)) {
|
||||
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -197,6 +198,10 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
Graph* g, const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
// Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
|
||||
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
Loading…
Reference in New Issue
Block a user