diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir new file mode 100644 index 00000000000..cd1e022de5e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir @@ -0,0 +1,24 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> +attributes {tf.entry_function = {inputs = "foo,bar", outputs = "Add"}} { + %0 = "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> + %1 = "tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> + // This node would be renamed to bar1 + %2 = "tf.Identity"(%1) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar") + // The following node would be renamed to bar2 + %3 = "tf.Identity"(%2) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar1") + %4 = "tf.Add"(%0, %3) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add") + return %4 : tensor<10xi32> +} + +// CHECK: name: "bar1" +// CHECK-NEXT: op: "Identity" +// CHECK: name: "bar2" +// CHECK-NEXT: op: "Identity" +// CHECK: name: "Add" +// CHECK-NEXT: op: "Add" +// CHECK: name: "foo" +// CHECK-NEXT: op: "Placeholder" +// CHECK: name: "bar" +// CHECK-NEXT: op: "Placeholder" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_identity_n.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_identity_n.mlir index bc4db2ec05f..d6c4d3ecd6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_identity_n.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_identity_n.mlir @@ -1,16 +1,16 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s func @main() -> tensor<2x3xi32> { - %0 = "tf.Const"() {value = dense<5> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) loc("Const0") - %1 = "tf.Const"() {value = dense<4.2> : tensor<4x5xf32>} : () -> (tensor<4x5xf32>) loc("Const1") + %0 = "tf.Const"() {value = dense<5> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) loc("x") + %1 = "tf.Const"() {value = dense<4.2> : tensor<4x5xf32>} : () -> (tensor<4x5xf32>) loc("y") %2:2 = "tf.IdentityN"(%0, %1) : (tensor<2x3xi32>, tensor<4x5xf32>) -> (tensor<2x3xi32>, tensor<4x5xf32>) loc("MyIdentityN") return %2#0 : tensor<2x3xi32> } // CHECK: name: "MyIdentityN" // CHECK-NEXT: op: "IdentityN" -// CHECK-NEXT: input: "Const0" -// CHECK-NEXT: input: "Const1" +// CHECK-NEXT: input: "x" +// CHECK-NEXT: input: "y" // CHECK-NEXT: attr { // CHECK-NEXT: key: "T" // CHECK-NEXT: value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir new file mode 100644 index 00000000000..dcf713c06a8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main() { +^bb0: + // CHECK: name: "foo" + %0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> (tensor) loc("foo") + // CHECK: name: "foo1" + %1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor} : () -> (tensor) loc("foo") + // CHECK: name: "foo2" + %2 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> (tensor) loc("foo1") + // CHECK: name: "Unnamed" + %3 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<3> : tensor} : () -> (tensor) loc("2") + // CHECK: name: "Unnamed1" + %4 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<3> : tensor} : () -> (tensor) loc("3") + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index a2823826fe4..3b59b953a9b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -171,9 +171,16 @@ class Exporter { }; std::string Exporter::UniqueName(llvm::StringRef prefix) { - std::string name = prefix; + // Remove the digits at the end of prefix. + // Otherwise, there would be collision between first instance of prefix = + // "foo1" and the second instance of prefix = "foo". In both cases this + // function would return "foo1". + int i = prefix.size() - 1; + while (i >= 0 && isdigit(prefix[i])) i--; + + std::string name = (i >= 0) ? prefix.substr(0, i + 1) : "Unnamed"; auto& val = name_to_count_[name]; - if (val) name = (prefix + llvm::Twine(val)).str(); + if (val) name = (name + llvm::Twine(val)).str(); ++val; return name; } @@ -330,6 +337,9 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { } for (auto* r : input->getResults()) state.types.push_back(r->getType()); auto* inst = builder.createOperation(state); + // If it is one of the specified input names, then the new + // instruction should have the same name. + op_to_name_[inst].assign(op_to_name_[input]); for (int index = 0, e = input->getNumResults(); index != e; ++index) { input->getResult(index)->replaceAllUsesWith(inst->getResult(index)); } @@ -393,7 +403,8 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); Exporter exporter(graph.get(), tf_dialect); - // Set input and output names. + // Set input and output names and increment the use counter for them to help + // generate unique names. if (!output_names.empty()) { auto term = block.getTerminator(); TF_RET_CHECK(output_names.size() == term->getNumOperands()) @@ -401,12 +412,14 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, << ") != terminator operands (" << term->getNumOperands() << ")"; int i = 0; for (auto it : term->getOperands()) { + exporter.name_to_count_[output_names[i].str()] = 1; exporter.op_to_name_[it->getDefiningOp()] = output_names[i++]; } } if (!input_names.empty()) { TF_RET_CHECK(input_names.size() == block.getNumArguments()); for (auto it : llvm::enumerate(function.getArguments())) { + exporter.name_to_count_[input_names[it.index()].str()] = 1; exporter.op_to_name_[*it.value()->user_begin()] = input_names[it.index()]; } }