Add support for SVDF in MLIR quantizer

* Added quantization traits for SVDF in tfl_ops.td.
* Added logic for removing dangling SVDF ops
* Added handling for 10-bit quantization of weights, using [-512, 512] range.
* Unified logic for quantizing complex ops. (LSTM, SVDF)
* Ensured quantization range contains zero.

PiperOrigin-RevId: 350702871
Change-Id: I719a5cd005d0e70db5266e6700ad3c2e58846e7a
This commit is contained in:
Taehee Jeong 2021-01-07 22:31:18 -08:00 committed by TensorFlower Gardener
parent 230472ae51
commit fe004863b7
9 changed files with 263 additions and 170 deletions

View File

@ -530,7 +530,7 @@ cc_library(
],
hdrs = [
"transforms/passes.h",
"transforms/prepare_quantize_lstm.h",
"transforms/prepare_quantize_helper.h",
],
deps = [
"convert_type",

View File

@ -4303,7 +4303,8 @@ def TFL_SVDFOp :
TFL_Op<"svdf", [
PredOpTrait<"the input and result tensor elemental types must be same",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_StatefulOp]> {
TFL_StatefulOp,
AccumulatorUniformScale<3, 2, 4>]> {
let summary = "Single value decomposition filter operator";
@ -4321,10 +4322,10 @@ def TFL_SVDFOp :
TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights,
// Time weights
TFL_TensorOf<[F32, QI8]>:$time_weights,
TFL_TensorOf<[F32, QI16]>:$time_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
// Activation state.
TFL_StatefulTensor:$activation_state,

View File

@ -72,10 +72,10 @@ static void ExpandVerySmallRange(ArrayRef<double> mins, ArrayRef<double> maxs,
// input_type/min/max/storag_type_width/narrow_range.
// This is entry point to the Quant dialect and used for both quantizing
// activations and weights.
static Type GetQuantizedType(Builder builder, Type input_type,
ArrayRef<double> min, ArrayRef<double> max,
int quant_dim, int storage_type_width,
bool narrow_range, bool is_signed) {
Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
ArrayRef<double> max, int quant_dim,
int storage_type_width, bool narrow_range,
bool is_signed) {
auto converter =
quant::ExpressedToQuantizedConverter::forInputType(input_type);
@ -253,10 +253,10 @@ TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
return TypeAttr::get(final_type);
}
static void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
int slice_size, bool symmetric,
SmallVector<double, 4>& mins,
SmallVector<double, 4>& maxs) {
void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
int slice_size, bool symmetric,
SmallVector<double, 4>& mins,
SmallVector<double, 4>& maxs) {
// If all the element values are same we don't need to scan the content.
if (values.isSplat()) {
double single_value =

View File

@ -553,6 +553,22 @@ quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
int64_t zero_point,
int64_t storage_min = -128,
int64_t storage_max = 127);
// Extrace min and max values from the DenseFPElementsAttr, and stores them into
// `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
// is number of channels and `slice_size` is the size of slice per each channel.
// When `symmetric` is true, the range is expanded to [-M, M].
void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
int slice_size, bool symmetric,
SmallVector<double, 4>& mins,
SmallVector<double, 4>& maxs);
// Returns the quantized type for the
// input_type/min/max/storag_type_width/narrow_range.
Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
ArrayRef<double> max, int quant_dim,
int storage_type_width, bool narrow_range,
bool is_signed);
} // namespace quant
} // namespace mlir

View File

