Add a debug flag to the quantization driver and also avoid quantization on non-float types

PiperOrigin-RevId: 292400638
Change-Id: Id8b9483fa0471f8b894c9a31d1c938a2835fa480
This commit is contained in:
Feng Liu 2020-01-30 12:22:52 -08:00 committed by TensorFlower Gardener
parent a59426a65c
commit 5a1377938c
2 changed files with 71 additions and 9 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
@ -39,6 +40,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#define DEBUG_TYPE "quantization-driver"
namespace mlir { namespace mlir {
namespace quant { namespace quant {
namespace { namespace {
@ -281,6 +284,37 @@ class QuantizationDriver {
cached.first->second = InitializeState(op, index, res, /*as_result=*/true); cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
} }
void DumpStates(Operation *current_op) {
if (current_op) {
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
}
fn_.walk([&](Operation *op) {
if (llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
return;
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op->getName() << " : (";
for (auto i = 0; i < op->getNumOperands(); ++i) {
if (auto params = GetOperandQuantState(op, i).params)
params.print(llvm::errs());
else
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ") -> (";
for (auto i = 0; i < op->getNumResults(); ++i) {
if (auto params = GetResultQuantState(op, i).params)
params.print(llvm::errs());
else
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ")\n";
});
}
FuncOp fn_; FuncOp fn_;
OpBuilder builder_; OpBuilder builder_;
bool is_signed_; bool is_signed_;
@ -712,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
Operation *op = work_list_.back(); Operation *op = work_list_.back();
work_list_.pop_back(); work_list_.pop_back();
LLVM_DEBUG(DumpStates(op));
// This op has been quantized, so we should not consider it again. // This op has been quantized, so we should not consider it again.
if (llvm::is_contained(quantized_, op)) continue; if (llvm::is_contained(quantized_, op)) continue;
quantized_.insert(op); quantized_.insert(op);
@ -736,13 +772,24 @@ bool QuantizationDriver::PropagateParams() {
} }
// Use the final state to set all the operands' parameters. // Use the final state to set all the operands' parameters.
for (int i = 0, e = op->getNumOperands(); i != e; ++i) for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetOperandParams(op, i, params); changed |= SetOperandParams(op, i, params);
}
}
// Use the final state to set all the results' parameters. // Use the final state to set all the results' parameters.
for (int res = 0, e = op->getNumResults(); res != e; ++res) for (int res = 0, e = op->getNumResults(); res != e; ++res)
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float-tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetResultParams(op, res, params); changed |= SetResultParams(op, res, params);
} }
}
// TODO(fengliuai): make the bit width configurable. // TODO(fengliuai): make the bit width configurable.
auto spec = GetQuantSpec(op); auto spec = GetQuantSpec(op);

View File

@ -242,6 +242,22 @@ func @QuantizePad(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) ->
// CHECK: return %3 : tensor<?xf32> // CHECK: return %3 : tensor<?xf32>
} }
// CHECK-LABEL: QuantizePad2
// only the second tfl.pad has sufficient quantization information.
func @QuantizePad2(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor<?xf32>, tensor<?xf32>) {
^bb0(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>):
%0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x1x3xf32>
%1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
%2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
return %1, %2 : tensor<?xf32>, tensor<?xf32>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0)
// CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2)
// CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2)
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]])
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
}
// CHECK-LABEL: QuantizeReshape2D // CHECK-LABEL: QuantizeReshape2D
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> { func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>): ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
@ -418,16 +434,15 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>> // CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
} }
// CHECK-LABEL: RequantizeAlreadyQuantizedModel // CHECK-LABEL: NotRequantizeAlreadyQuantizedModel
func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> { func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
%9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>> %9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>
%10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> %10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>> // CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>> // CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>> // CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
} }
// CHECK-LABEL: QuantizeChain // CHECK-LABEL: QuantizeChain