Add custom FusedLoc naming in OpOrArgNameMapper.

PiperOrigin-RevId: 279430057
Change-Id: I5caf9227285a8dcc8aa4387007a0d882d9d4cb42
This commit is contained in:
Abdurrahman Akkas 2019-11-08 18:13:56 -08:00 committed by TensorFlower Gardener
parent 755aec33ef
commit 43d77b42e7
2 changed files with 24 additions and 7 deletions
tensorflow/compiler/mlir

View File

@ -484,7 +484,7 @@ node {
# CHECK: shape: [ 186 ],
# CHECK: type: INT32,
# CHECK: buffer: 3,
# CHECK: name: "tfl.pseudo_qconst",
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases",
# CHECK: quantization: {
# CHECK: scale: [ 0.027216, 0.00038, 0.000413, 0.000426, 0.001607,
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@ -493,7 +493,7 @@ node {
# CHECK: shape: [ 186, 1, 1, 256 ],
# CHECK: type: INT8,
# CHECK: buffer: 4,
# CHECK: name: "tfl.pseudo_qconst1",
# CHECK: name: "BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/weights_quant/FakeQuantWithMinMaxVarsPerChannel",
# CHECK: quantization: {
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
@ -502,7 +502,7 @@ node {
# CHECK: shape: [ 1, 1, 1, 186 ],
# CHECK: type: INT8,
# CHECK: buffer: 5,
# CHECK: name: "tfl.conv_2d",
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases1",
# CHECK: quantization: {
# CHECK: scale: [ 0.093635 ],
# CHECK: zero_point: [ 22 ]

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
@ -72,18 +73,34 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) {
namespace {
// Derives name from location.
llvm::StringRef GetNameFromLoc(mlir::Location loc) {
std::string GetNameFromLoc(mlir::Location loc) {
if (auto name_loc = loc.dyn_cast<mlir::NameLoc>())
return name_loc.getName().strref();
return name_loc.getName().str();
if (auto call_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
// if imported with DebugInfo), else use the fallback naming scheme below.
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
return name_loc.getName().strref();
return name_loc.getName().str();
}
return llvm::StringRef();
if (auto fused_loc = loc.dyn_cast<mlir::FusedLoc>()) {
llvm::ArrayRef<mlir::Location> locations = fused_loc.getLocations();
std::vector<std::string> names;
bool names_is_nonempty = false;
for (const auto& loc : locations) {
const std::string loc_name = GetNameFromLoc(loc);
names.push_back(loc_name);
if (!loc_name.empty()) {
names_is_nonempty = true;
}
}
if (names_is_nonempty) {
return llvm::join(names.begin(), names.end(), ",");
}
}
return "";
}
} // anonymous namespace