@ -380,3 +380,25 @@ func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
}
// CHECK-LABEL: QuantizeSVDF
func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
%0 = "quant.stats"(%arg0) {layerStats = dense<[2.07937503, 1.365000e+01]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%1 = "tfl.pseudo_const"() {value = dense<[[1.125947117805481, 1.0, 1.1], [-1.164743185043335, -1.0, -1.1]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%2 = "tfl.pseudo_const"() {value = dense<[[1.8328168392181396], [-1.897219181060791]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
%3 = "tfl.pseudo_const"() {value = dense<[1.4014043807983398, -1.0950859785079956]> : tensor<2xf32>} : () -> tensor<2xf32>
%4 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
%5 = "quant.stats"(%4) {layerStats = dense<[-56.2916565, 122.922478]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
%6 = "tfl.svdf"(%0, %1, %2, %3, %5) {fused_activation_function = "RELU", rank = 1 : i32} : (tensor<1x3xf32>, tensor<2x3xf32>, tensor<2x1xf32>, tensor<2xf32>, tensor<1x4xf32>) -> tensor<1x2xf32>
%7 = "quant.stats"(%6) {layerStats = dense<[0.000000e+00, 33.0349121]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
return %7 : tensor<1x2xf32>
// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform<i8:f32, 0.053529410268746171:-128>>)
// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x3x!quant.uniform<i8<-127:127>:f32, 0.0091712061814435818>>)
// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x1x!quant.uniform<i16<-512:512>:f32, 0.0037055062130093575>>)
// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.3900876031311922E-5>>)
// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i16<-32767:32767>:f32, 0.0037514108011770368>>)
// CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]])
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.12954867493872549:-128>>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%11)
// CHECK: return %[[dq]]
}

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h"
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
//===----------------------------------------------------------------------===//
// The Pass to add default quantization parameters for the activations which

View File

@ -139,21 +139,20 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
}
};
// Removes LSTMs that have dangling output.
// LSTMs are not removed automatically becuase they are stateful ops.
template <typename LstmOpTy>
struct PruneUnusedLstm : public OpRewritePattern<LstmOpTy> {
// Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
// output.
template <typename OpTy>
struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
public:
explicit PruneUnusedLstm(MLIRContext* context)
: OpRewritePattern<LstmOpTy>(context) {}
explicit PruneUnusedOpsWithSideEffect(MLIRContext* context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(LstmOpTy lstm_op,
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter& rewriter) const override {
Operation* op = lstm_op.getOperation();
if (op->isKnownTerminator()) {
if (op.getOperation()->isKnownTerminator()) {
return failure();
}
for (auto result : op->getOpResults()) {
for (auto result : op.getOperation()->getOpResults()) {
if (!result.use_empty()) {
return failure();
}
@ -171,8 +170,11 @@ void PostQuantizePass::runOnFunction() {
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
patterns.insert<PruneUnusedLstm<TFL::LSTMOp>>(ctx);
patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
patterns.insert<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
patterns
.insert<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
ctx);
patterns.insert<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
applyPatternsAndFoldGreedily(func, std::move(patterns));
if (!emit_quant_adaptor_ops_) {

View File

@ -43,7 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h"
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/monitoring/counter.h"
@ -388,6 +388,7 @@ void PrepareQuantizePass::runOnFunction() {
patterns_2.insert<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
patterns_2.insert<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
ctx, quant_specs_);
patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx);
}
applyPatternsAndFoldGreedily(func, std::move(patterns_2));

View File

@ -15,8 +15,8 @@ limitations under the License.
// Transform pass for LSTMs.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
#include <algorithm>
#include <cmath>
@ -221,15 +221,187 @@ struct PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
}
};
template <typename SourceOp>
struct ConvertOpStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
explicit ConvertOpStatsToQDQs(MLIRContext* context,
PatternBenefit benefit = 1)
: OpRewritePattern<SourceOp>(context, benefit) {}
protected:
LogicalResult processInputs(
SourceOp op, const operator_property::OpVariant& op_variant,
const operator_property::OperatorProperty& op_property,
PatternRewriter& rewriter) const {
for (auto& enumerated_inputs : op_property.inputs) {
int index = enumerated_inputs.first;
auto& tensor_property = enumerated_inputs.second;
Value input = op.getOperand(index);
if (input.getDefiningOp() == nullptr) continue;
// TODO(b/172517537): make this work with non-PTQ case.
if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
// Tensors with derived scale are biases, and handled in propagation.
if (tensor_property.use_derived_scale) continue;
// For weights, use quantization scale inferred from the values.
if (failed(processConstantOp(op, input.getDefiningOp(), index,
tensor_property, rewriter))) {
return failure();
}
} else {
if (auto stats_op =
llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
rewriter))) {
return failure();
}
} else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
!llvm::isa<SameScalesOpInterface>(input.getDefiningOp())) {
// Continue if StatisticsOp is already converted to Q-DQ pair, or
// stats op is not immediately available to the input because it's
// connected to ops with same scale requirements.
// TODO(b/172517537): make this work with non-PTQ case.
op.emitError() << "Input " << index
<< " should be from DequantizeCast, Statistics, "
<< ", or ops with same scale requirement.";
input.getDefiningOp()->emitError();
return failure();
}
}
}
return success();
}
LogicalResult processConstantOp(
SourceOp op, Operation* const_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
// Non-float tensors are neither weights nor require quantization.
auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return success();
DenseFPElementsAttr attr;
if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
const_op->emitError("Not a constant op.");
return failure();
}
UniformQuantizedType quant_type = nullptr;
// When the number of bits is 10 (instead of 16), quantize the tensor to
// [-512, 512], instead of [-32767, 32767].
// For now this behavior is specific for SVDF, where 6 bits are reserved for
// the reduce operation after element-wise multiplication between state and
// time weights.
if (tensor_property.number_of_bits == 10) {
SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
// Computes the effective min/max values of the attribute values.
quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1,
/*symmetric=*/true, mins, maxs);
double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits);
quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
rewriter.getIntegerType(16, /*isSigned=*/true),
attr.getType().getElementType(), scale, /*zeroPoint=*/0,
llvm::minIntN(10), -llvm::minIntN(10), const_op->getLoc());
} else {
quant_type =
quant::GetUniformQuantizedTypeForWeight(
attr, /*symmetric=*/true,
/*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
/*narrow_range=*/true)
.template dyn_cast<quant::UniformQuantizedType>();
}
if (!quant_type) {
const_op->emitError("Failed to get quantized type");
return failure();
}
// TODO(b/172517537): duplicate the constant when the bias is shared.
Type expressed_type = const_op->getResult(0).getType();
Type cast_type = quant_type.castFromExpressedType(expressed_type);
rewriter.setInsertionPointAfter(const_op);
auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
const_op->getResult(0));
auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
op.setOperand(input_index, dq.getResult());
return success();
}
LogicalResult replaceStatsOp(
SourceOp op, quant::StatisticsOp stats_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
// TODO(b/172517537): check if other tensors should go through this
// check too.
op.emitError() << "Input tensor [" << input_index
<< "] is a state tensor, but has more than one use.";
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats || stats.getNumElements() != 2) {
stats_op.emitError("Stats should have 2 values.");
return failure();
}
quant::QuantizedType quant_type;
double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
// Make sure the range includes zero.
min = std::min(min, 0.0);
max = std::max(max, 0.0);
Type expressed = getElementTypeOrSelf(stats_op.getType());
if (tensor_property.extend_to_power_of_two) {
if (tensor_property.number_of_bits != 16) {
op.emitError(
"extended power of 2 scale is only supported for 16-bit"
" quantization.");
return failure();
}
double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
// Set flags to 1 for signed type.
quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
/*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
/*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
} else {
// int16 uses range [-32767, 32767]
if (tensor_property.number_of_bits == 16) {
max = std::max(std::abs(min), std::abs(max));
min = -max;
quant_type = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, min, max,
/*narrowRange=*/true, expressed,
/*isSigned=*/true);
} else {
quant_type = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, min, max,
/*narrowRange=*/false, expressed,
/*isSigned=*/true);
}
}
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
// Quantize LSTM according to its quantization recipe.
template <typename SourceOp>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
struct ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs<SourceOp> {
public:
ConvertLstmStatsToQDQs(MLIRContext* context,
const QuantizationSpecs& quant_specs)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
quant_specs(quant_specs) {}
: ConvertOpStatsToQDQs<SourceOp>(context), quant_specs(quant_specs) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant;
@ -239,7 +411,8 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
}
if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
failed(processInputs(op, lstm_variant, lstm_property, rewriter))) {
failed(ConvertOpStatsToQDQs<SourceOp>::processInputs(
op, lstm_variant, lstm_property, rewriter))) {
return failure();
}
@ -311,143 +484,6 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
}
return success();
}
LogicalResult processInputs(
SourceOp op, const operator_property::OpVariant& lstm_variant,
const operator_property::OperatorProperty& lstm_property,
PatternRewriter& rewriter) const {
for (auto& enumerated_inputs : lstm_property.inputs) {
int index = enumerated_inputs.first;
auto& tensor_property = enumerated_inputs.second;
Value input = op.getOperand(index);
if (input.getDefiningOp() == nullptr) continue;
// TODO(b/172517537): make this work with non-PTQ case.
if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
// Tensors with derived scale are biases, and handled in propagation.
if (tensor_property.use_derived_scale) continue;
if (failed(processConstantOp(op, input.getDefiningOp(), index,
tensor_property, rewriter))) {
return failure();
}
} else {
if (auto stats_op =
llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
rewriter))) {
return failure();
}
} else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
!llvm::isa<SameScalesOpInterface>(input.getDefiningOp())) {
// Continue if StatisticsOp is already converted to Q-DQ pair, or
// stats op is not immediately available to the input because it's
// connected to ops with same scale requirements.
// TODO(b/172517537): make this work with non-PTQ case.
op.emitError() << "Input " << index
<< " should be from DequantizeCast, Statistics, "
<< ", or ops with same scale requirement.";
input.getDefiningOp()->emitError();
return failure();
}
}
}
return success();
}
// For weights, use quantization scale directly inferred from the values.
//
// input 1~4: input to gate weights
// input 5~8: recurrent to gate weights
// input 9~11: peephole weights, input 16: projection weight
// input 20~23: normalization weights
LogicalResult processConstantOp(
SourceOp op, Operation* const_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
// Non-float tensors are neither weights nor require quantization.
auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return success();
DenseFPElementsAttr attr;
if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
const_op->emitError("Not a constant op.");
return failure();
}
UniformQuantizedType quant_type =
quant::GetUniformQuantizedTypeForWeight(
attr, /*symmetric=*/true,
/*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
/*narrow_range=*/true)
.template dyn_cast<quant::UniformQuantizedType>();
if (!quant_type) {
const_op->emitError("Failed to get quantized type");
return failure();
}
// TODO(b/172517537): duplicate the constant when the bias is shared.
Type expressed_type = const_op->getResult(0).getType();
Type cast_type = quant_type.castFromExpressedType(expressed_type);
rewriter.setInsertionPointAfter(const_op);
auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
const_op->getResult(0));
auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
op.setOperand(input_index, dq.getResult());
return success();
}
LogicalResult replaceStatsOp(
SourceOp op, quant::StatisticsOp stats_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
// TODO(b/172517537): check if other tensors should go through this
// check too.
op.emitError() << "Input tensor [" << input_index
<< "] is a state tensor, but has more than one use.";
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats || stats.getNumElements() != 2) {
stats_op.emitError("Stats should have 2 values.");
return failure();
}
quant::QuantizedType quant_type;
double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
Type expressed = getElementTypeOrSelf(stats_op.getType());
if (tensor_property.extend_to_power_of_two) {
if (tensor_property.number_of_bits != 16) {
op.emitError(
"extended power of 2 scale is only supported for 16-bit"
" quantization.");
return failure();
}
double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
// Set flags to 1 for signed type.
quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
/*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
/*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
} else {
quant_type = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, min, max,
/*narrowRange=*/false, expressed,
/*isSigned=*/true);
}
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
// Returns a function that returns the quantized type of a bias input.
@ -509,7 +545,22 @@ std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
return spec;
}
struct ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs<TFL::SVDFOp> {
public:
explicit ConvertSvdfStatsToQDQs(MLIRContext* context)
: ConvertOpStatsToQDQs<TFL::SVDFOp>(context) {}
LogicalResult matchAndRewrite(TFL::SVDFOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant op_variant = {
.op_code = tflite::BuiltinOperator_SVDF,
};
auto op_property = operator_property::GetOperatorProperty(op_variant);
return ConvertOpStatsToQDQs<TFL::SVDFOp>::processInputs(
op, op_variant, op_property, rewriter);
}
};
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER