Remove lookup_library from FunctionalizeControlFlow().

It was only used when FunctionalizeControlFlow() was called in XlaCompiler instead of a graph optimization pass, and the purpose was to look for FunctionDefs defined outside XlaCompiler. Now it's not needed any more.

PiperOrigin-RevId: 270338766
This commit is contained in:
Tong Shen 2019-09-20 13:32:17 -07:00 committed by TensorFlower Gardener
parent 316a882856
commit a2a51fa318
5 changed files with 16 additions and 102 deletions

View File

@ -46,14 +46,15 @@ limitations under the License.
namespace tensorflow {
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
// Transformation that converts TensorFlow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
<< DumpGraphToFile("functionalize_initial", *graph, library);
// Functionalize and remove while loops from graph.
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library));
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library));
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
@ -66,27 +67,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
return Status::OK();
}
// Transformation that converts TensorFlow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
FunctionLibraryDefinition* library) {
return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr,
graph_def, library);
}
Status FunctionalizeControlFlowForGraphDef(
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
FunctionLibraryDefinition* library) {
FunctionDefLibrary function_lib = graph_def->library();
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library));
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library));
graph.ToGraphDef(graph_def);
std::swap(*graph_def->mutable_library(), function_lib);
return Status::OK();

View File

@ -25,19 +25,12 @@ namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators and tf.cond() conditionals into function If operators, suitable for
// XLA compilation. If lookup_library is provided, use it to make the library
// for control flow self-contained.
// XLA compilation.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlowForGraphDef(
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
FunctionLibraryDefinition* library);
// This pass looks at the graph, and turns V1 control flow structure
// (Switch/Merge/etc.) into V2 control flow structure (If/While).

View File

@ -394,18 +394,16 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
TF_ASSERT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition lookup_lib(graph.flib_def());
FunctionLibraryDefinition library(OpRegistry::Global(), {});
// Function increment_fn will be copied from lookup_lib to library.
FunctionLibraryDefinition library(graph.flib_def());
GraphDef optimized_graph_def;
graph.ToGraphDef(&optimized_graph_def);
*(optimized_graph_def.mutable_library()->add_function()) =
GetNoinlineFunctionDef();
TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
&lookup_lib, &optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
TF_ASSERT_OK(
FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
GraphDef converted_graph_def;
graph.ToGraphDef(&converted_graph_def);
@ -470,14 +468,12 @@ TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
TF_ASSERT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition lookup_lib(graph.flib_def());
FunctionLibraryDefinition library(OpRegistry::Global(), {});
FunctionLibraryDefinition library(graph.flib_def());
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
graph_def.clear_library();
Status status =
FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library);
Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
}

View File

@ -211,58 +211,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
return Status::OK();
}
// Copy the FunctionDef of given function from lookup_library to library, if
// it can be found in lookup_library but is missing from library.
Status AddMissingFunctionByName(const string& function_name,
const FunctionLibraryDefinition* lookup_library,
FunctionLibraryDefinition* library) {
if (!library->Find(function_name) && lookup_library->Find(function_name)) {
return library->AddFunctionDef(*lookup_library->Find(function_name));
}
return Status::OK();
}
// Iterate over all functions that the given fdef refers to. Copy the missing
// FunctionDefs from lookup_library to library.
Status AddMissingFunctionDef(const FunctionDef& fdef,
const FunctionLibraryDefinition* lookup_library,
FunctionLibraryDefinition* library) {
TF_RET_CHECK(lookup_library);
for (const NodeDef& node : fdef.node_def()) {
if (library->Find(node.op())) {
continue;
}
// The function referred by 'SymbolicGradient' node is specified in its
// attribute 'f'.
if (node.op() == FunctionLibraryDefinition::kGradientOp) {
const AttrValue* attr =
AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
if (!attr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
const string& func_name = attr->func().name();
TF_RETURN_IF_ERROR(
AddMissingFunctionByName(func_name, lookup_library, library));
// Copy the user-defined gradient function if it exists.
const string grad_name = lookup_library->FindGradient(func_name);
if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
TF_RETURN_IF_ERROR(
AddMissingFunctionByName(grad_name, lookup_library, library));
GradientDef grad_def;
grad_def.set_function_name(func_name);
grad_def.set_gradient_func(grad_name);
TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
}
} else if (lookup_library->Find(node.op())) {
TF_RETURN_IF_ERROR(
library->AddFunctionDef(*lookup_library->Find(node.op())));
}
}
return Status::OK();
}
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph, WhileLoopFrame* frame,
Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< DumpGraphToFile("functionalize_before", *graph, library);
@ -479,14 +428,6 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
if (lookup_library) {
// Copy missing FunctionDefs from lookup_library to library to make library
// self-contained.
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(cond_fdef, lookup_library, library));
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(body_fdef, lookup_library, library));
}
// Builds a While operator.
NodeDef while_def;
@ -568,8 +509,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
} // namespace
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
Status FunctionalizeWhileLoop(Graph* graph,
FunctionLibraryDefinition* library) {
// Note: BuildControlFlowInfo() requires that the graph's source node is
// connected to all source nodes in the graph. Many graphs violate this
@ -604,8 +544,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
TF_RETURN_IF_ERROR(
FunctionalizeLoop(lookup_library, graph, frame, library));
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;

View File

@ -24,8 +24,7 @@ namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators, suitable for XLA compilation. If lookup_library is provided, use
// it to make the library for control flow self-contained.
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph, FunctionLibraryDefinition* library);
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library);
} // namespace tensorflow