Propagate the sign of the inference type into the quantization passes
The FakeQuant* ops in the inference graph don't have "sign" information and "unsigned" is used during the legalization to convert these FakeQuant* ops to quantized types. The quantization passes need to generate ops for the target inference type and its sign, so if the inference type is signed, all the unsigned quantized types need to be converted to the signed ones. A rewrite pattern is added for this purpose. PiperOrigin-RevId: 272089485
This commit is contained in:
parent
48d245cbe3
commit
b81191bc81
@ -336,6 +336,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
|
@ -42,6 +42,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT8:
|
||||
return DT_QINT8;
|
||||
case toco::IODataType::INT32:
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
|
@ -22,11 +22,8 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Is this dtype a quantization type from TensorFlow.
|
||||
bool IsQuantizationType(tensorflow::DataType dtype) {
|
||||
static bool IsQuantizationType(tensorflow::DataType dtype) {
|
||||
switch (dtype) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
@ -39,22 +36,8 @@ bool IsQuantizationType(tensorflow::DataType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth(tensorflow::DataType dtype) {
|
||||
switch (dtype) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
return 8;
|
||||
case tensorflow::DT_QINT16:
|
||||
case tensorflow::DT_QUINT16:
|
||||
return 16;
|
||||
case tensorflow::DT_QINT32:
|
||||
return 32;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
absl::string_view min_values,
|
||||
|
@ -29,13 +29,6 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Is this dtype a quantization type from TensorFlow.
|
||||
bool IsQuantizationType(tensorflow::DataType dtype);
|
||||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth(tensorflow::DataType dtype);
|
||||
|
||||
struct QuantizationSpecs {
|
||||
// Which function this node quant specifications belong to.
|
||||
std::string target_func = "main";
|
||||
@ -66,6 +59,34 @@ struct QuantizationSpecs {
|
||||
|
||||
// Whether run the passes to only quantize the weights.
|
||||
bool RunWeightQuantization() const { return weight_quantization; }
|
||||
|
||||
// Whether this inference type represents a signed storage type.
|
||||
bool IsSignedInferneceType() {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QUINT8:
|
||||
case tensorflow::DT_QUINT16:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth() {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
return 8;
|
||||
case tensorflow::DT_QINT16:
|
||||
case tensorflow::DT_QUINT16:
|
||||
return 16;
|
||||
case tensorflow::DT_QINT32:
|
||||
return 32;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Parses the command line flag strings to the quantization specification for
|
||||
|
@ -175,6 +175,62 @@ struct QuantizationPattern : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts quantize ops with unsigned quantized types to these with signed
|
||||
// quantized types and preserves the scales.
|
||||
template <typename Q>
|
||||
struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
using BaseType = ConvertUnsignedToSigned<Q>;
|
||||
using QType = quant::QuantizedType;
|
||||
|
||||
explicit ConvertUnsignedToSigned(MLIRContext* context)
|
||||
: OpRewritePattern<Q>(context, 1) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output()->getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
int num_bits = qtype.getStorageTypeIntegralWidth();
|
||||
// This is a positive value, and will be applied on zero points and fixed
|
||||
// point ranges.
|
||||
int64_t offset =
|
||||
QType::getDefaultMininumForInteger(/*isSigned=*/false, num_bits) -
|
||||
QType::getDefaultMininumForInteger(/*isSigned=*/true, num_bits);
|
||||
|
||||
auto flags = quant::QuantizationFlags::Signed;
|
||||
QType new_qtype;
|
||||
if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
|
||||
new_qtype = quant::UniformQuantizedType::getChecked(
|
||||
flags, qtype.getStorageType(), qtype.getExpressedType(),
|
||||
uqtype.getScale(), uqtype.getZeroPoint() - offset,
|
||||
uqtype.getStorageTypeMin() - offset,
|
||||
uqtype.getStorageTypeMax() - offset, op.getLoc());
|
||||
} else if (auto aqtype = qtype.template dyn_cast<
|
||||
quant::UniformQuantizedPerAxisType>()) {
|
||||
auto zero_points = aqtype.getZeroPoints();
|
||||
llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
|
||||
zero_points.end());
|
||||
for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
|
||||
new_zero_points[i] -= offset;
|
||||
}
|
||||
new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
|
||||
flags, qtype.getStorageType(), qtype.getExpressedType(),
|
||||
aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
|
||||
aqtype.getStorageTypeMin() - offset,
|
||||
aqtype.getStorageTypeMax() - offset, op.getLoc());
|
||||
} else {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
|
||||
rewriter.getTypeAttr(new_output_type));
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts the min/max/num_bits/narrow_range information to a
|
||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||
// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
|
||||
@ -183,7 +239,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// if it is using signed int symmetric quantization.
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
|
||||
Attribute max, IntegerAttr num_bits,
|
||||
BoolAttr narrow_range, bool is_signed = false);
|
||||
BoolAttr narrow_range, bool is_signed);
|
||||
|
||||
// Casts the `target` type to a quantized type by using the quantization
|
||||
// parameters from the type in the `source` type attribute.
|
||||
|
@ -443,16 +443,16 @@ node {
|
||||
}
|
||||
}
|
||||
|
||||
# TODO(fengliuai): make the storage type signed.
|
||||
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<u8:f32, 0.21632751372549019:155>>) -> tensor<1x6x31x!quant.uniform<u8:f32, 0.09363494573854933:150>>
|
||||
|
||||
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
|
||||
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
|
||||
# MLIR: %[[input:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<1x1x1x256x!quant.uniform<u8:f32, 0.21632751372549019:155>>)
|
||||
# MLIR: %[[input:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>)
|
||||
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
|
||||
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<u8<1:255>:f32:0, {0.12581039038230116:128,
|
||||
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
|
||||
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[input]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
||||
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<u8:f32, 0.09363494573854933:150>>, tensor<3xi32>)
|
||||
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<u8:f32, 0.09363494573854933:150>>
|
||||
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
|
||||
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR: }
|
||||
|
||||
|
||||
@ -460,7 +460,7 @@ node {
|
||||
# CHECK: version: 3,
|
||||
# CHECK: operator_codes: [ {
|
||||
# CHECK: builtin_code: CONV_2D,
|
||||
# CHECK: version: 1
|
||||
# CHECK: version: 3
|
||||
# CHECK: }, {
|
||||
# CHECK: builtin_code: RESHAPE,
|
||||
# CHECK: version: 1
|
||||
@ -476,12 +476,12 @@ node {
|
||||
# CHECK: }
|
||||
# CHECK: }, {
|
||||
# CHECK: shape: [ 1, 1, 1, 256 ],
|
||||
# CHECK: type: UINT8,
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 2,
|
||||
# CHECK: name: "input",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.216328 ],
|
||||
# CHECK: zero_point: [ 155 ]
|
||||
# CHECK: zero_point: [ 27 ]
|
||||
# CHECK: }
|
||||
# CHECK: }, {
|
||||
# CHECK: shape: [ 186 ],
|
||||
@ -494,30 +494,30 @@ node {
|
||||
# CHECK: }
|
||||
# CHECK: }, {
|
||||
# CHECK: shape: [ 186, 1, 1, 256 ],
|
||||
# CHECK: type: UINT8,
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 4,
|
||||
# CHECK: name: "tfl.pseudo_qconst1",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.12581, 0.001755, 0.001908, 0.001967, 0.007431,
|
||||
# CHECK: zero_point: [ 128, 128, 128, 128, 128, 128, 128, 128, 128,
|
||||
# CHECK: zero_point: [ 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
# CHECK: }
|
||||
# CHECK: }, {
|
||||
# CHECK: shape: [ 1, 1, 1, 186 ],
|
||||
# CHECK: type: UINT8,
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 5,
|
||||
# CHECK: name: "tfl.conv_2d",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.093635 ],
|
||||
# CHECK: zero_point: [ 150 ]
|
||||
# CHECK: zero_point: [ 22 ]
|
||||
# CHECK: }
|
||||
# CHECK: }, {
|
||||
# CHECK: shape: [ 1, 6, 31 ],
|
||||
# CHECK: type: UINT8,
|
||||
# CHECK: type: INT8,
|
||||
# CHECK: buffer: 6,
|
||||
# CHECK: name: "output",
|
||||
# CHECK: quantization: {
|
||||
# CHECK: scale: [ 0.093635 ],
|
||||
# CHECK: zero_point: [ 150 ]
|
||||
# CHECK: zero_point: [ 22 ]
|
||||
# CHECK: }
|
||||
# CHECK: } ],
|
||||
# CHECK: inputs: [ 1 ],
|
||||
@ -547,7 +547,7 @@ node {
|
||||
# CHECK: }, {
|
||||
# CHECK: data: [ 245, 255, 255, 255, 186, 254, 255, 255, 213, 254, 255, 255,
|
||||
# CHECK: }, {
|
||||
# CHECK: data: [ 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136,
|
||||
# CHECK: data: [ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
|
||||
# CHECK: }, {
|
||||
# CHECK-EMPTY
|
||||
# CHECK: }, {
|
||||
|
@ -0,0 +1,34 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-signed | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: uint8_to_int8
|
||||
func @uint8_to_int8(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00>>} : (tensor<2x2xf32>)
|
||||
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: uint8_to_int8_per_axis
|
||||
func @uint8_to_int8_per_axis(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32:1, {1.0:128, 1.0}>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8:f32:1, {1.000000e+00,1.000000e+00:-128}>>}
|
||||
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%0)
|
||||
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: uint8_to_int8_narrow_range
|
||||
func @uint8_to_int8_narrow_range(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8<1:255>:f32, 1.0:255>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:127>>}
|
||||
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK-NEXT: return %[[dq]] : tensor<2x2xf32>
|
||||
}
|
@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4)">;
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4, /*is_signed=*/false)">;
|
||||
|
||||
// Converts an integer attribute $0 to 32-bit with builder.
|
||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
|
@ -21,11 +21,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> quantize_whitelist(
|
||||
@ -34,6 +36,12 @@ static llvm::cl::list<std::string> quantize_whitelist(
|
||||
"quantized. Only used in tests"),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> quantize_signed(
|
||||
"tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("signed inference type. Only used in tests"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The prepare-quantize Pass.
|
||||
//
|
||||
@ -49,8 +57,13 @@ namespace {
|
||||
// training quantization simpler.
|
||||
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration.
|
||||
explicit PrepareQuantizePass() {}
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
else
|
||||
quant_specs_.inference_type = tensorflow::DT_QUINT8;
|
||||
}
|
||||
|
||||
// Constructor used by manually creating the pass.
|
||||
explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
|
||||
@ -107,9 +120,10 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
}
|
||||
|
||||
OpBuilder builder(func);
|
||||
auto num_bits = builder.getI32IntegerAttr(
|
||||
GetQuantizationTypeWidth(quant_specs_.inference_type));
|
||||
auto narrow_range = builder.getBoolAttr(false);
|
||||
bool is_signed = quant_specs_.IsSignedInferneceType();
|
||||
IntegerAttr num_bits =
|
||||
builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
|
||||
BoolAttr narrow_range = builder.getBoolAttr(false);
|
||||
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||
Value* arg = func.getArgument(i);
|
||||
@ -128,7 +142,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
auto min_max = GetMinMaxValuesForArgument(func_name, i);
|
||||
TypeAttr params = GetQuantizedTypeAttr(
|
||||
builder, input_type, builder.getF64FloatAttr(min_max.first),
|
||||
builder.getF64FloatAttr(min_max.second), num_bits, narrow_range);
|
||||
builder.getF64FloatAttr(min_max.second), num_bits, narrow_range,
|
||||
is_signed);
|
||||
builder.setInsertionPoint(input->getBlock(),
|
||||
++Block::iterator(input_op));
|
||||
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(),
|
||||
@ -147,15 +162,24 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Set the quantization parameters for the quantizable input nodes. If this
|
||||
// failed, return the function immediately.
|
||||
// TODO(fengliuai): send the signal to the pass manager.
|
||||
if (SetInputNodesQuantizationParams(getFunction())) return;
|
||||
if (SetInputNodesQuantizationParams(func)) return;
|
||||
|
||||
// TODO(fengliuai): set the sign by the inference type from the spec
|
||||
bool quantize_sign = false;
|
||||
ApplyQuantizationParamsPropagation(getFunction(), quantize_sign,
|
||||
GetOpQuantSpec);
|
||||
// During the legalization, unsigned quantized type is used, so we have to
|
||||
// convert all of them to signed.
|
||||
bool is_signed = quant_specs_.IsSignedInferneceType();
|
||||
if (is_signed) {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(
|
||||
func.getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
ApplyQuantizationParamsPropagation(func, is_signed, GetOpQuantSpec);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -136,8 +136,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_value,
|
||||
max_value, num_bits, narrow_range);
|
||||
TypeAttr qtype =
|
||||
GetQuantizedTypeAttr(rewriter, res_type, min_value, max_value, num_bits,
|
||||
narrow_range, /*is_signed=*/false);
|
||||
if (!qtype) this->matchFailure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
|
Loading…
Reference in New Issue
Block a user