Add a debug flag to the quantization driver and also avoid quantization on non-float types
PiperOrigin-RevId: 292400638 Change-Id: Id8b9483fa0471f8b894c9a31d1c938a2835fa480
This commit is contained in:
parent
a59426a65c
commit
5a1377938c
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
@ -39,6 +40,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-driver"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
namespace {
|
||||
@ -281,6 +284,37 @@ class QuantizationDriver {
|
||||
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
|
||||
}
|
||||
|
||||
void DumpStates(Operation *current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
|
||||
}
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
for (auto i = 0; i < op->getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
if (auto params = GetResultQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
FuncOp fn_;
|
||||
OpBuilder builder_;
|
||||
bool is_signed_;
|
||||
@ -712,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
Operation *op = work_list_.back();
|
||||
work_list_.pop_back();
|
||||
|
||||
LLVM_DEBUG(DumpStates(op));
|
||||
|
||||
// This op has been quantized, so we should not consider it again.
|
||||
if (llvm::is_contained(quantized_, op)) continue;
|
||||
quantized_.insert(op);
|
||||
@ -736,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
}
|
||||
}
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res)
|
||||
changed |= SetResultParams(op, res, params);
|
||||
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float-tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetResultParams(op, res, params);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
|
@ -242,6 +242,22 @@ func @QuantizePad(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) ->
|
||||
// CHECK: return %3 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizePad2
|
||||
// only the second tfl.pad has sufficient quantization information.
|
||||
func @QuantizePad2(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
^bb0(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x1x3xf32>
|
||||
%1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||
%2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||
return %1, %2 : tensor<?xf32>, tensor<?xf32>
|
||||
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2)
|
||||
// CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2)
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]])
|
||||
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeReshape2D
|
||||
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
@ -418,16 +434,15 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
|
||||
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
|
||||
func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
|
||||
// CHECK-LABEL: NotRequantizeAlreadyQuantizedModel
|
||||
func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
|
||||
%9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>
|
||||
%10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
|
||||
// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeChain
|
||||
|
Loading…
Reference in New Issue
Block a user