Apply correct quantization recipe for intermediate tensors of LSTMs

PiperOrigin-RevId: 344372258
Change-Id: I7af460e317d52d58c96f9e7beee9b4bce1ef653b
This commit is contained in:
Taehee Jeong 2020-11-25 23:31:07 -08:00 committed by TensorFlower Gardener
parent dd35a3e925
commit 38d042b105
7 changed files with 225 additions and 57 deletions

View File

@ -499,6 +499,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/optimize:operator_property",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",

View File

@ -0,0 +1,91 @@
// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s
// CHECK-LABEL: QuantizeLstmCellInput
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
%cst_2 = constant unit
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
%cst_7 = constant dense<1.0> : tensor<20xf32>
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
%cell_input = constant dense<0.0> : tensor<1x20xf32>
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_stats,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28xf32>,
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
none, none, none,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
none, none,
tensor<1x20xf32>, tensor<1x20xf32>,
none, none, none, none) -> tensor<1x28x20xf32>
return %0 : tensor<1x28x20xf32>
// CHECK: %[[none:.*]] = constant unit
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
// Checks if input 19 is correctly passed from a dequantize op.
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
}
// CHECK-LABEL: QuantizeIntermediates
func @QuantizeIntermediates(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} {
%0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32>
%4 = "tfl.pseudo_const"() {value = dense<[[-1.65420437, 0.19633314, 0.828249216, -0.546153665], [-1.49073172, 1.6467551, 0.904948651, 1.1367631]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
%8 = "tfl.pseudo_const"() {value = dense<[1.43194032, -0.553496838]> : tensor<2xf32>} : () -> tensor<2xf32>
%9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32>
%10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32>
%11 = "tfl.pseudo_const"() {value = dense<[-0.323118627, 1.77580559]> : tensor<2xf32>} : () -> tensor<2xf32>
%12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32>
%13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32>
%14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32>
%15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32>
%16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32>
%17 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
%18 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
%19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32>
%20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32>
%21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32>
%22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32>
%23 = "tfl.unidirectional_sequence_lstm"(%arg0,
%0, %1, %2, %3,
%4, %5, %6, %7,
%8, %9, %10,
%11, %12, %13, %14,
%15, %16,
%17, %18,
%19, %20, %21, %22) {cell_clip = 5.000000e+01 : f32,
effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>,
fused_activation_function = "TANH",
input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>,
input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>,
input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>,
input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>,
proj_clip = 0.000000e+00 : f32, time_major = false} : (
tensor<1x5xf32>,
tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>,
tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>,
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
tensor<4x2xf32>, tensor<4xf32>,
tensor<1x4xf32>, tensor<1x2xf32>,
tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32>
return %23 : tensor<*xf32>
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.uniform<u8:f32, 0.0039215686274509803:128>>
// CHECK: input_to_cell_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 1.2207403790398877E-4>>
// CHECK: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
// CHECK: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
// CHECK: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>,
}

View File

@ -166,37 +166,3 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
}
// CHECK-LABEL: QuantizeLstmCellInput
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
%cst_2 = constant unit
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
%cst_7 = constant dense<1.0> : tensor<20xf32>
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
%cell_input = constant dense<0.0> : tensor<1x20xf32>
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_stats,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28xf32>,
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
none, none, none,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
none, none,
tensor<1x20xf32>, tensor<1x20xf32>,
none, none, none, none) -> tensor<1x28x20xf32>
return %0 : tensor<1x28x20xf32>
// CHECK: %[[none:.*]] = constant unit
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
// Checks if input 19 is correctly passed from a dequantize op.
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
}

View File

