Add custom FusedLoc naming in OpOrArgNameMapper.
PiperOrigin-RevId: 279430057 Change-Id: I5caf9227285a8dcc8aa4387007a0d882d9d4cb42
This commit is contained in:
parent
755aec33ef
commit
43d77b42e7
tensorflow/compiler/mlir
@ -484,7 +484,7 @@ node {
|
||||
# CHECK: shape: [ 186 ],
|
||||
# CHECK: type: INT32,
|
||||
# CHECK: buffer: 3,
|
||||
# CHECK: name: "tfl.pseudo_qconst",
|
||||
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.027216, 0.00038, 0.000413, 0.000426, 0.001607,
|
||||
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
@ -493,7 +493,7 @@ node {
|
||||
# CHECK: shape: [ 186, 1, 1, 256 ],
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 4,
|
||||
# CHECK: name: "tfl.pseudo_qconst1",
|
||||
# CHECK: name: "BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/weights_quant/FakeQuantWithMinMaxVarsPerChannel",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
|
||||
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
@ -502,7 +502,7 @@ node {
|
||||
# CHECK: shape: [ 1, 1, 1, 186 ],
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 5,
|
||||
# CHECK: name: "tfl.conv_2d",
|
||||
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases1",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.093635 ],
|
||||
# CHECK: zero_point: [ 22 ]
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
@ -72,18 +73,34 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) {
|
||||
|
||||
namespace {
|
||||
// Derives name from location.
|
||||
llvm::StringRef GetNameFromLoc(mlir::Location loc) {
|
||||
std::string GetNameFromLoc(mlir::Location loc) {
|
||||
if (auto name_loc = loc.dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().strref();
|
||||
return name_loc.getName().str();
|
||||
|
||||
if (auto call_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
|
||||
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
|
||||
// if imported with DebugInfo), else use the fallback naming scheme below.
|
||||
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().strref();
|
||||
return name_loc.getName().str();
|
||||
}
|
||||
|
||||
return llvm::StringRef();
|
||||
if (auto fused_loc = loc.dyn_cast<mlir::FusedLoc>()) {
|
||||
llvm::ArrayRef<mlir::Location> locations = fused_loc.getLocations();
|
||||
std::vector<std::string> names;
|
||||
bool names_is_nonempty = false;
|
||||
for (const auto& loc : locations) {
|
||||
const std::string loc_name = GetNameFromLoc(loc);
|
||||
names.push_back(loc_name);
|
||||
if (!loc_name.empty()) {
|
||||
names_is_nonempty = true;
|
||||
}
|
||||
}
|
||||
if (names_is_nonempty) {
|
||||
return llvm::join(names.begin(), names.end(), ",");
|
||||
}
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user