diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 94ddf76736e..51f63741da4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -177,7 +177,8 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def, restrict_functionalization_to_tpu_nodes ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); } : NodeFilter{}; - return FunctionalizeControlFlow(graph, flib_def, node_filter); + return FunctionalizeControlFlow(graph, flib_def, node_filter, + /*include_functions=*/true); } // Stateful helper class to import a TensorFlow model into an MLIR Module. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 10b26f9801c..596fa8e8e38 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -46,12 +46,254 @@ limitations under the License. namespace tensorflow { +// Helper functions for functionalizing control flow in functions. + +// Maps function name to +// - new function name, if the function body was functionalized +// - absl::nullopt, if not +using FuncMap = std::map>; +using FuncMapIter = std::map>::const_iterator; + +// Returns whether function has been processed before. +bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) { + return func_iter != func_map->end(); +} + +// Returns whether function has been modified (i.e., functionalized) before. +bool FunctionHasBeenModified(FuncMapIter func_iter) { + return func_iter->second.has_value(); +} + +// Returns a name for the new functionalized version of a function. +string GetNewFunctionName( + const string& func_name, Node* n, + AssociatedFunctionInfo::AssociatedFunctionType func_type, + FunctionLibraryDefinition* fld) { + // For SymbolicGradient, `func_name` is always "SymbolicGradient" which + // is not very informative. Use node name instead. + return ( + func_type == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient + ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")) + : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_"))); +} + +// Returns name to which a modified function has been mapped. +const string& GetMappedFunctionName(FuncMapIter func_iter) { + DCHECK(func_iter->second.has_value()); + return func_iter->second.value(); +} + +// Updates `func_map` with function given by `canonicalized_name`. +void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, + const string& new_func_name, bool function_modified) { + // If function was modified store its new name, otherwise add empty entry to + // record that function has been processed and does not need to be rewritten. + (*func_map)[canonicalized_name] = + function_modified ? absl::make_optional(new_func_name) : absl::nullopt; +} + +// Adds new function def to graph's function library if necessary. +Status AddFunctionDefToGraphLibrary( + const string& func_name, const AssociatedFunctionInfo& associated_function, + Graph* graph, FunctionLibraryDefinition* fld) { + const OpRegistrationData* op_reg_data; + // We have to be careful with adding the function def since there are three + // different `OpRegistryInterface`s involved here: + // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`. + // We have already added the function def to `fld` before calling this + // function but for the subsequent `RewriteAssociatedFunction` call we need + // the function def to be in one of the other two registries, otherwise + // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case + // because it cannot find the associated function def. + // On the other hand, we should not add the function def if it is already + // contained in one of the last two registries, this would lead to errors when + // the function def is already in one registry and we try to add it to the + // other one (if we try to add it to the same it's fine). This can happen in + // cases where one of the last two registries is identical to `fld` (which we + // already updated). + // Therefore, before adding the function def we have to check if it's already + // contained in either `graph->flib_def()` or + // `graph->flib_def().default_registry()` which is done in the following line + // (we have to use `LookUp` instead of `Contains` or `Find` because the latter + // both don't check the default registry). + if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) + return Status::OK(); + + const FunctionDef* new_fdef = fld->Find(func_name); + DCHECK(new_fdef != nullptr); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = *new_fdef; + return graph->AddFunctionLibrary(fdef_lib); +} + +// Functionalizes function given by `func_name`. Update `func_map` accordingly. +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, + const NodeFilter& node_filter = {}); + +// Functionalizes all functions that are (directly or indirectly) associated to +// any node in `graph`. Adds processed functions to `func_map`. +Status FunctionalizeControlFlowForNodeAssociatedFunctions( + FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld, + FunctionLibraryRuntime* flr, bool* any_function_modified, + const NodeFilter& node_filter) { + std::vector>> + nodes_to_associated_functions; + for (auto* n : graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (const auto& pair : nodes_to_associated_functions) { + Node* n = pair.first; + auto associated_functions = pair.second; + for (auto& associated_function : associated_functions) { + // Note that if `n` is a function call node, then potential calls of + // `RewriteAssociatedFunction` below might delete `n` and create a new + // node instead, making `n` an invalid pointer. That's fine because in + // that case `n` only has one associated function, so this loop has only + // one iteration and we don't use `n` again after the rewrite. + // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed + // below. + DCHECK(associated_function.type() != + AssociatedFunctionInfo::kFunctionCallNode || + associated_functions.size() == 1); + + // Process one node-function-pair. + string func_name = associated_function.func_name(); + string canonicalized_name = + Canonicalize(func_name, AttrSlice(&associated_function.attrs())); + auto func_iter = func_map->find(canonicalized_name); + string new_func_name; + if (FunctionHasBeenProcessed(func_iter, func_map)) { + if (FunctionHasBeenModified(func_iter)) { + *any_function_modified = true; + new_func_name = GetMappedFunctionName(func_iter); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + continue; + } + // Function is processed for the first time. + bool function_modified = false; + new_func_name = + GetNewFunctionName(func_name, n, associated_function.type(), fld); + // Perform functionalization for current function. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func_name, new_func_name, associated_function.attrs(), fld, flr, + func_map, &function_modified, node_filter)); + UpdateFunctionMap(func_map, canonicalized_name, new_func_name, + function_modified); + if (function_modified) { + *any_function_modified = true; + TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary( + new_func_name, associated_function, graph, fld)); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + } + } + return Status::OK(); +} + +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) { + *function_modified = false; + + // Convert the function to a graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + // Skip nodes that are filtered out. + if (node_filter && !node_filter(n)) continue; + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // Before functionalizing control flow in `g` we functionalize control flow + // in functions (directly or indirectly) associated with nodes in `g`. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + func_map, g, fld, flr, function_modified, node_filter)); + + if (has_switch_or_merge) { + *function_modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter)); + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } + } + if (*function_modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + } + + return ret_status; +} + Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); + if (include_functions) { + // Functionalize control flow in functions that are (directly or indirectly) + // associated with a node in `graph`. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library, + tensorflow::OptimizerOptions()); + // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice` + // (because we constructed it with `device_mgr = nullptr`). + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + FuncMap func_map; + bool modified = false; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + &func_map, graph, library, flr, &modified, node_filter)); + } // Functionalize and remove while loops from graph. TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter)); @@ -68,153 +310,19 @@ Status FunctionalizeControlFlow(Graph* graph, Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter, + include_functions)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); return Status::OK(); } -Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map>* canonicalized_name_to_new_name, - bool* modified) { - *modified = false; - - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - Graph* g = body->graph; - - // Check if the graph has Switch or Merge node. - bool has_switch_or_merge = false; - for (Node* n : body->graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // We cannot return here directly if the graph has no Switch/Merge. - // It might contain function call nodes, or If/While nodes with Switch/Merge - // in function body. We still need to rewrite those functions and modify - // corresponding nodes. - - // If any node has associated functions, functionalize them first. - // Gather nodes with associated functions first, because rewriting those nodes - // might involve node deletion/addition. Avoid modifying nodes while iterating - // it. - std::vector>> - nodes_to_associated_functions; - for (auto* n : g->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, fld); - if (!associated_functions.empty()) { - nodes_to_associated_functions.push_back({n, associated_functions}); - } - } - for (const auto& iter : nodes_to_associated_functions) { - Node* n = iter.first; - auto associated_functions = iter.second; - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = - Canonicalize(name, AttrSlice(&associated_function.attrs())); - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - string new_name; - bool function_modified; - if (iter != canonicalized_name_to_new_name->end()) { - // If we already processed this function, check if it was rewritten. If - // the function was rewritten, the entry will be non-empty. Otherwise - // the entry will be empty. - function_modified = iter->second.has_value(); - if (function_modified) { - new_name = iter->second.value(); - } - } else { - if (associated_function.type() == - AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { - // For SymbolicGradient, `name` is always "SymbolicGradient", - // which is not very informative. Use node name instead. - new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); - } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name, &function_modified)); - if (function_modified) { - // If the function was rewritten, add an non-empty entry. So later we - // know we have processed this function, and it was rewritten into - // another function. - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - } else { - // If the function was not rewritten, add an empty entry. So later - // we know we have processed this function, and it does not need to be - // rewritten. - (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; - } - } - if (function_modified) { - *modified = true; - - // Notice that if "n" is a function call, RewriteAssociatedFunction() - // will delete it and create a new node instead, making "n" an invalid - // pointer. That's fine because in that case, associated_functions will - // only have one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - g, n, fld, associated_function, new_name)); - } - } - } - - if (has_switch_or_merge) { - *modified = true; - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *g, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, - fld); - } - } - - if (*modified) { - // Add rewritten FunctionDef into library. - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } - } - - return ret_status; -} - Status FunctionalizeControlFlowForXlaPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); @@ -241,7 +349,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; - std::map> canonicalized_name_to_new_name; + FuncMap func_map; bool fld_modified = false; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); @@ -258,7 +366,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); + &func_map, &modified)); if (modified) { n->ClearAttr(func_attr); func.set_name(new_func_name); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f9e751e2d67..46abae27878 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -30,6 +30,13 @@ namespace tensorflow { // // If `node_filter` is defined, then only loops and conditions for whose // nodes `node_filter` returns true are functionalized. + +// If `include_functions` is true, then loops and conditions inside of functions +// that are associated with nodes in `graph` (e.g., a function called from a +// node in `graph`) are also functionalized, otherwise they are not. +// This also handles transitive cases, e.g., a function body will be +// functionalized when it is called in another function that is called by some +// node in `graph` (and so on). The node filter also applies here. // // Precondition: // For any node in a loop or condition for which `node_filter` returns true, @@ -43,11 +50,13 @@ namespace tensorflow { // satisfies the above conditions. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); // This pass looks at the graph, and turns V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 79a042ad680..951ebdd7ec1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -27,12 +27,15 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -63,18 +66,41 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), // lambda: math_ops.add(x, 23)) // -// Tests different node filters. -class ConditionalTestFixture : public ::testing::TestWithParam { +// Tests different node filters and functionalization inside of a function. +class ConditionalTestFixture + : public ::testing::TestWithParam> { protected: - void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); } + void SetUp() override { + restrict_to_tpu_nodes_ = std::get<0>(GetParam()); + wrap_condition_in_function_ = std::get<1>(GetParam()); + } void RunTest(); private: + void BuildCondGraph(Graph* cond_graph); + void CheckGraphDef(const GraphDef& graph_def, + const FunctionLibraryDefinition& library); + bool restrict_to_tpu_nodes_ = false; + bool wrap_condition_in_function_ = false; }; -void ConditionalTestFixture::RunTest() { - Graph graph(OpRegistry::Global()); +TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } + +INSTANTIATE_TEST_SUITE_P( + FunctionalizeControlFlow, ConditionalTestFixture, + ::testing::Combine(::testing::Bool(), ::testing::Bool()), + [](const ::testing::TestParamInfo& + info) { + bool restrict_to_tpu_nodes = std::get<0>(info.param); + bool wrap_cond_in_function = std::get<1>(info.param); + string name = + absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter", + wrap_cond_in_function ? "_in_function" : "_in_graph"); + return name; + }); + +void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -102,13 +128,117 @@ void ConditionalTestFixture::RunTest() { auto merge = ops::Merge(scope.WithOpName("cond/Merge"), std::initializer_list{add, mul}); - TF_EXPECT_OK(scope.ToGraph(&graph)); + TF_EXPECT_OK(scope.ToGraph(cond_graph)); // Set `_tpu_replicate` attribute for all nodes. - for (Node* n : graph.nodes()) { + for (Node* n : cond_graph->nodes()) { n->AddAttr("_tpu_replicate", "cluster"); } } +} + +void ConditionalTestFixture::CheckGraphDef( + const GraphDef& graph_def, const FunctionLibraryDefinition& library) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = + ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, then_fn, + else_fn, ops::If::OutputShapes({PartialTensorShape()})); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +void ConditionalTestFixture::RunTest() { + Graph graph(OpRegistry::Global()); + if (wrap_condition_in_function_) { + // Wrap condition in a function which is called from `graph`. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + + Graph cond_graph(OpRegistry::Global()); + BuildCondGraph(&cond_graph); + + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef)); + + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = cond_fdef; + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + NodeDef cond_fn; + cond_fn.set_name("cond_node"); + cond_fn.set_op("cond_fn"); + *(cond_fn.add_input()) = "source"; + Status status; + scope.graph()->AddNode(cond_fn, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } else { + // Build condition in `graph`. + BuildCondGraph(&graph); + } + FunctionLibraryDefinition library(graph.flib_def()); // If `restrict_to_tpu_nodes_` is true let filter function return true for // `_tpu_replicate` nodes. NodeFilter node_filter = @@ -116,99 +246,47 @@ void ConditionalTestFixture::RunTest() { ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } : NodeFilter{}; - FunctionLibraryDefinition library(OpRegistry::Global(), {}); GraphDef optimized_graph_def; graph.ToGraphDef(&optimized_graph_def); - TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def, - &library, node_filter)); - TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter)); - GraphDef converted_graph_def; - graph.ToGraphDef(&converted_graph_def); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &optimized_graph_def, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); + TF_ASSERT_OK(FunctionalizeControlFlow( + &graph, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); - for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + if (wrap_condition_in_function_) { + // Check if function body was functionalized. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library, + tensorflow::OptimizerOptions()); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + FunctionLibraryRuntime::Handle handle; - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = - ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, then_fn, - else_fn, ops::If::OutputShapes({PartialTensorShape()})); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Functionalized function name is the type string of `cond_node`. + string func_name; + for (Node* n : graph.nodes()) { + if (n->name() == "cond_node") { + func_name = n->type_string(); + break; + } } + TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle)); + const FunctionBody* body = flr->GetFunctionBody(handle); + GraphDef graph_def; + body->graph->ToGraphDef(&graph_def); + CheckGraphDef(graph_def, library); + } else { + // Check if graphs were functionalized. + CheckGraphDef(optimized_graph_def, library); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + CheckGraphDef(converted_graph_def, library); } } -TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } - -INSTANTIATE_TEST_SUITE_P( - FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(), - [](const ::testing::TestParamInfo& - info) { return info.param ? "with_filter" : "without_filter"; }); - // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,