diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/merge_node_with_function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/merge_node_with_function.pbtxt new file mode 100644 index 00000000000..e86ecbfe41f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/merge_node_with_function.pbtxt @@ -0,0 +1,67 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s + +# This is a stripped down GraphDef of the model from b/175240312. To hit the +# bug, the GraphDef needs to have functions in the library and also a Merge node +# to go into certain part of the functionalization code where it crashes. + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } +} +node { + name: "Switch0" + op: "Switch" + input: "input" + input: "input" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "func0" + op: "func_name" + input: "Switch0:1" +} +node { + name: "Merge" + op: "Merge" + input: "Switch0:1" + input: "Switch0" + input: "^func0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +library { + function { + signature { + name: "func_name" + input_arg { + name: "arg0" + type: DT_BOOL + } + } + ret { + key: "retval0" + value: "arg0" + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 2cda524db21..19d558ad19b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -182,8 +182,6 @@ class NameUniquifier : public OpOrArgNameMapper { Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def, bool restrict_functionalization_to_tpu_nodes) { - TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(*graph, *flib_def)); - // If `restrict_functionalization_to_tpu_nodes` is true let filter function // return true for `_tpu_replicate` nodes, otherwise don't set filter. NodeFilter node_filter = @@ -3386,11 +3384,17 @@ Status SavedModelSignatureDefImporterLite::InitializeGraph( TF_RETURN_IF_ERROR( RunGrappler(const_cast(&meta_graph_def_))); + GraphDef graph_def = meta_graph_def_.graph_def(); + if (import_options.upgrade_legacy) { + TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( + graph_def, graph_->flib_def().default_registry())); + } + GraphConstructorOptions graph_ctor_options; graph_ctor_options.allow_internal_ops = true; graph_ctor_options.add_default_attributes = true; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( - graph_ctor_options, meta_graph_def_.graph_def(), graph_.get())); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph_.get())); // TODO(jpienaar): Remove need to const_cast. if (import_options.upgrade_legacy) { @@ -3732,6 +3736,10 @@ StatusOr ConvertGraphdefToMlir( if (add_default_attributes) { TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); } + if (specs.upgrade_legacy) { + TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( + preprocessed_graphdef, graph.flib_def().default_registry())); + } TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( options, std::move(preprocessed_graphdef), &graph)); return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc index 386bc4d397a..7a944960c58 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc @@ -34,8 +34,8 @@ const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() { return *ops; } -Status GenerateResourceSharedNameIfEmpty(Graph& graph, - FunctionLibraryDefinition& flib_def) { +Status GenerateResourceSharedNameIfEmpty( + GraphDef& gdef, const OpRegistryInterface* default_registry) { auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def, const OpDef& op_def) { if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) { @@ -64,31 +64,37 @@ Status GenerateResourceSharedNameIfEmpty(Graph& graph, return iter->second.s().empty(); }; - // Upgrade nodes in the graph. - for (auto* node : graph.nodes()) { - if (is_resource_op_with_empty_shared_name(node->def(), node->op_def())) { - node->AddAttr("shared_name", node->name()); + FunctionDefLibrary* library = gdef.mutable_library(); + auto flib_def = library ? std::make_unique( + default_registry, *library) + : std::make_unique( + default_registry, FunctionDefLibrary()); + + if (library) { + // Upgrade nodes in the functions. + for (FunctionDef& fdef : *library->mutable_function()) { + auto func_name = fdef.signature().name(); + for (auto& node_def : *fdef.mutable_node_def()) { + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def)); + if (is_resource_op_with_empty_shared_name(node_def, *op_def)) { + // Use the concat of function name and node name for such ops in a + // function as the shared_name. "@" is used as the separator because + // it is not allowed in the function name or the node name. + (*node_def.mutable_attr())["shared_name"].set_s( + absl::StrCat(node_def.name(), "@", func_name)); + } + } } } - // Upgrade nodes in the functions. - auto func_names = flib_def.ListFunctionNames(); - for (const auto& func_name : func_names) { - const FunctionDef* orig = flib_def.Find(func_name); - DCHECK(orig); - auto copy = *orig; - for (auto& node_def : *copy.mutable_node_def()) { - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def)); - if (is_resource_op_with_empty_shared_name(node_def, *op_def)) { - // Use the concat of function name and node name for such ops in a - // function as the shared_name. "@" is used as the separator because it - // is not allowed in the function name or the node name. - (*node_def.mutable_attr())["shared_name"].set_s( - absl::StrCat(node_def.name(), "@", func_name)); - } + // Upgrade nodes in the GraphDef. + for (auto& node_def : *gdef.mutable_node()) { + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def)); + if (is_resource_op_with_empty_shared_name(node_def, *op_def)) { + (*node_def.mutable_attr())["shared_name"].set_s(node_def.name()); } - TF_RETURN_IF_ERROR(flib_def.ReplaceFunction(func_name, copy)); } return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h index f27379a323c..de0674d7f41 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h @@ -27,8 +27,8 @@ class MetaGraphDef; // Generate the shared_name for resource handle ops in the graph and functions // if their shared_names are empty. Resource handle ops with empty shared_name // may have undesired semantics. -Status GenerateResourceSharedNameIfEmpty(Graph& graph, - FunctionLibraryDefinition& flib_def); +Status GenerateResourceSharedNameIfEmpty( + GraphDef& gdef, const OpRegistryInterface* default_registry); // Run grapler passes over `meta_graph_def`.graph_def(), and optimize it in // place.