Remove lookup_library from FunctionalizeControlFlow().
It was only used when FunctionalizeControlFlow() was called in XlaCompiler instead of a graph optimization pass, and the purpose was to look for FunctionDefs defined outside XlaCompiler. Now it's not needed any more. PiperOrigin-RevId: 270338766
This commit is contained in:
parent
316a882856
commit
a2a51fa318
@ -46,14 +46,15 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
// Transformation that converts TensorFlow's graph control flow constructs into
|
||||||
Graph* graph,
|
// functional equivalents.
|
||||||
|
Status FunctionalizeControlFlow(Graph* graph,
|
||||||
FunctionLibraryDefinition* library) {
|
FunctionLibraryDefinition* library) {
|
||||||
VLOG(2) << "FunctionalizeControlFlow (initial): "
|
VLOG(2) << "FunctionalizeControlFlow (initial): "
|
||||||
<< DumpGraphToFile("functionalize_initial", *graph, library);
|
<< DumpGraphToFile("functionalize_initial", *graph, library);
|
||||||
|
|
||||||
// Functionalize and remove while loops from graph.
|
// Functionalize and remove while loops from graph.
|
||||||
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library));
|
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library));
|
||||||
|
|
||||||
// FunctionalizeControlFlow is invoked for every function, so the loops's
|
// FunctionalizeControlFlow is invoked for every function, so the loops's
|
||||||
// bodies and conditionals that were extracted into functions will be handled
|
// bodies and conditionals that were extracted into functions will be handled
|
||||||
@ -66,27 +67,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transformation that converts TensorFlow's graph control flow constructs into
|
|
||||||
// functional equivalents.
|
|
||||||
Status FunctionalizeControlFlow(Graph* graph,
|
|
||||||
FunctionLibraryDefinition* library) {
|
|
||||||
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
||||||
FunctionLibraryDefinition* library) {
|
|
||||||
return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr,
|
|
||||||
graph_def, library);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FunctionalizeControlFlowForGraphDef(
|
|
||||||
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
|
|
||||||
FunctionLibraryDefinition* library) {
|
FunctionLibraryDefinition* library) {
|
||||||
FunctionDefLibrary function_lib = graph_def->library();
|
FunctionDefLibrary function_lib = graph_def->library();
|
||||||
Graph graph(OpRegistry::Global());
|
Graph graph(OpRegistry::Global());
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
|
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
|
||||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library));
|
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library));
|
||||||
graph.ToGraphDef(graph_def);
|
graph.ToGraphDef(graph_def);
|
||||||
std::swap(*graph_def->mutable_library(), function_lib);
|
std::swap(*graph_def->mutable_library(), function_lib);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -25,19 +25,12 @@ namespace tensorflow {
|
|||||||
|
|
||||||
// Transformation that converts tf.while_loop() loops into functional While
|
// Transformation that converts tf.while_loop() loops into functional While
|
||||||
// operators and tf.cond() conditionals into function If operators, suitable for
|
// operators and tf.cond() conditionals into function If operators, suitable for
|
||||||
// XLA compilation. If lookup_library is provided, use it to make the library
|
// XLA compilation.
|
||||||
// for control flow self-contained.
|
|
||||||
Status FunctionalizeControlFlow(Graph* graph,
|
Status FunctionalizeControlFlow(Graph* graph,
|
||||||
FunctionLibraryDefinition* library);
|
FunctionLibraryDefinition* library);
|
||||||
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
|
|
||||||
Graph* graph,
|
|
||||||
FunctionLibraryDefinition* library);
|
|
||||||
|
|
||||||
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
||||||
FunctionLibraryDefinition* library);
|
FunctionLibraryDefinition* library);
|
||||||
Status FunctionalizeControlFlowForGraphDef(
|
|
||||||
const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def,
|
|
||||||
FunctionLibraryDefinition* library);
|
|
||||||
|
|
||||||
// This pass looks at the graph, and turns V1 control flow structure
|
// This pass looks at the graph, and turns V1 control flow structure
|
||||||
// (Switch/Merge/etc.) into V2 control flow structure (If/While).
|
// (Switch/Merge/etc.) into V2 control flow structure (If/While).
|
||||||
|
@ -394,18 +394,16 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
|
|||||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryDefinition lookup_lib(graph.flib_def());
|
FunctionLibraryDefinition library(graph.flib_def());
|
||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
|
||||||
// Function increment_fn will be copied from lookup_lib to library.
|
|
||||||
GraphDef optimized_graph_def;
|
GraphDef optimized_graph_def;
|
||||||
graph.ToGraphDef(&optimized_graph_def);
|
graph.ToGraphDef(&optimized_graph_def);
|
||||||
|
|
||||||
*(optimized_graph_def.mutable_library()->add_function()) =
|
*(optimized_graph_def.mutable_library()->add_function()) =
|
||||||
GetNoinlineFunctionDef();
|
GetNoinlineFunctionDef();
|
||||||
|
|
||||||
TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
|
TF_ASSERT_OK(
|
||||||
&lookup_lib, &optimized_graph_def, &library));
|
FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
|
||||||
TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
|
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
|
||||||
GraphDef converted_graph_def;
|
GraphDef converted_graph_def;
|
||||||
graph.ToGraphDef(&converted_graph_def);
|
graph.ToGraphDef(&converted_graph_def);
|
||||||
|
|
||||||
@ -470,14 +468,12 @@ TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
|
|||||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryDefinition lookup_lib(graph.flib_def());
|
FunctionLibraryDefinition library(graph.flib_def());
|
||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
graph.ToGraphDef(&graph_def);
|
graph.ToGraphDef(&graph_def);
|
||||||
graph_def.clear_library();
|
graph_def.clear_library();
|
||||||
|
|
||||||
Status status =
|
Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
|
||||||
FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library);
|
|
||||||
EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
|
EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,58 +211,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the FunctionDef of given function from lookup_library to library, if
|
Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
|
||||||
// it can be found in lookup_library but is missing from library.
|
|
||||||
Status AddMissingFunctionByName(const string& function_name,
|
|
||||||
const FunctionLibraryDefinition* lookup_library,
|
|
||||||
FunctionLibraryDefinition* library) {
|
|
||||||
if (!library->Find(function_name) && lookup_library->Find(function_name)) {
|
|
||||||
return library->AddFunctionDef(*lookup_library->Find(function_name));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over all functions that the given fdef refers to. Copy the missing
|
|
||||||
// FunctionDefs from lookup_library to library.
|
|
||||||
Status AddMissingFunctionDef(const FunctionDef& fdef,
|
|
||||||
const FunctionLibraryDefinition* lookup_library,
|
|
||||||
FunctionLibraryDefinition* library) {
|
|
||||||
TF_RET_CHECK(lookup_library);
|
|
||||||
for (const NodeDef& node : fdef.node_def()) {
|
|
||||||
if (library->Find(node.op())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// The function referred by 'SymbolicGradient' node is specified in its
|
|
||||||
// attribute 'f'.
|
|
||||||
if (node.op() == FunctionLibraryDefinition::kGradientOp) {
|
|
||||||
const AttrValue* attr =
|
|
||||||
AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
|
|
||||||
if (!attr) {
|
|
||||||
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
|
|
||||||
}
|
|
||||||
const string& func_name = attr->func().name();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddMissingFunctionByName(func_name, lookup_library, library));
|
|
||||||
// Copy the user-defined gradient function if it exists.
|
|
||||||
const string grad_name = lookup_library->FindGradient(func_name);
|
|
||||||
if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddMissingFunctionByName(grad_name, lookup_library, library));
|
|
||||||
GradientDef grad_def;
|
|
||||||
grad_def.set_function_name(func_name);
|
|
||||||
grad_def.set_gradient_func(grad_name);
|
|
||||||
TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
|
|
||||||
}
|
|
||||||
} else if (lookup_library->Find(node.op())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
library->AddFunctionDef(*lookup_library->Find(node.op())));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
|
||||||
Graph* graph, WhileLoopFrame* frame,
|
|
||||||
FunctionLibraryDefinition* library) {
|
FunctionLibraryDefinition* library) {
|
||||||
VLOG(2) << "Frame " << frame->name << " before: "
|
VLOG(2) << "Frame " << frame->name << " before: "
|
||||||
<< DumpGraphToFile("functionalize_before", *graph, library);
|
<< DumpGraphToFile("functionalize_before", *graph, library);
|
||||||
@ -479,14 +428,6 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
|||||||
|
|
||||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
|
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
|
||||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
|
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
|
||||||
if (lookup_library) {
|
|
||||||
// Copy missing FunctionDefs from lookup_library to library to make library
|
|
||||||
// self-contained.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddMissingFunctionDef(cond_fdef, lookup_library, library));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddMissingFunctionDef(body_fdef, lookup_library, library));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builds a While operator.
|
// Builds a While operator.
|
||||||
NodeDef while_def;
|
NodeDef while_def;
|
||||||
@ -568,8 +509,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
Status FunctionalizeWhileLoop(Graph* graph,
|
||||||
Graph* graph,
|
|
||||||
FunctionLibraryDefinition* library) {
|
FunctionLibraryDefinition* library) {
|
||||||
// Note: BuildControlFlowInfo() requires that the graph's source node is
|
// Note: BuildControlFlowInfo() requires that the graph's source node is
|
||||||
// connected to all source nodes in the graph. Many graphs violate this
|
// connected to all source nodes in the graph. Many graphs violate this
|
||||||
@ -604,8 +544,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
|
||||||
FunctionalizeLoop(lookup_library, graph, frame, library));
|
|
||||||
|
|
||||||
// If the parent has no remaining children, add it to the worklist.
|
// If the parent has no remaining children, add it to the worklist.
|
||||||
--frame->parent->num_children;
|
--frame->parent->num_children;
|
||||||
|
@ -24,8 +24,7 @@ namespace tensorflow {
|
|||||||
// Transformation that converts tf.while_loop() loops into functional While
|
// Transformation that converts tf.while_loop() loops into functional While
|
||||||
// operators, suitable for XLA compilation. If lookup_library is provided, use
|
// operators, suitable for XLA compilation. If lookup_library is provided, use
|
||||||
// it to make the library for control flow self-contained.
|
// it to make the library for control flow self-contained.
|
||||||
Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library);
|
||||||
Graph* graph, FunctionLibraryDefinition* library);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user