Apply correct quantization recipe for intermediate tensors of LSTMs
PiperOrigin-RevId: 344372258 Change-Id: I7af460e317d52d58c96f9e7beee9b4bce1ef653b
This commit is contained in:
parent
dd35a3e925
commit
38d042b105
@ -499,6 +499,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/tools/optimize:operator_property",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -0,0 +1,91 @@
|
|||||||
|
// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: QuantizeLstmCellInput
|
||||||
|
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
|
||||||
|
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
|
||||||
|
%cst_2 = constant unit
|
||||||
|
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
|
||||||
|
%cst_7 = constant dense<1.0> : tensor<20xf32>
|
||||||
|
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
|
||||||
|
%cell_input = constant dense<0.0> : tensor<1x20xf32>
|
||||||
|
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
|
||||||
|
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
||||||
|
%cst_11, %cst_11, %cst_11, %cst_11,
|
||||||
|
%cst_3, %cst_3, %cst_3, %cst_3,
|
||||||
|
%cst_2, %cst_2, %cst_2,
|
||||||
|
%cst_7, %cst_7, %cst_7, %cst_7,
|
||||||
|
%cst_2, %cst_2,
|
||||||
|
%cst_1, %cell_stats,
|
||||||
|
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
||||||
|
: ( tensor<1x28x28xf32>,
|
||||||
|
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
|
||||||
|
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
|
||||||
|
none, none, none,
|
||||||
|
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
|
||||||
|
none, none,
|
||||||
|
tensor<1x20xf32>, tensor<1x20xf32>,
|
||||||
|
none, none, none, none) -> tensor<1x28x20xf32>
|
||||||
|
return %0 : tensor<1x28x20xf32>
|
||||||
|
// CHECK: %[[none:.*]] = constant unit
|
||||||
|
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
|
||||||
|
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
|
||||||
|
// Checks if input 19 is correctly passed from a dequantize op.
|
||||||
|
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: QuantizeIntermediates
|
||||||
|
func @QuantizeIntermediates(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
|
||||||
|
%0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
|
||||||
|
%1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
|
||||||
|
%2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
|
||||||
|
%3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
|
||||||
|
%4 = "tfl.pseudo_const"() {value = dense<[[-1.65420437, 0.19633314, 0.828249216, -0.546153665], [-1.49073172, 1.6467551, 0.904948651, 1.1367631]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
|
||||||
|
%5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
|
||||||
|
%6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
|
||||||
|
%7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
|
||||||
|
%8 = "tfl.pseudo_const"() {value = dense<[1.43194032, -0.553496838]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%11 = "tfl.pseudo_const"() {value = dense<[-0.323118627, 1.77580559]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
|
||||||
|
%16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||||
|
%17 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||||
|
%18 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
|
||||||
|
%19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%23 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
||||||
|
%0, %1, %2, %3,
|
||||||
|
%4, %5, %6, %7,
|
||||||
|
%8, %9, %10,
|
||||||
|
%11, %12, %13, %14,
|
||||||
|
%15, %16,
|
||||||
|
%17, %18,
|
||||||
|
%19, %20, %21, %22) {cell_clip = 5.000000e+01 : f32,
|
||||||
|
effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
|
||||||
|
fused_activation_function = "TANH",
|
||||||
|
input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>,
|
||||||
|
input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>,
|
||||||
|
input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>,
|
||||||
|
input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>,
|
||||||
|
proj_clip = 0.000000e+00 : f32, time_major = false} : (
|
||||||
|
tensor<1x5xf32>,
|
||||||
|
tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
|
||||||
|
tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
|
||||||
|
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
|
||||||
|
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
|
||||||
|
tensor<4x2xf32>, tensor<4xf32>,
|
||||||
|
tensor<1x4xf32>, tensor<1x2xf32>,
|
||||||
|
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
return %23 : tensor<*xf32>
|
||||||
|
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.uniform<u8:f32, 0.0039215686274509803:128>>
|
||||||
|
// CHECK: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
|
||||||
|
// CHECK: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
|
||||||
|
// CHECK: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
|
||||||
|
// CHECK: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>,
|
||||||
|
}
|
@ -166,37 +166,3 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
|
|||||||
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
|
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
|
||||||
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeLstmCellInput
|
|
||||||
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
|
|
||||||
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
|
|
||||||
%cst_2 = constant unit
|
|
||||||
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
|
|
||||||
%cst_7 = constant dense<1.0> : tensor<20xf32>
|
|
||||||
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
|
|
||||||
%cell_input = constant dense<0.0> : tensor<1x20xf32>
|
|
||||||
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
|
|
||||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
|
||||||
%cst_11, %cst_11, %cst_11, %cst_11,
|
|
||||||
%cst_3, %cst_3, %cst_3, %cst_3,
|
|
||||||
%cst_2, %cst_2, %cst_2,
|
|
||||||
%cst_7, %cst_7, %cst_7, %cst_7,
|
|
||||||
%cst_2, %cst_2,
|
|
||||||
%cst_1, %cell_stats,
|
|
||||||
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
|
||||||
: ( tensor<1x28x28xf32>,
|
|
||||||
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
|
|
||||||
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
|
|
||||||
none, none, none,
|
|
||||||
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
|
|
||||||
none, none,
|
|
||||||
tensor<1x20xf32>, tensor<1x20xf32>,
|
|
||||||
none, none, none, none) -> tensor<1x28x20xf32>
|
|
||||||
return %0 : tensor<1x28x20xf32>
|
|
||||||
// CHECK: %[[none:.*]] = constant unit
|
|
||||||
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
|
|
||||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
|
|
||||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
|
|
||||||
// Checks if input 19 is correctly passed from a dequantize op.
|
|
||||||
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
|
|
||||||
}
|
|
||||||
|
@ -19,25 +19,34 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "llvm/ADT/Optional.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/MathExtras.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::list<std::string> quantize_allowlist(
|
static llvm::cl::list<std::string> quantize_allowlist(
|
||||||
@ -322,6 +331,101 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
|||||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
||||||
LogicalResult matchAndRewrite(SourceOp op,
|
LogicalResult matchAndRewrite(SourceOp op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
|
tflite::optimize::operator_property::OpVariant lstm_variant;
|
||||||
|
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
|
||||||
|
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
|
||||||
|
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
|
||||||
|
op.getOperation())) {
|
||||||
|
lstm_variant.op_code =
|
||||||
|
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
|
||||||
|
} else {
|
||||||
|
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
lstm_variant.use_projection =
|
||||||
|
!op.projection_weights().getType().template isa<NoneType>();
|
||||||
|
lstm_variant.use_peephole =
|
||||||
|
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||||
|
lstm_variant.use_peephole =
|
||||||
|
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||||
|
lstm_variant.use_layer_norm =
|
||||||
|
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
|
||||||
|
|
||||||
|
auto lstm_property =
|
||||||
|
tflite::optimize::operator_property::GetOperatorProperty(lstm_variant);
|
||||||
|
|
||||||
|
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
|
||||||
|
const std::vector<std::string> intermediate_attributes = {
|
||||||
|
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||||
|
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||||
|
"effective_hidden_scale_intermediate"};
|
||||||
|
|
||||||
|
for (auto& enumerated_intermediates : lstm_property.intermediates) {
|
||||||
|
int index = enumerated_intermediates.first;
|
||||||
|
auto& tensor_property = enumerated_intermediates.second;
|
||||||
|
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
|
||||||
|
if (!lstm_variant.use_layer_norm && index != 4) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// intermediate tensor 4 is only used with projection.
|
||||||
|
if (!lstm_variant.use_projection && index == 4) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TypeAttr attr =
|
||||||
|
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
|
||||||
|
|
||||||
|
if (!attr) {
|
||||||
|
op.emitError()
|
||||||
|
<< op.getOperationName()
|
||||||
|
<< " requires quantization values for intermediate tensor "
|
||||||
|
<< intermediate_attributes[index];
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto quantized_type =
|
||||||
|
QuantizedType::getQuantizedElementType(attr.getValue());
|
||||||
|
if (!quantized_type) {
|
||||||
|
op.emitError() << intermediate_attributes[index]
|
||||||
|
<< " is not quantized.";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto calibrated_type =
|
||||||
|
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
|
||||||
|
if (!calibrated_type) {
|
||||||
|
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
|
||||||
|
if (tensor_property.number_of_bits != num_storage_bits) {
|
||||||
|
op.emitError() << intermediate_attributes[index]
|
||||||
|
<< " is expected to be quantized with "
|
||||||
|
<< tensor_property.number_of_bits << " bits, but got "
|
||||||
|
<< num_storage_bits << " bits instead.";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
continue; // skip if it is already quantized.
|
||||||
|
}
|
||||||
|
quant::UniformQuantizedType qtype;
|
||||||
|
if (tensor_property.number_of_bits == 8) {
|
||||||
|
qtype = quant::fakeQuantAttrsToType(
|
||||||
|
op.getLoc(), tensor_property.number_of_bits,
|
||||||
|
calibrated_type.getMin(), calibrated_type.getMax(),
|
||||||
|
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
||||||
|
/*isSigned=*/false);
|
||||||
|
} else if (tensor_property.number_of_bits == 16) {
|
||||||
|
double max = std::max(std::abs(calibrated_type.getMin()),
|
||||||
|
std::abs(calibrated_type.getMax()));
|
||||||
|
qtype = quant::fakeQuantAttrsToType(
|
||||||
|
op.getLoc(), tensor_property.number_of_bits, -max, max,
|
||||||
|
/*narrowRange=*/true, calibrated_type.getExpressedType(),
|
||||||
|
/*isSigned=*/true);
|
||||||
|
} else {
|
||||||
|
op.emitError() << "Unsupported quantization bits: "
|
||||||
|
<< tensor_property.number_of_bits;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
op.setAttr(intermediate_attributes[index],
|
||||||
|
TypeAttr::get(qtype.castFromExpressedType(
|
||||||
|
qtype.castToExpressedType(attr.getValue()))));
|
||||||
|
}
|
||||||
|
|
||||||
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||||
op.input_cell_state().getDefiningOp());
|
op.input_cell_state().getDefiningOp());
|
||||||
// Recurrent input is be used within an LSTM, and thus should have one use.
|
// Recurrent input is be used within an LSTM, and thus should have one use.
|
||||||
@ -338,10 +442,12 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
|||||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
||||||
double bound = power_of_two_bound(max);
|
double bound = power_of_two_bound(max);
|
||||||
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
||||||
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
|
// Set flags to 1 for signed type.
|
||||||
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
|
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
|
||||||
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
|
quant::QuantizationFlags::Signed,
|
||||||
/*narrow_range*/ false, expressed, /*is_signed*/ true);
|
IntegerType::get(16, expressed.getContext()), expressed,
|
||||||
|
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
|
||||||
|
llvm::maxIntN(16), op.getLoc());
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(stats_op);
|
rewriter.setInsertionPointAfter(stats_op);
|
||||||
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
||||||
|
@ -195,7 +195,6 @@ cc_library(
|
|||||||
compatible_with = get_compatible_with_cloud(),
|
compatible_with = get_compatible_with_cloud(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels/internal:types",
|
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/schema:schema_utils",
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
],
|
],
|
||||||
|
@ -22,21 +22,6 @@ namespace optimize {
|
|||||||
namespace operator_property {
|
namespace operator_property {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// The op as well as it variants.
|
|
||||||
// TODO(jianlijianli): extend it to support ops that has multiple variants.
|
|
||||||
struct OpVariant {
|
|
||||||
BuiltinOperator op_code;
|
|
||||||
bool use_layer_norm = false;
|
|
||||||
bool use_projection = false;
|
|
||||||
bool use_peephole = false;
|
|
||||||
// An attribute to indicate if quantization is supported for this Op.
|
|
||||||
// This attribute is equivalent to the "quantizable" attribute in
|
|
||||||
// "OperatorProperty". It added here since OpVariants peeks inside the Op and
|
|
||||||
// determines its quantization related properties.
|
|
||||||
bool is_quantizable = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
|
const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
|
||||||
int op_index) {
|
int op_index) {
|
||||||
OpVariant op_variant;
|
OpVariant op_variant;
|
||||||
@ -67,12 +52,16 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Update operation defintions in TensorFlow Lite dialect accordingly when there
|
|
||||||
// are any needs on updating the kernel support level.
|
|
||||||
// LINT.IfChange
|
|
||||||
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||||
int op_index) {
|
int op_index) {
|
||||||
OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index);
|
OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index);
|
||||||
|
return GetOperatorProperty(op_variant);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update operation defintions in TensorFlow Lite dialect accordingly when there
|
||||||
|
// are any needs on updating the kernel support level.
|
||||||
|
// LINT.IfChange
|
||||||
|
OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
||||||
BuiltinOperator op_code = op_variant.op_code;
|
BuiltinOperator op_code = op_variant.op_code;
|
||||||
OperatorProperty property;
|
OperatorProperty property;
|
||||||
switch (op_code) {
|
switch (op_code) {
|
||||||
|
@ -125,8 +125,23 @@ struct OperatorProperty {
|
|||||||
bool quantize_input_as_activations = false;
|
bool quantize_input_as_activations = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The op as well as it variants.
|
||||||
|
// TODO(b/174283888): extend it to support ops that has multiple variants.
|
||||||
|
struct OpVariant {
|
||||||
|
BuiltinOperator op_code;
|
||||||
|
bool use_layer_norm = false;
|
||||||
|
bool use_projection = false;
|
||||||
|
bool use_peephole = false;
|
||||||
|
// An attribute to indicate if quantization is supported for this Op.
|
||||||
|
// This attribute is equivalent to the "quantizable" attribute in
|
||||||
|
// "OperatorProperty". It added here since OpVariants peeks inside the Op and
|
||||||
|
// determines its quantization related properties.
|
||||||
|
bool is_quantizable = true;
|
||||||
|
};
|
||||||
|
|
||||||
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||||
int op_index);
|
int op_index);
|
||||||
|
OperatorProperty GetOperatorProperty(OpVariant op_variant);
|
||||||
|
|
||||||
} // namespace operator_property
|
} // namespace operator_property
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
|
Loading…
x
Reference in New Issue
Block a user