From fe004863b7308658000d48e31632b17e5fc537eb Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Thu, 7 Jan 2021 22:31:18 -0800 Subject: [PATCH] Add support for SVDF in MLIR quantizer * Added quantization traits for SVDF in tfl_ops.td. * Added logic for removing dangling SVDF ops * Added handling for 10-bit quantization of weights, using [-512, 512] range. * Unified logic for quantizing complex ops. (LSTM, SVDF) * Ensured quantization range contains zero. PiperOrigin-RevId: 350702871 Change-Id: I719a5cd005d0e70db5266e6700ad3c2e58846e7a --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 7 +- .../lite/quantization/quantization_utils.cc | 16 +- .../lite/quantization/quantization_utils.h | 16 + ...ir => prepare-quantize-post-training.mlir} | 22 ++ .../lite/transforms/default_quant_params.cc | 2 +- .../mlir/lite/transforms/post_quantize.cc | 26 +- .../mlir/lite/transforms/prepare_quantize.cc | 3 +- ...ntize_lstm.h => prepare_quantize_helper.h} | 339 ++++++++++-------- 9 files changed, 263 insertions(+), 170 deletions(-) rename tensorflow/compiler/mlir/lite/tests/{prepare-quantize-lstm.mlir => prepare-quantize-post-training.mlir} (94%) rename tensorflow/compiler/mlir/lite/transforms/{prepare_quantize_lstm.h => prepare_quantize_helper.h} (84%) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 4bb5e055eea..9f78594a2cb 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -530,7 +530,7 @@ cc_library( ], hdrs = [ "transforms/passes.h", - "transforms/prepare_quantize_lstm.h", + "transforms/prepare_quantize_helper.h", ], deps = [ "convert_type", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e9b4500380b..6a461142c27 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4303,7 +4303,8 @@ def TFL_SVDFOp : TFL_Op<"svdf", [ PredOpTrait<"the input and result tensor elemental types must be same", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_StatefulOp]> { + TFL_StatefulOp, + AccumulatorUniformScale<3, 2, 4>]> { let summary = "Single value decomposition filter operator"; @@ -4321,10 +4322,10 @@ def TFL_SVDFOp : TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights, // Time weights - TFL_TensorOf<[F32, QI8]>:$time_weights, + TFL_TensorOf<[F32, QI16]>:$time_weights, // Bias - TFL_TensorOfOrNone<[F32]>:$input_gate_bias, + TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias, // Activation state. TFL_StatefulTensor:$activation_state, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index c8ac2c2dea3..72b5d534259 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -72,10 +72,10 @@ static void ExpandVerySmallRange(ArrayRef mins, ArrayRef maxs, // input_type/min/max/storag_type_width/narrow_range. // This is entry point to the Quant dialect and used for both quantizing // activations and weights. -static Type GetQuantizedType(Builder builder, Type input_type, - ArrayRef min, ArrayRef max, - int quant_dim, int storage_type_width, - bool narrow_range, bool is_signed) { +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, + bool is_signed) { auto converter = quant::ExpressedToQuantizedConverter::forInputType(input_type); @@ -253,10 +253,10 @@ TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, return TypeAttr::get(final_type); } -static void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, - int slice_size, bool symmetric, - SmallVector& mins, - SmallVector& maxs) { +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVector& mins, + SmallVector& maxs) { // If all the element values are same we don't need to scan the content. if (values.isSplat()) { double single_value = diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 48b0db6613a..04d46273982 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -553,6 +553,22 @@ quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, int64_t zero_point, int64_t storage_min = -128, int64_t storage_max = 127); + +// Extrace min and max values from the DenseFPElementsAttr, and stores them into +// `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size` +// is number of channels and `slice_size` is the size of slice per each channel. +// When `symmetric` is true, the range is expanded to [-M, M]. +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVector& mins, + SmallVector& maxs); + +// Returns the quantized type for the +// input_type/min/max/storag_type_width/narrow_range. +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, + bool is_signed); } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir similarity index 94% rename from tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir rename to tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir index aac58c9c43e..46a9876cd03 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir @@ -380,3 +380,25 @@ func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} } +// CHECK-LABEL: QuantizeSVDF +func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { + %0 = "quant.stats"(%arg0) {layerStats = dense<[2.07937503, 1.365000e+01]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %1 = "tfl.pseudo_const"() {value = dense<[[1.125947117805481, 1.0, 1.1], [-1.164743185043335, -1.0, -1.1]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[1.8328168392181396], [-1.897219181060791]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32> + %3 = "tfl.pseudo_const"() {value = dense<[1.4014043807983398, -1.0950859785079956]> : tensor<2xf32>} : () -> tensor<2xf32> + %4 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32> + %5 = "quant.stats"(%4) {layerStats = dense<[-56.2916565, 122.922478]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> + %6 = "tfl.svdf"(%0, %1, %2, %3, %5) {fused_activation_function = "RELU", rank = 1 : i32} : (tensor<1x3xf32>, tensor<2x3xf32>, tensor<2x1xf32>, tensor<2xf32>, tensor<1x4xf32>) -> tensor<1x2xf32> + %7 = "quant.stats"(%6) {layerStats = dense<[0.000000e+00, 33.0349121]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %7 : tensor<1x2xf32> + +// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform>) +// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x3x!quant.uniform:f32, 0.0091712061814435818>>) +// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x1x!quant.uniform:f32, 0.0037055062130093575>>) +// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) +// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform:f32, 0.0037514108011770368>>) +// CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]]) +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform>} +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%11) +// CHECK: return %[[dq]] +} diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index c4531628612..e5536eb10da 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h" +#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" //===----------------------------------------------------------------------===// // The Pass to add default quantization parameters for the activations which diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index be8b0967d1e..8ee5556afb1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -139,21 +139,20 @@ struct RemoveVolatileOps : public OpRewritePattern { } }; -// Removes LSTMs that have dangling output. -// LSTMs are not removed automatically becuase they are stateful ops. -template -struct PruneUnusedLstm : public OpRewritePattern { +// Removes operations with side effect (i.e. LSTM, SVDF) that have dangling +// output. +template +struct PruneUnusedOpsWithSideEffect : public OpRewritePattern { public: - explicit PruneUnusedLstm(MLIRContext* context) - : OpRewritePattern(context) {} + explicit PruneUnusedOpsWithSideEffect(MLIRContext* context) + : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(LstmOpTy lstm_op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const override { - Operation* op = lstm_op.getOperation(); - if (op->isKnownTerminator()) { + if (op.getOperation()->isKnownTerminator()) { return failure(); } - for (auto result : op->getOpResults()) { + for (auto result : op.getOperation()->getOpResults()) { if (!result.use_empty()) { return failure(); } @@ -171,8 +170,11 @@ void PostQuantizePass::runOnFunction() { auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, patterns); patterns.insert>(ctx); - patterns.insert>(ctx); - patterns.insert>(ctx); + patterns.insert>(ctx); + patterns + .insert>( + ctx); + patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, std::move(patterns)); if (!emit_quant_adaptor_ops_) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index ac8d5a0a10a..b8b3117069f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -43,7 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h" +#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -388,6 +388,7 @@ void PrepareQuantizePass::runOnFunction() { patterns_2.insert>(ctx, quant_specs_); patterns_2.insert>( ctx, quant_specs_); + patterns_2.insert(ctx); } applyPatternsAndFoldGreedily(func, std::move(patterns_2)); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h similarity index 84% rename from tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h rename to tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index 74c3e23a298..3e96f53e58f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -15,8 +15,8 @@ limitations under the License. // Transform pass for LSTMs. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM -#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER #include #include @@ -221,15 +221,187 @@ struct PrepareLstmOutputScale : public OpRewritePattern { } }; +template +struct ConvertOpStatsToQDQs : public OpRewritePattern { + public: + explicit ConvertOpStatsToQDQs(MLIRContext* context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + protected: + LogicalResult processInputs( + SourceOp op, const operator_property::OpVariant& op_variant, + const operator_property::OperatorProperty& op_property, + PatternRewriter& rewriter) const { + for (auto& enumerated_inputs : op_property.inputs) { + int index = enumerated_inputs.first; + auto& tensor_property = enumerated_inputs.second; + + Value input = op.getOperand(index); + + if (input.getDefiningOp() == nullptr) continue; + + // TODO(b/172517537): make this work with non-PTQ case. + if (llvm::isa(input.getDefiningOp())) { + // Tensors with derived scale are biases, and handled in propagation. + if (tensor_property.use_derived_scale) continue; + // For weights, use quantization scale inferred from the values. + if (failed(processConstantOp(op, input.getDefiningOp(), index, + tensor_property, rewriter))) { + return failure(); + } + } else { + if (auto stats_op = + llvm::dyn_cast(input.getDefiningOp())) { + if (failed(replaceStatsOp(op, stats_op, index, tensor_property, + rewriter))) { + return failure(); + } + } else if (!llvm::isa(input.getDefiningOp()) && + !llvm::isa(input.getDefiningOp())) { + // Continue if StatisticsOp is already converted to Q-DQ pair, or + // stats op is not immediately available to the input because it's + // connected to ops with same scale requirements. + // TODO(b/172517537): make this work with non-PTQ case. + op.emitError() << "Input " << index + << " should be from DequantizeCast, Statistics, " + << ", or ops with same scale requirement."; + input.getDefiningOp()->emitError(); + return failure(); + } + } + } + return success(); + } + + LogicalResult processConstantOp( + SourceOp op, Operation* const_op, int input_index, + const operator_property::TensorProperty& tensor_property, + PatternRewriter& rewriter) const { + // Non-float tensors are neither weights nor require quantization. + auto type = const_op->getResult(0).getType().dyn_cast(); + if (!type || !type.getElementType().isa()) return success(); + + DenseFPElementsAttr attr; + if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { + const_op->emitError("Not a constant op."); + return failure(); + } + + UniformQuantizedType quant_type = nullptr; + // When the number of bits is 10 (instead of 16), quantize the tensor to + // [-512, 512], instead of [-32767, 32767]. + // For now this behavior is specific for SVDF, where 6 bits are reserved for + // the reduce operation after element-wise multiplication between state and + // time weights. + if (tensor_property.number_of_bits == 10) { + SmallVector mins(1, std::numeric_limits::max()); + SmallVector maxs(1, std::numeric_limits::min()); + // Computes the effective min/max values of the attribute values. + quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1, + /*symmetric=*/true, mins, maxs); + double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits); + quant_type = UniformQuantizedType::getChecked( + quant::QuantizationFlags::Signed, + rewriter.getIntegerType(16, /*isSigned=*/true), + attr.getType().getElementType(), scale, /*zeroPoint=*/0, + llvm::minIntN(10), -llvm::minIntN(10), const_op->getLoc()); + } else { + quant_type = + quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, + /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true, + /*narrow_range=*/true) + .template dyn_cast(); + } + if (!quant_type) { + const_op->emitError("Failed to get quantized type"); + return failure(); + } + + // TODO(b/172517537): duplicate the constant when the bias is shared. + Type expressed_type = const_op->getResult(0).getType(); + Type cast_type = quant_type.castFromExpressedType(expressed_type); + rewriter.setInsertionPointAfter(const_op); + auto q = rewriter.create(const_op->getLoc(), cast_type, + const_op->getResult(0)); + auto dq = rewriter.create(const_op->getLoc(), expressed_type, q); + op.setOperand(input_index, dq.getResult()); + return success(); + } + + LogicalResult replaceStatsOp( + SourceOp op, quant::StatisticsOp stats_op, int input_index, + const operator_property::TensorProperty& tensor_property, + PatternRewriter& rewriter) const { + if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) { + // TODO(b/172517537): check if other tensors should go through this + // check too. + op.emitError() << "Input tensor [" << input_index + << "] is a state tensor, but has more than one use."; + return failure(); + } + auto stats = stats_op.layerStats().dyn_cast(); + if (!stats || stats.getNumElements() != 2) { + stats_op.emitError("Stats should have 2 values."); + return failure(); + } + quant::QuantizedType quant_type; + double min = FloatAttr::getValueAsDouble(stats.getValue({0})); + double max = FloatAttr::getValueAsDouble(stats.getValue({1})); + // Make sure the range includes zero. + min = std::min(min, 0.0); + max = std::max(max, 0.0); + Type expressed = getElementTypeOrSelf(stats_op.getType()); + + if (tensor_property.extend_to_power_of_two) { + if (tensor_property.number_of_bits != 16) { + op.emitError( + "extended power of 2 scale is only supported for 16-bit" + " quantization."); + return failure(); + } + + double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max))); + // Set flags to 1 for signed type. + quant_type = UniformQuantizedType::getChecked( + quant::QuantizationFlags::Signed, + rewriter.getIntegerType(tensor_property.number_of_bits), expressed, + /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits), + /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits), + llvm::maxIntN(tensor_property.number_of_bits), op.getLoc()); + } else { + // int16 uses range [-32767, 32767] + if (tensor_property.number_of_bits == 16) { + max = std::max(std::abs(min), std::abs(max)); + min = -max; + quant_type = quant::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, min, max, + /*narrowRange=*/true, expressed, + /*isSigned=*/true); + } else { + quant_type = quant::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, min, max, + /*narrowRange=*/false, expressed, + /*isSigned=*/true); + } + } + rewriter.setInsertionPointAfter(stats_op); + Type result_type = quant_type.castFromExpressedType(stats_op.getType()); + auto q = rewriter.create(stats_op.getLoc(), result_type, stats_op.arg()); + rewriter.replaceOpWithNewOp(stats_op, stats_op.getType(), q); + return success(); + } +}; + // Quantize LSTM according to its quantization recipe. template -struct ConvertLstmStatsToQDQs : public OpRewritePattern { +struct ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { public: ConvertLstmStatsToQDQs(MLIRContext* context, const QuantizationSpecs& quant_specs) - : OpRewritePattern(context, /*benefit=*/2), - quant_specs(quant_specs) {} + : ConvertOpStatsToQDQs(context), quant_specs(quant_specs) {} LogicalResult matchAndRewrite(SourceOp op, PatternRewriter& rewriter) const override { operator_property::OpVariant lstm_variant; @@ -239,7 +411,8 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern { } if (failed(processIntermediates(op, lstm_variant, lstm_property)) || - failed(processInputs(op, lstm_variant, lstm_property, rewriter))) { + failed(ConvertOpStatsToQDQs::processInputs( + op, lstm_variant, lstm_property, rewriter))) { return failure(); } @@ -311,143 +484,6 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern { } return success(); } - - LogicalResult processInputs( - SourceOp op, const operator_property::OpVariant& lstm_variant, - const operator_property::OperatorProperty& lstm_property, - PatternRewriter& rewriter) const { - for (auto& enumerated_inputs : lstm_property.inputs) { - int index = enumerated_inputs.first; - auto& tensor_property = enumerated_inputs.second; - - Value input = op.getOperand(index); - - if (input.getDefiningOp() == nullptr) continue; - - // TODO(b/172517537): make this work with non-PTQ case. - if (llvm::isa(input.getDefiningOp())) { - // Tensors with derived scale are biases, and handled in propagation. - if (tensor_property.use_derived_scale) continue; - if (failed(processConstantOp(op, input.getDefiningOp(), index, - tensor_property, rewriter))) { - return failure(); - } - } else { - if (auto stats_op = - llvm::dyn_cast(input.getDefiningOp())) { - if (failed(replaceStatsOp(op, stats_op, index, tensor_property, - rewriter))) { - return failure(); - } - } else if (!llvm::isa(input.getDefiningOp()) && - !llvm::isa(input.getDefiningOp())) { - // Continue if StatisticsOp is already converted to Q-DQ pair, or - // stats op is not immediately available to the input because it's - // connected to ops with same scale requirements. - // TODO(b/172517537): make this work with non-PTQ case. - op.emitError() << "Input " << index - << " should be from DequantizeCast, Statistics, " - << ", or ops with same scale requirement."; - input.getDefiningOp()->emitError(); - return failure(); - } - } - } - return success(); - } - - // For weights, use quantization scale directly inferred from the values. - // - // input 1~4: input to gate weights - // input 5~8: recurrent to gate weights - // input 9~11: peephole weights, input 16: projection weight - // input 20~23: normalization weights - LogicalResult processConstantOp( - SourceOp op, Operation* const_op, int input_index, - const operator_property::TensorProperty& tensor_property, - PatternRewriter& rewriter) const { - // Non-float tensors are neither weights nor require quantization. - auto type = const_op->getResult(0).getType().dyn_cast(); - if (!type || !type.getElementType().isa()) return success(); - - DenseFPElementsAttr attr; - if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { - const_op->emitError("Not a constant op."); - return failure(); - } - - UniformQuantizedType quant_type = - quant::GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/true, - /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true, - /*narrow_range=*/true) - .template dyn_cast(); - - if (!quant_type) { - const_op->emitError("Failed to get quantized type"); - return failure(); - } - - // TODO(b/172517537): duplicate the constant when the bias is shared. - Type expressed_type = const_op->getResult(0).getType(); - Type cast_type = quant_type.castFromExpressedType(expressed_type); - rewriter.setInsertionPointAfter(const_op); - auto q = rewriter.create(const_op->getLoc(), cast_type, - const_op->getResult(0)); - auto dq = rewriter.create(const_op->getLoc(), expressed_type, q); - op.setOperand(input_index, dq.getResult()); - return success(); - } - - LogicalResult replaceStatsOp( - SourceOp op, quant::StatisticsOp stats_op, int input_index, - const operator_property::TensorProperty& tensor_property, - PatternRewriter& rewriter) const { - if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) { - // TODO(b/172517537): check if other tensors should go through this - // check too. - op.emitError() << "Input tensor [" << input_index - << "] is a state tensor, but has more than one use."; - return failure(); - } - auto stats = stats_op.layerStats().dyn_cast(); - if (!stats || stats.getNumElements() != 2) { - stats_op.emitError("Stats should have 2 values."); - return failure(); - } - quant::QuantizedType quant_type; - double min = FloatAttr::getValueAsDouble(stats.getValue({0})); - double max = FloatAttr::getValueAsDouble(stats.getValue({1})); - Type expressed = getElementTypeOrSelf(stats_op.getType()); - - if (tensor_property.extend_to_power_of_two) { - if (tensor_property.number_of_bits != 16) { - op.emitError( - "extended power of 2 scale is only supported for 16-bit" - " quantization."); - return failure(); - } - - double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max))); - // Set flags to 1 for signed type. - quant_type = UniformQuantizedType::getChecked( - quant::QuantizationFlags::Signed, - rewriter.getIntegerType(tensor_property.number_of_bits), expressed, - /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits), - /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits), - llvm::maxIntN(tensor_property.number_of_bits), op.getLoc()); - } else { - quant_type = quant::fakeQuantAttrsToType( - op.getLoc(), tensor_property.number_of_bits, min, max, - /*narrowRange=*/false, expressed, - /*isSigned=*/true); - } - rewriter.setInsertionPointAfter(stats_op); - Type result_type = quant_type.castFromExpressedType(stats_op.getType()); - auto q = rewriter.create(stats_op.getLoc(), result_type, stats_op.arg()); - rewriter.replaceOpWithNewOp(stats_op, stats_op.getType(), q); - return success(); - } }; // Returns a function that returns the quantized type of a bias input. @@ -509,7 +545,22 @@ std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { return spec; } +struct ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs { + public: + explicit ConvertSvdfStatsToQDQs(MLIRContext* context) + : ConvertOpStatsToQDQs(context) {} + LogicalResult matchAndRewrite(TFL::SVDFOp op, + PatternRewriter& rewriter) const override { + operator_property::OpVariant op_variant = { + .op_code = tflite::BuiltinOperator_SVDF, + }; + auto op_property = operator_property::GetOperatorProperty(op_variant); + return ConvertOpStatsToQDQs::processInputs( + op, op_variant, op_property, rewriter); + } +}; + } // namespace TFL } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER