Add custom FusedLoc naming in OpOrArgNameMapper.
PiperOrigin-RevId: 279430057 Change-Id: I5caf9227285a8dcc8aa4387007a0d882d9d4cb42
This commit is contained in:
parent
755aec33ef
commit
43d77b42e7
@ -484,7 +484,7 @@ node {
|
|||||||
# CHECK: shape: [ 186 ],
|
# CHECK: shape: [ 186 ],
|
||||||
# CHECK: type: INT32,
|
# CHECK: type: INT32,
|
||||||
# CHECK: buffer: 3,
|
# CHECK: buffer: 3,
|
||||||
# CHECK: name: "tfl.pseudo_qconst",
|
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases",
|
||||||
# CHECK: quantization: {
|
# CHECK: quantization: {
|
||||||
# CHECK: scale: [ 0.027216, 0.00038, 0.000413, 0.000426, 0.001607,
|
# CHECK: scale: [ 0.027216, 0.00038, 0.000413, 0.000426, 0.001607,
|
||||||
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
@ -493,7 +493,7 @@ node {
|
|||||||
# CHECK: shape: [ 186, 1, 1, 256 ],
|
# CHECK: shape: [ 186, 1, 1, 256 ],
|
||||||
# CHECK: type: INT8,
|
# CHECK: type: INT8,
|
||||||
# CHECK: buffer: 4,
|
# CHECK: buffer: 4,
|
||||||
# CHECK: name: "tfl.pseudo_qconst1",
|
# CHECK: name: "BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/weights_quant/FakeQuantWithMinMaxVarsPerChannel",
|
||||||
# CHECK: quantization: {
|
# CHECK: quantization: {
|
||||||
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
|
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
|
||||||
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
@ -502,7 +502,7 @@ node {
|
|||||||
# CHECK: shape: [ 1, 1, 1, 186 ],
|
# CHECK: shape: [ 1, 1, 1, 186 ],
|
||||||
# CHECK: type: INT8,
|
# CHECK: type: INT8,
|
||||||
# CHECK: buffer: 5,
|
# CHECK: buffer: 5,
|
||||||
# CHECK: name: "tfl.conv_2d",
|
# CHECK: name: "BoxPredictor_4/ClassPredictor/BiasAdd,BoxPredictor_4/ClassPredictor/Conv2D,BoxPredictor_4/ClassPredictor/biases1",
|
||||||
# CHECK: quantization: {
|
# CHECK: quantization: {
|
||||||
# CHECK: scale: [ 0.093635 ],
|
# CHECK: scale: [ 0.093635 ],
|
||||||
# CHECK: zero_point: [ 22 ]
|
# CHECK: zero_point: [ 22 ]
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/SmallString.h"
|
#include "llvm/ADT/SmallString.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
@ -72,18 +73,34 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Derives name from location.
|
// Derives name from location.
|
||||||
llvm::StringRef GetNameFromLoc(mlir::Location loc) {
|
std::string GetNameFromLoc(mlir::Location loc) {
|
||||||
if (auto name_loc = loc.dyn_cast<mlir::NameLoc>())
|
if (auto name_loc = loc.dyn_cast<mlir::NameLoc>())
|
||||||
return name_loc.getName().strref();
|
return name_loc.getName().str();
|
||||||
|
|
||||||
if (auto call_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
|
if (auto call_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
|
||||||
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
|
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
|
||||||
// if imported with DebugInfo), else use the fallback naming scheme below.
|
// if imported with DebugInfo), else use the fallback naming scheme below.
|
||||||
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
|
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
|
||||||
return name_loc.getName().strref();
|
return name_loc.getName().str();
|
||||||
}
|
}
|
||||||
|
|
||||||
return llvm::StringRef();
|
if (auto fused_loc = loc.dyn_cast<mlir::FusedLoc>()) {
|
||||||
|
llvm::ArrayRef<mlir::Location> locations = fused_loc.getLocations();
|
||||||
|
std::vector<std::string> names;
|
||||||
|
bool names_is_nonempty = false;
|
||||||
|
for (const auto& loc : locations) {
|
||||||
|
const std::string loc_name = GetNameFromLoc(loc);
|
||||||
|
names.push_back(loc_name);
|
||||||
|
if (!loc_name.empty()) {
|
||||||
|
names_is_nonempty = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (names_is_nonempty) {
|
||||||
|
return llvm::join(names.begin(), names.end(), ",");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user