Remove lookup_library from FunctionalizeControlFlow().
It was only used when FunctionalizeControlFlow() was called in XlaCompiler instead of a graph optimization pass, and the purpose was to look for FunctionDefs defined outside XlaCompiler. Now it's not needed any more. PiperOrigin-RevId: 270338766
This commit is contained in:
parent
316a882856
commit
a2a51fa318
@ -46,14 +46,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph,
|
||||
// Transformation that converts TensorFlow's graph control flow constructs into
|
||||
// functional equivalents.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "FunctionalizeControlFlow (initial): "
|
||||
<< DumpGraphToFile("functionalize_initial", *graph, library);
|
||||
|
||||
// Functionalize and remove while loops from graph.
|
||||
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library));
|
||||
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library));
|
||||
|
||||
// FunctionalizeControlFlow is invoked for every function, so the loops's
|
||||
// bodies and conditionals that were extracted into functions will be handled
|
||||
@ -66,27 +67,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Transformation that converts TensorFlow's graph control flow constructs into
|
||||
// functional equivalents.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
|
||||
}
|
||||
|
||||
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
||||
FunctionLibraryDefinition* library) {
|
||||
return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr,
|
||||
graph_def, library);
|
||||
}
|
||||
|
||||
Status FunctionalizeControlFlowForGraphDef(
|
||||
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
|
||||
FunctionLibraryDefinition* library) {
|
||||
FunctionDefLibrary function_lib = graph_def->library();
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
|
||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library));
|
||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library));
|
||||
graph.ToGraphDef(graph_def);
|
||||
std::swap(*graph_def->mutable_library(), function_lib);
|
||||
return Status::OK();
|
||||
|
@ -25,19 +25,12 @@ namespace tensorflow {
|
||||
|
||||
// Transformation that converts tf.while_loop() loops into functional While
|
||||
// operators and tf.cond() conditionals into function If operators, suitable for
|
||||
// XLA compilation. If lookup_library is provided, use it to make the library
|
||||
// for control flow self-contained.
|
||||
// XLA compilation.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library);
|
||||
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
||||
FunctionLibraryDefinition* library);
|
||||
Status FunctionalizeControlFlowForGraphDef(
|
||||
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
// This pass looks at the graph, and turns V1 control flow structure
|
||||
// (Switch/Merge/etc.) into V2 control flow structure (If/While).
|
||||
|
@ -394,18 +394,16 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition lookup_lib(graph.flib_def());
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
// Function increment_fn will be copied from lookup_lib to library.
|
||||
FunctionLibraryDefinition library(graph.flib_def());
|
||||
GraphDef optimized_graph_def;
|
||||
graph.ToGraphDef(&optimized_graph_def);
|
||||
|
||||
*(optimized_graph_def.mutable_library()->add_function()) =
|
||||
GetNoinlineFunctionDef();
|
||||
|
||||
TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
|
||||
&lookup_lib, &optimized_graph_def, &library));
|
||||
TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
|
||||
TF_ASSERT_OK(
|
||||
FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
|
||||
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
|
||||
GraphDef converted_graph_def;
|
||||
graph.ToGraphDef(&converted_graph_def);
|
||||
|
||||
@ -470,14 +468,12 @@ TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition lookup_lib(graph.flib_def());
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
FunctionLibraryDefinition library(graph.flib_def());
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
graph_def.clear_library();
|
||||
|
||||
Status status =
|
||||
FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library);
|
||||
Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
|
||||
EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
|
||||
}
|
||||
|
||||
|
@ -211,58 +211,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copy the FunctionDef of given function from lookup_library to library, if
|
||||
// it can be found in lookup_library but is missing from library.
|
||||
Status AddMissingFunctionByName(const string& function_name,
|
||||
const FunctionLibraryDefinition* lookup_library,
|
||||
FunctionLibraryDefinition* library) {
|
||||
if (!library->Find(function_name) && lookup_library->Find(function_name)) {
|
||||
return library->AddFunctionDef(*lookup_library->Find(function_name));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Iterate over all functions that the given fdef refers to. Copy the missing
|
||||
// FunctionDefs from lookup_library to library.
|
||||
Status AddMissingFunctionDef(const FunctionDef& fdef,
|
||||
const FunctionLibraryDefinition* lookup_library,
|
||||
FunctionLibraryDefinition* library) {
|
||||
TF_RET_CHECK(lookup_library);
|
||||
for (const NodeDef& node : fdef.node_def()) {
|
||||
if (library->Find(node.op())) {
|
||||
continue;
|
||||
}
|
||||
// The function referred by 'SymbolicGradient' node is specified in its
|
||||
// attribute 'f'.
|
||||
if (node.op() == FunctionLibraryDefinition::kGradientOp) {
|
||||
const AttrValue* attr =
|
||||
AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
|
||||
if (!attr) {
|
||||
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
|
||||
}
|
||||
const string& func_name = attr->func().name();
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddMissingFunctionByName(func_name, lookup_library, library));
|
||||
// Copy the user-defined gradient function if it exists.
|
||||
const string grad_name = lookup_library->FindGradient(func_name);
|
||||
if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddMissingFunctionByName(grad_name, lookup_library, library));
|
||||
GradientDef grad_def;
|
||||
grad_def.set_function_name(func_name);
|
||||
grad_def.set_gradient_func(grad_name);
|
||||
TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
|
||||
}
|
||||
} else if (lookup_library->Find(node.op())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
library->AddFunctionDef(*lookup_library->Find(node.op())));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph, WhileLoopFrame* frame,
|
||||
Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "Frame " << frame->name << " before: "
|
||||
<< DumpGraphToFile("functionalize_before", *graph, library);
|
||||
@ -479,14 +428,6 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
|
||||
if (lookup_library) {
|
||||
// Copy missing FunctionDefs from lookup_library to library to make library
|
||||
// self-contained.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddMissingFunctionDef(cond_fdef, lookup_library, library));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddMissingFunctionDef(body_fdef, lookup_library, library));
|
||||
}
|
||||
|
||||
// Builds a While operator.
|
||||
NodeDef while_def;
|
||||
@ -568,8 +509,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph,
|
||||
Status FunctionalizeWhileLoop(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// Note: BuildControlFlowInfo() requires that the graph's source node is
|
||||
// connected to all source nodes in the graph. Many graphs violate this
|
||||
@ -604,8 +544,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionalizeLoop(lookup_library, graph, frame, library));
|
||||
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
|
||||
|
||||
// If the parent has no remaining children, add it to the worklist.
|
||||
--frame->parent->num_children;
|
||||
|
@ -24,8 +24,7 @@ namespace tensorflow {
|
||||
// Transformation that converts tf.while_loop() loops into functional While
|
||||
// operators, suitable for XLA compilation. If lookup_library is provided, use
|
||||
// it to make the library for control flow self-contained.
|
||||
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph, FunctionLibraryDefinition* library);
|
||||
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user