diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt index 4dd3973537b..d4e62c386bd 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt @@ -484,7 +484,7 @@ node { # CHECK: shape: [ 186 ], # CHECK: type: INT32, # CHECK: buffer: 3, -# CHECK: name: "tfl.pseudo_qconst", +# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases", # CHECK: quantization: { # CHECK: scale: [ 0.027216, 0.00038, 0.000413, 0.000426, 0.001607, # CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -493,7 +493,7 @@ node { # CHECK: shape: [ 186, 1, 1, 256 ], # CHECK: type: INT8, # CHECK: buffer: 4, -# CHECK: name: "tfl.pseudo_qconst1", +# CHECK: name: "BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/weights_quant/FakeQuantWithMinMaxVarsPerChannel", # CHECK: quantization: { # CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431, # CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -502,7 +502,7 @@ node { # CHECK: shape: [ 1, 1, 1, 186 ], # CHECK: type: INT8, # CHECK: buffer: 5, -# CHECK: name: "tfl.conv_2d", +# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases1", # CHECK: quantization: { # CHECK: scale: [ 0.093635 ], # CHECK: zero_point: [ 22 ] diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index c2a06806055..3f75877e906 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir @@ -72,18 +73,34 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { namespace { // Derives name from location. -llvm::StringRef GetNameFromLoc(mlir::Location loc) { +std::string GetNameFromLoc(mlir::Location loc) { if (auto name_loc = loc.dyn_cast()) - return name_loc.getName().strref(); + return name_loc.getName().str(); if (auto call_loc = loc.dyn_cast()) { // Return name if CallSiteLoc's callee has a NameLoc (as should be the case // if imported with DebugInfo), else use the fallback naming scheme below. if (auto name_loc = call_loc.getCallee().dyn_cast()) - return name_loc.getName().strref(); + return name_loc.getName().str(); } - return llvm::StringRef(); + if (auto fused_loc = loc.dyn_cast()) { + llvm::ArrayRef locations = fused_loc.getLocations(); + std::vector names; + bool names_is_nonempty = false; + for (const auto& loc : locations) { + const std::string loc_name = GetNameFromLoc(loc); + names.push_back(loc_name); + if (!loc_name.empty()) { + names_is_nonempty = true; + } + } + if (names_is_nonempty) { + return llvm::join(names.begin(), names.end(), ","); + } + } + + return ""; } } // anonymous namespace