diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 33b8c9ee080..41c8c1e9e68 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -46,14 +46,15 @@ limitations under the License. namespace tensorflow { -Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, - Graph* graph, +// Transformation that converts TensorFlow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); // Functionalize and remove while loops from graph. - TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library)); // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled @@ -66,27 +67,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, return Status::OK(); } -// Transformation that converts TensorFlow's graph control flow constructs into -// functional equivalents. -Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library) { - return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); -} - Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, - FunctionLibraryDefinition* library) { - return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr, - graph_def, library); -} - -Status FunctionalizeControlFlowForGraphDef( - const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, FunctionLibraryDefinition* library) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 22cd422599c..fb35d1b4198 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -25,19 +25,12 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators and tf.cond() conditionals into function If operators, suitable for -// XLA compilation. If lookup_library is provided, use it to make the library -// for control flow self-contained. +// XLA compilation. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); -Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, - Graph* graph, - FunctionLibraryDefinition* library); Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library); -Status FunctionalizeControlFlowForGraphDef( - const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, - FunctionLibraryDefinition* library); // This pass looks at the graph, and turns V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 6c6b6cd1a77..a8e5aa87b3e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -394,18 +394,16 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { TF_ASSERT_OK(scope.ToGraph(&graph)); } - FunctionLibraryDefinition lookup_lib(graph.flib_def()); - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - // Function increment_fn will be copied from lookup_lib to library. + FunctionLibraryDefinition library(graph.flib_def()); GraphDef optimized_graph_def; graph.ToGraphDef(&optimized_graph_def); *(optimized_graph_def.mutable_library()->add_function()) = GetNoinlineFunctionDef(); - TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( - &lookup_lib, &optimized_graph_def, &library)); - TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); GraphDef converted_graph_def; graph.ToGraphDef(&converted_graph_def); @@ -470,14 +468,12 @@ TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { TF_ASSERT_OK(scope.ToGraph(&graph)); } - FunctionLibraryDefinition lookup_lib(graph.flib_def()); - FunctionLibraryDefinition library(OpRegistry::Global(), {}); + FunctionLibraryDefinition library(graph.flib_def()); GraphDef graph_def; graph.ToGraphDef(&graph_def); graph_def.clear_library(); - Status status = - FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library); + Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library); EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 74790f9ee4d..db61a612a81 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -211,58 +211,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, return Status::OK(); } -// Copy the FunctionDef of given function from lookup_library to library, if -// it can be found in lookup_library but is missing from library. -Status AddMissingFunctionByName(const string& function_name, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - if (!library->Find(function_name) && lookup_library->Find(function_name)) { - return library->AddFunctionDef(*lookup_library->Find(function_name)); - } - return Status::OK(); -} - -// Iterate over all functions that the given fdef refers to. Copy the missing -// FunctionDefs from lookup_library to library. -Status AddMissingFunctionDef(const FunctionDef& fdef, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - TF_RET_CHECK(lookup_library); - for (const NodeDef& node : fdef.node_def()) { - if (library->Find(node.op())) { - continue; - } - // The function referred by 'SymbolicGradient' node is specified in its - // attribute 'f'. - if (node.op() == FunctionLibraryDefinition::kGradientOp) { - const AttrValue* attr = - AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); - if (!attr) { - return errors::InvalidArgument("SymbolicGradient is missing attr: f"); - } - const string& func_name = attr->func().name(); - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(func_name, lookup_library, library)); - // Copy the user-defined gradient function if it exists. - const string grad_name = lookup_library->FindGradient(func_name); - if (!grad_name.empty() && library->FindGradient(func_name).empty()) { - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(grad_name, lookup_library, library)); - GradientDef grad_def; - grad_def.set_function_name(func_name); - grad_def.set_gradient_func(grad_name); - TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); - } - } else if (lookup_library->Find(node.op())) { - TF_RETURN_IF_ERROR( - library->AddFunctionDef(*lookup_library->Find(node.op()))); - } - } - return Status::OK(); -} - -Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, WhileLoopFrame* frame, +Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << DumpGraphToFile("functionalize_before", *graph, library); @@ -479,14 +428,6 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); - if (lookup_library) { - // Copy missing FunctionDefs from lookup_library to library to make library - // self-contained. - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(cond_fdef, lookup_library, library)); - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(body_fdef, lookup_library, library)); - } // Builds a While operator. NodeDef while_def; @@ -568,8 +509,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } // namespace -Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, +Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library) { // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this @@ -604,8 +544,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); + TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h index a708c6e4ec4..207b29b8498 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.h +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -24,8 +24,7 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators, suitable for XLA compilation. If lookup_library is provided, use // it to make the library for control flow self-contained. -Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library); } // namespace tensorflow