From 5a1377938c0c909ebd483bed37f6553a401fa9d6 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 30 Jan 2020 12:22:52 -0800 Subject: [PATCH] Add a debug flag to the quantization driver and also avoid quantization on non-float types PiperOrigin-RevId: 292400638 Change-Id: Id8b9483fa0471f8b894c9a31d1c938a2835fa480 --- .../lite/quantization/quantization_driver.cc | 53 +++++++++++++++++-- .../mlir/lite/tests/prepare-quantize.mlir | 27 +++++++--- 2 files changed, 71 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 2e134396d49..b2355b2ae6e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project @@ -39,6 +40,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/core/platform/logging.h" +#define DEBUG_TYPE "quantization-driver" + namespace mlir { namespace quant { namespace { @@ -281,6 +284,37 @@ class QuantizationDriver { cached.first->second = InitializeState(op, index, res, /*as_result=*/true); } + void DumpStates(Operation *current_op) { + if (current_op) { + llvm::errs() << "\n\n\n" << current_op->getName() << "\n"; + } + fn_.walk([&](Operation *op) { + if (llvm::isa(op) || + llvm::isa(op) || llvm::isa(op)) + return; + if (current_op == op) llvm::errs() << "===>>>"; + llvm::errs() << op->getName() << " : ("; + for (auto i = 0; i < op->getNumOperands(); ++i) { + if (auto params = GetOperandQuantState(op, i).params) + params.print(llvm::errs()); + else + op->getOperand(i).getType().cast().getElementType().print( + llvm::errs()); + llvm::errs() << ","; + } + llvm::errs() << ") -> ("; + for (auto i = 0; i < op->getNumResults(); ++i) { + if (auto params = GetResultQuantState(op, i).params) + params.print(llvm::errs()); + else + op->getResult(i).getType().cast().getElementType().print( + llvm::errs()); + llvm::errs() << ","; + } + llvm::errs() << ")\n"; + }); + } + FuncOp fn_; OpBuilder builder_; bool is_signed_; @@ -712,6 +746,8 @@ bool QuantizationDriver::PropagateParams() { Operation *op = work_list_.back(); work_list_.pop_back(); + LLVM_DEBUG(DumpStates(op)); + // This op has been quantized, so we should not consider it again. if (llvm::is_contained(quantized_, op)) continue; quantized_.insert(op); @@ -736,12 +772,23 @@ bool QuantizationDriver::PropagateParams() { } // Use the final state to set all the operands' parameters. - for (int i = 0, e = op->getNumOperands(); i != e; ++i) - changed |= SetOperandParams(op, i, params); + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + if (auto type = op->getOperand(i).getType().dyn_cast()) { + // Without this check, it will accidently propagate the quantization + // information by the shared non-float tensors. + if (type.getElementType().isa()) + changed |= SetOperandParams(op, i, params); + } + } // Use the final state to set all the results' parameters. for (int res = 0, e = op->getNumResults(); res != e; ++res) - changed |= SetResultParams(op, res, params); + if (auto type = op->getResult(res).getType().dyn_cast()) { + // Without this check, it will accidently propagate the quantization + // information by the shared non-float-tensors. + if (type.getElementType().isa()) + changed |= SetResultParams(op, res, params); + } } // TODO(fengliuai): make the bit width configurable. diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index fc9c55089a3..9ae61357c09 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -242,6 +242,22 @@ func @QuantizePad(tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> // CHECK: return %3 : tensor } +// CHECK-LABEL: QuantizePad2 +// only the second tfl.pad has sufficient quantization information. +func @QuantizePad2(tensor<2x1x3x!quant.uniform>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor, tensor) { +^bb0(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>): + %0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform>) -> tensor<2x1x3xf32> + %1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor + %2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor + return %1, %2 : tensor, tensor + +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0) +// CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2) +// CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2) +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]]) +// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) +} + // CHECK-LABEL: QuantizeReshape2D func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -418,16 +434,15 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform> } -// CHECK-LABEL: RequantizeAlreadyQuantizedModel -func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform>, %arg1: tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> { +// CHECK-LABEL: NotRequantizeAlreadyQuantizedModel +func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform>, %arg1: tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> { %9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> %10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> return %10 : tensor<1x73x73x160x!quant.uniform> -// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform>} : (tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> -// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> -// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform> +// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> +// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> +// CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform> } // CHECK-LABEL: QuantizeChain