Propagate the sign of the inference type into the quantization passes

The FakeQuant* ops in the inference graph don't have "sign" information and
"unsigned" is used during the legalization to convert these FakeQuant* ops to
quantized types. The quantization passes need to generate ops for the target
inference type and its sign, so if the inference type is signed, all the
unsigned quantized types need to be converted to the signed ones.  A rewrite
pattern is added for this purpose.

PiperOrigin-RevId: 272089485
This commit is contained in:
Feng Liu 2019-09-30 16:17:57 -07:00 committed by TensorFlower Gardener
parent 48d245cbe3
commit b81191bc81
10 changed files with 180 additions and 58 deletions

View File

@ -336,6 +336,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",

View File

@ -42,6 +42,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_FLOAT;
case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8;
case toco::IODataType::INT8:
return DT_QINT8;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::INT64:

View File

@ -22,11 +22,8 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mlir {
namespace TFL {
// Is this dtype a quantization type from TensorFlow.
bool IsQuantizationType(tensorflow::DataType dtype) {
static bool IsQuantizationType(tensorflow::DataType dtype) {
switch (dtype) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:
@ -39,22 +36,8 @@ bool IsQuantizationType(tensorflow::DataType dtype) {
}
}
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth(tensorflow::DataType dtype) {
switch (dtype) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:
return 8;
case tensorflow::DT_QINT16:
case tensorflow::DT_QUINT16:
return 16;
case tensorflow::DT_QINT32:
return 32;
default:
return 0;
}
}
namespace mlir {
namespace TFL {
bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view min_values,

View File

@ -29,13 +29,6 @@ limitations under the License.
namespace mlir {
namespace TFL {
// Is this dtype a quantization type from TensorFlow.
bool IsQuantizationType(tensorflow::DataType dtype);
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth(tensorflow::DataType dtype);
struct QuantizationSpecs {
// Which function this node quant specifications belong to.
std::string target_func = "main";
@ -66,6 +59,34 @@ struct QuantizationSpecs {
// Whether run the passes to only quantize the weights.
bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type.
bool IsSignedInferneceType() {
switch (inference_type) {
case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16:
return false;
default:
return true;
}
}
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth() {
switch (inference_type) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:
return 8;
case tensorflow::DT_QINT16:
case tensorflow::DT_QUINT16:
return 16;
case tensorflow::DT_QINT32:
return 32;
default:
return 0;
}
}
};
// Parses the command line flag strings to the quantization specification for

View File

@ -175,6 +175,62 @@ struct QuantizationPattern : public RewritePattern {
}
};
// Converts quantize ops with unsigned quantized types to these with signed
// quantized types and preserves the scales.
template <typename Q>
struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
using BaseType = ConvertUnsignedToSigned<Q>;
using QType = quant::QuantizedType;
explicit ConvertUnsignedToSigned(MLIRContext* context)
: OpRewritePattern<Q>(context, 1) {}
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output()->getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();
int num_bits = qtype.getStorageTypeIntegralWidth();
// This is a positive value, and will be applied on zero points and fixed
// point ranges.
int64_t offset =
QType::getDefaultMininumForInteger(/*isSigned=*/false, num_bits) -
QType::getDefaultMininumForInteger(/*isSigned=*/true, num_bits);
auto flags = quant::QuantizationFlags::Signed;
QType new_qtype;
if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
new_qtype = quant::UniformQuantizedType::getChecked(
flags, qtype.getStorageType(), qtype.getExpressedType(),
uqtype.getScale(), uqtype.getZeroPoint() - offset,
uqtype.getStorageTypeMin() - offset,
uqtype.getStorageTypeMax() - offset, op.getLoc());
} else if (auto aqtype = qtype.template dyn_cast<
quant::UniformQuantizedPerAxisType>()) {
auto zero_points = aqtype.getZeroPoints();
llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
zero_points.end());
for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
new_zero_points[i] -= offset;
}
new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
flags, qtype.getStorageType(), qtype.getExpressedType(),
aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
aqtype.getStorageTypeMin() - offset,
aqtype.getStorageTypeMax() - offset, op.getLoc());
} else {
return this->matchFailure();
}
Type new_output_type = new_qtype.castFromExpressedType(
QType::castToExpressedType(output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
rewriter.getTypeAttr(new_output_type));
return this->matchSuccess();
}
};
// Converts the min/max/num_bits/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType.
// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
@ -183,7 +239,7 @@ struct QuantizationPattern : public RewritePattern {
// if it is using signed int symmetric quantization.
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
Attribute max, IntegerAttr num_bits,
BoolAttr narrow_range, bool is_signed = false);
BoolAttr narrow_range, bool is_signed);
// Casts the `target` type to a quantized type by using the quantization
// parameters from the type in the `source` type attribute.

