Import quantization stats by using locations

Since we have removed the "name" attribute in the tf ops in the tf importer,
the quantization stats should be specified by the named location.

Since there are chances that op locations are changed over transformations,
this pass is only for debugging purpose.

PiperOrigin-RevId: 289104435
Change-Id: Ie6ed389b761b71eba4d33779e8588cda0e532d19
This commit is contained in:
Feng Liu 2020-01-10 09:12:48 -08:00 committed by TensorFlower Gardener
parent f146ef1740
commit d88b067ef1
2 changed files with 18 additions and 7 deletions

View File

@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OpPassBase<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name"))
return name.getValue();
else
return llvm::StringRef("");
Location loc = op->getLoc();
if (auto name = loc.dyn_cast<NameLoc>()) {
return name.getName().strref();
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
for (auto sub_loc : fused_name.getLocations()) {
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
return named_sub_loc.getName().strref();
}
}
}
return llvm::StringRef("");
};
return CreateImportQuantStatsPass(get_name_func, stats_str);

View File

@ -3,7 +3,8 @@
// CHECK-LABEL: import_stats_skip
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: "tfl.split"
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name_port
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
// CHECK-LABEL: import_stats_name_regex
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"