Import quantization stats by using locations
Since we have removed the "name" attribute in the tf ops in the tf importer, the quantization stats should be specified by the named location. Since there are chances that op locations are changed over transformations, this pass is only for debugging purpose. PiperOrigin-RevId: 289104435 Change-Id: Ie6ed389b761b71eba4d33779e8588cda0e532d19
This commit is contained in:
parent
f146ef1740
commit
d88b067ef1
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||
auto get_name_func = [](Operation *op) {
|
||||
if (auto name = op->getAttrOfType<StringAttr>("name"))
|
||||
return name.getValue();
|
||||
else
|
||||
return llvm::StringRef("");
|
||||
Location loc = op->getLoc();
|
||||
if (auto name = loc.dyn_cast<NameLoc>()) {
|
||||
return name.getName().strref();
|
||||
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
|
||||
for (auto sub_loc : fused_name.getLocations()) {
|
||||
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
|
||||
return named_sub_loc.getName().strref();
|
||||
}
|
||||
}
|
||||
}
|
||||
return llvm::StringRef("");
|
||||
};
|
||||
|
||||
return CreateImportQuantStatsPass(get_name_func, stats_str);
|
||||
|
@ -3,7 +3,8 @@
|
||||
|
||||
// CHECK-LABEL: import_stats_skip
|
||||
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.split"
|
||||
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name
|
||||
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name_port
|
||||
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
|
||||
// CHECK-LABEL: import_stats_name_regex
|
||||
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
|
Loading…
Reference in New Issue
Block a user