From d88b067ef1928e7afab0ede675ae27514416bff8 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 10 Jan 2020 09:12:48 -0800 Subject: [PATCH] Import quantization stats by using locations Since we have removed the "name" attribute in the tf ops in the tf importer, the quantization stats should be specified by the named location. Since there are chances that op locations are changed over transformations, this pass is only for debugging purpose. PiperOrigin-RevId: 289104435 Change-Id: Ie6ed389b761b71eba4d33779e8588cda0e532d19 --- .../lite/quantization/import_quant_stats_pass.cc | 15 +++++++++++---- .../quantization/tests/import_quant_stats.mlir | 10 +++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 4c4d8f1d9a2..45e87e63475 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -206,10 +206,17 @@ std::unique_ptr> CreateImportQuantStatsPass( std::unique_ptr> CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { auto get_name_func = [](Operation *op) { - if (auto name = op->getAttrOfType("name")) - return name.getValue(); - else - return llvm::StringRef(""); + Location loc = op->getLoc(); + if (auto name = loc.dyn_cast()) { + return name.getName().strref(); + } else if (auto fused_name = loc.dyn_cast()) { + for (auto sub_loc : fused_name.getLocations()) { + if (auto named_sub_loc = sub_loc.dyn_cast()) { + return named_sub_loc.getName().strref(); + } + } + } + return llvm::StringRef(""); }; return CreateImportQuantStatsPass(get_name_func, stats_str); diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir b/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir index e7c4f9a27b2..248ccb265ab 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir @@ -3,7 +3,8 @@ // CHECK-LABEL: import_stats_skip func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf32>,tensor<2xf32>) { - %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)]) return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: "tfl.split" @@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf3 // CHECK-LABEL: import_stats_name func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf32>,tensor<2xf32>) { - %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)]) return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" @@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf3 // CHECK-LABEL: import_stats_name_port func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf32>,tensor<2xf32>) { - %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)]) return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" @@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor // CHECK-LABEL: import_stats_name_regex func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor<2xf32>,tensor<2xf32>) { %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)]) return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"