@ -19,25 +19,34 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/operator_property.h"
// NOLINTNEXTLINE
static llvm::cl::list<std::string> quantize_allowlist(
@ -322,6 +331,101 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
tflite::optimize::operator_property::OpVariant lstm_variant;
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
op.getOperation())) {
lstm_variant.op_code =
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
} else {
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
return failure();
}
lstm_variant.use_projection =
!op.projection_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_layer_norm =
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
auto lstm_property =
tflite::optimize::operator_property::GetOperatorProperty(lstm_variant);
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
const std::vector<std::string> intermediate_attributes = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto& enumerated_intermediates : lstm_property.intermediates) {
int index = enumerated_intermediates.first;
auto& tensor_property = enumerated_intermediates.second;
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
if (!lstm_variant.use_layer_norm && index != 4) {
continue;
}
// intermediate tensor 4 is only used with projection.
if (!lstm_variant.use_projection && index == 4) {
continue;
}
TypeAttr attr =
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
if (!attr) {
op.emitError()
<< op.getOperationName()
<< " requires quantization values for intermediate tensor "
<< intermediate_attributes[index];
return failure();
}
auto quantized_type =
QuantizedType::getQuantizedElementType(attr.getValue());
if (!quantized_type) {
op.emitError() << intermediate_attributes[index]
<< " is not quantized.";
return failure();
}
auto calibrated_type =
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
if (!calibrated_type) {
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
if (tensor_property.number_of_bits != num_storage_bits) {
op.emitError() << intermediate_attributes[index]
<< " is expected to be quantized with "
<< tensor_property.number_of_bits << " bits, but got "
<< num_storage_bits << " bits instead.";
return failure();
}
continue; // skip if it is already quantized.
}
quant::UniformQuantizedType qtype;
if (tensor_property.number_of_bits == 8) {
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/false);
} else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax()));
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, -max, max,
/*narrowRange=*/true, calibrated_type.getExpressedType(),
/*isSigned=*/true);
} else {
op.emitError() << "Unsupported quantization bits: "
<< tensor_property.number_of_bits;
return failure();
}
op.setAttr(intermediate_attributes[index],
TypeAttr::get(qtype.castFromExpressedType(
qtype.castToExpressedType(attr.getValue()))));
}
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.input_cell_state().getDefiningOp());
// Recurrent input is be used within an LSTM, and thus should have one use.
@ -338,10 +442,12 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
double bound = power_of_two_bound(max);
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
/*narrow_range*/ false, expressed, /*is_signed*/ true);
// Set flags to 1 for signed type.
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
IntegerType::get(16, expressed.getContext()), expressed,
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
llvm::maxIntN(16), op.getLoc());
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());

View File

@ -195,7 +195,6 @@ cc_library(
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels/internal:types",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
],

View File

@ -22,21 +22,6 @@ namespace optimize {
namespace operator_property {
namespace {
// The op as well as it variants.
// TODO(jianlijianli): extend it to support ops that has multiple variants.
struct OpVariant {
BuiltinOperator op_code;
bool use_layer_norm = false;
bool use_projection = false;
bool use_peephole = false;
// An attribute to indicate if quantization is supported for this Op.
// This attribute is equivalent to the "quantizable" attribute in
// "OperatorProperty". It added here since OpVariants peeks inside the Op and
// determines its quantization related properties.
bool is_quantizable = true;
};
const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
int op_index) {
OpVariant op_variant;
@ -67,12 +52,16 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
}
} // namespace
// Update operation defintions in TensorFlow Lite dialect accordingly when there
// are any needs on updating the kernel support level.
// LINT.IfChange
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
int op_index) {
OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index);
return GetOperatorProperty(op_variant);
}
// Update operation defintions in TensorFlow Lite dialect accordingly when there
// are any needs on updating the kernel support level.
// LINT.IfChange
OperatorProperty GetOperatorProperty(OpVariant op_variant) {
BuiltinOperator op_code = op_variant.op_code;
OperatorProperty property;
switch (op_code) {

View File

@ -125,8 +125,23 @@ struct OperatorProperty {
bool quantize_input_as_activations = false;
};
// The op as well as it variants.
// TODO(b/174283888): extend it to support ops that has multiple variants.
struct OpVariant {
BuiltinOperator op_code;
bool use_layer_norm = false;
bool use_projection = false;
bool use_peephole = false;
// An attribute to indicate if quantization is supported for this Op.
// This attribute is equivalent to the "quantizable" attribute in
// "OperatorProperty". It added here since OpVariants peeks inside the Op and
// determines its quantization related properties.
bool is_quantizable = true;
};
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
int op_index);
OperatorProperty GetOperatorProperty(OpVariant op_variant);
} // namespace operator_property
} // namespace optimize