View File

@ -443,16 +443,16 @@ node {
}
}
# TODO(fengliuai): make the storage type signed.
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<u8:f32, 0.21632751372549019:155>>) -> tensor<1x6x31x!quant.uniform<u8:f32, 0.09363494573854933:150>>
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
# MLIR: %[[input:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<1x1x1x256x!quant.uniform<u8:f32, 0.21632751372549019:155>>)
# MLIR: %[[input:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>)
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<u8<1:255>:f32:0, {0.12581039038230116:128,
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[input]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<u8:f32, 0.09363494573854933:150>>, tensor<3xi32>)
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<u8:f32, 0.09363494573854933:150>>
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: }
@ -460,7 +460,7 @@ node {
# CHECK: version: 3,
# CHECK: operator_codes: [ {
# CHECK: builtin_code: CONV_2D,
# CHECK: version: 1
# CHECK: version: 3
# CHECK: }, {
# CHECK: builtin_code: RESHAPE,
# CHECK: version: 1
@ -476,12 +476,12 @@ node {
# CHECK: }
# CHECK: }, {
# CHECK: shape: [ 1, 1, 1, 256 ],
# CHECK: type: UINT8,
# CHECK: type: INT8,
# CHECK: buffer: 2,
# CHECK: name: "input",
# CHECK: quantization: {
# CHECK: scale: [ 0.216328 ],
# CHECK: zero_point: [ 155 ]
# CHECK: zero_point: [ 27 ]
# CHECK: }
# CHECK: }, {
# CHECK: shape: [ 186 ],
@ -494,30 +494,30 @@ node {
# CHECK: }
# CHECK: }, {
# CHECK: shape: [ 186, 1, 1, 256 ],
# CHECK: type: UINT8,
# CHECK: type: INT8,
# CHECK: buffer: 4,
# CHECK: name: "tfl.pseudo_qconst1",
# CHECK: quantization: {
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
# CHECK: zero_point: [ 128, 128, 128, 128, 128, 128, 128, 128, 128,
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
# CHECK: }
# CHECK: }, {
# CHECK: shape: [ 1, 1, 1, 186 ],
# CHECK: type: UINT8,
# CHECK: type: INT8,
# CHECK: buffer: 5,
# CHECK: name: "tfl.conv_2d",
# CHECK: quantization: {
# CHECK: scale: [ 0.093635 ],
# CHECK: zero_point: [ 150 ]
# CHECK: zero_point: [ 22 ]
# CHECK: }
# CHECK: }, {
# CHECK: shape: [ 1, 6, 31 ],
# CHECK: type: UINT8,
# CHECK: type: INT8,
# CHECK: buffer: 6,
# CHECK: name: "output",
# CHECK: quantization: {
# CHECK: scale: [ 0.093635 ],
# CHECK: zero_point: [ 150 ]
# CHECK: zero_point: [ 22 ]
# CHECK: }
# CHECK: } ],
# CHECK: inputs: [ 1 ],
@ -547,7 +547,7 @@ node {
# CHECK: }, {
# CHECK: data: [ 245, 255, 255, 255, 186, 254, 255, 255, 213, 254, 255, 255,
# CHECK: }, {
# CHECK: data: [ 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136,
# CHECK: data: [ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
# CHECK: }, {
# CHECK-EMPTY
# CHECK: }, {

View File

@ -0,0 +1,34 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-signed | FileCheck %s
// CHECK-LABEL: uint8_to_int8
func @uint8_to_int8(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00>>} : (tensor<2x2xf32>)
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
}
// CHECK-LABEL: uint8_to_int8_per_axis
func @uint8_to_int8_per_axis(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8:f32:1, {1.000000e+00,1.000000e+00:-128}>>}
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%0)
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
}
// CHECK-LABEL: uint8_to_int8_narrow_range
func @uint8_to_int8_narrow_range(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:127>>}
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
}

View File

@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4)">;
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4, /*is_signed=*/false)">;
// Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall<

View File

@ -21,11 +21,13 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/core/framework/types.pb.h"
// NOLINTNEXTLINE
static llvm::cl::list<std::string> quantize_whitelist(
@ -34,6 +36,12 @@ static llvm::cl::list<std::string> quantize_whitelist(
"quantized. Only used in tests"),
llvm::cl::CommaSeparated);
// NOLINTNEXTLINE
static llvm::cl::opt<bool> quantize_signed(
"tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
llvm::cl::desc("signed inference type. Only used in tests"),
llvm::cl::init(false));
//===----------------------------------------------------------------------===//
// The prepare-quantize Pass.
//
@ -49,8 +57,13 @@ namespace {
// training quantization simpler.
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
public:
// Constructor used by the PassRegistration.
explicit PrepareQuantizePass() {}
// Constructor used by the PassRegistration and enforce uint8 quantization.
explicit PrepareQuantizePass() {
if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8;
else
quant_specs_.inference_type = tensorflow::DT_QUINT8;
}
// Constructor used by manually creating the pass.
explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
@ -107,9 +120,10 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
}
OpBuilder builder(func);
auto num_bits = builder.getI32IntegerAttr(
GetQuantizationTypeWidth(quant_specs_.inference_type));
auto narrow_range = builder.getBoolAttr(false);
bool is_signed = quant_specs_.IsSignedInferneceType();
IntegerAttr num_bits =
builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
BoolAttr narrow_range = builder.getBoolAttr(false);
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
Value* arg = func.getArgument(i);
@ -128,7 +142,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
auto min_max = GetMinMaxValuesForArgument(func_name, i);
TypeAttr params = GetQuantizedTypeAttr(
builder, input_type, builder.getF64FloatAttr(min_max.first),
builder.getF64FloatAttr(min_max.second), num_bits, narrow_range);
builder.getF64FloatAttr(min_max.second), num_bits, narrow_range,
is_signed);
builder.setInsertionPoint(input->getBlock(),
++Block::iterator(input_op));
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(),
@ -147,15 +162,24 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
// Set the quantization parameters for the quantizable input nodes. If this
// failed, return the function immediately.
// TODO(fengliuai): send the signal to the pass manager.
if (SetInputNodesQuantizationParams(getFunction())) return;
if (SetInputNodesQuantizationParams(func)) return;
// TODO(fengliuai): set the sign by the inference type from the spec
bool quantize_sign = false;
ApplyQuantizationParamsPropagation(getFunction(), quantize_sign,
GetOpQuantSpec);
// During the legalization, unsigned quantized type is used, so we have to
// convert all of them to signed.
bool is_signed = quant_specs_.IsSignedInferneceType();
if (is_signed) {
OwningRewritePatternList patterns;
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(
func.getContext());
applyPatternsGreedily(func, patterns);
}
ApplyQuantizationParamsPropagation(func, is_signed, GetOpQuantSpec);
}
} // namespace

View File

@ -136,8 +136,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
Type res_type = tf_op.getType();
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_value,
max_value, num_bits, narrow_range);
TypeAttr qtype =
GetQuantizedTypeAttr(rewriter, res_type, min_value, max_value, num_bits,
narrow_range, /*is_signed=*/false);
if (!qtype) this->matchFailure();
// Finally, use the quantization parameter to create the quantize and