diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index e416574344d..deb5230c760 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -499,6 +499,8 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/tools/optimize:operator_property", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir new file mode 100644 index 00000000000..5c06bb47d05 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-lstm.mlir @@ -0,0 +1,91 @@ +// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s + +// CHECK-LABEL: QuantizeLstmCellInput +func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> { + %cst_1 = constant dense<1.0> : tensor<1x20xf32> + %cst_2 = constant unit + %cst_3 = constant dense<1.0> : tensor<20x20xf32> + %cst_7 = constant dense<1.0> : tensor<20xf32> + %cst_11 = constant dense<1.0> : tensor<20x28xf32> + %cell_input = constant dense<0.0> : tensor<1x20xf32> + %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, + %cst_11, %cst_11, %cst_11, %cst_11, + %cst_3, %cst_3, %cst_3, %cst_3, + %cst_2, %cst_2, %cst_2, + %cst_7, %cst_7, %cst_7, %cst_7, + %cst_2, %cst_2, + %cst_1, %cell_stats, + %cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} + : ( tensor<1x28x28xf32>, + tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, + tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, + none, none, none, + tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, + none, none, + tensor<1x20xf32>, tensor<1x20xf32>, + none, none, none, none) -> tensor<1x28x20xf32> + return %0 : tensor<1x28x20xf32> +// CHECK: %[[none:.*]] = constant unit +// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform> +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform>) -> tensor<1x20xf32> +// Checks if input 19 is correctly passed from a dequantize op. +// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]]) +} + +// CHECK-LABEL: QuantizeIntermediates +func @QuantizeIntermediates(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "input0", outputs = "output24"}} { + %0 = "tfl.pseudo_const"() {value = dense<[[1.31760073, -0.78338623, 0.287265539, -0.383972764, -0.00321021513], [0.104248755, 1.07823908, 0.138089031, 0.76123321, -1.4124943]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %1 = "tfl.pseudo_const"() {value = dense<[[2.32939887, -0.623641372, -0.0191893689, 0.326861918, 0.734137893], [0.499284297, 1.25277913, 0.60228157, -1.39478016, 0.115529917]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[0.839470446, 0.564852297, -0.80136007, -0.0372898243, 0.57127893], [-5.516230e-01, -1.082380e+00, 1.41860521, -0.92541927, -1.13971734]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %3 = "tfl.pseudo_const"() {value = dense<[[-0.440826088, -0.0863231644, -0.707756281, -0.695703208, -1.87899077], [0.16942361, 0.206325337, 1.09067786, -2.18648934, 0.273400396]]> : tensor<2x5xf32>} : () -> tensor<2x5xf32> + %4 = "tfl.pseudo_const"() {value = dense<[[-1.65420437, 0.19633314, 0.828249216, -0.546153665], [-1.49073172, 1.6467551, 0.904948651, 1.1367631]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %5 = "tfl.pseudo_const"() {value = dense<[[-0.435141891, -0.940576493, 1.30446923, -1.02953017], [0.684501767, 0.363370508, -2.29151702, 2.41928673]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %6 = "tfl.pseudo_const"() {value = dense<[[0.270476967, 0.00706229592, 0.489950746, 1.05166924], [1.28193891, 0.273171216, 0.484176666, 1.11504579]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %7 = "tfl.pseudo_const"() {value = dense<[[-2.36692929, -3.483900e-01, 0.322934568, -1.56939185], [-5.623850e-01, -0.083735466, 1.73820043, 0.218063414]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> + %8 = "tfl.pseudo_const"() {value = dense<[1.43194032, -0.553496838]> : tensor<2xf32>} : () -> tensor<2xf32> + %9 = "tfl.pseudo_const"() {value = dense<[-1.66391921, 1.14934266]> : tensor<2xf32>} : () -> tensor<2xf32> + %10 = "tfl.pseudo_const"() {value = dense<[-1.59288621, 0.904723584]> : tensor<2xf32>} : () -> tensor<2xf32> + %11 = "tfl.pseudo_const"() {value = dense<[-0.323118627, 1.77580559]> : tensor<2xf32>} : () -> tensor<2xf32> + %12 = "tfl.pseudo_const"() {value = dense<[-1.0347594, -1.09994471]> : tensor<2xf32>} : () -> tensor<2xf32> + %13 = "tfl.pseudo_const"() {value = dense<[-2.03072214, -1.63648951]> : tensor<2xf32>} : () -> tensor<2xf32> + %14 = "tfl.pseudo_const"() {value = dense<[-1.90073407, -0.286088765]> : tensor<2xf32>} : () -> tensor<2xf32> + %15 = "tfl.pseudo_const"() {value = dense<[[0.580187321, -1.72028887], [1.48392391, 0.859561979], [0.316514879, 0.81852132], [0.0933789983, 0.58165586]]> : tensor<4x2xf32>} : () -> tensor<4x2xf32> + %16 = "tfl.pseudo_const"() {value = dense<[-0.0432887711, -0.431485623, -0.307492912, -0.882515907]> : tensor<4xf32>} : () -> tensor<4xf32> + %17 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x4xf32>} : () -> tensor<1x4xf32> + %18 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %19 = "tfl.pseudo_const"() {value = dense<[0.928654432, -0.393729329]> : tensor<2xf32>} : () -> tensor<2xf32> + %20 = "tfl.pseudo_const"() {value = dense<[-0.76004064, -0.892570137]> : tensor<2xf32>} : () -> tensor<2xf32> + %21 = "tfl.pseudo_const"() {value = dense<[-0.330534697, -1.68513882]> : tensor<2xf32>} : () -> tensor<2xf32> + %22 = "tfl.pseudo_const"() {value = dense<[-0.896740913, -0.382640809]> : tensor<2xf32>} : () -> tensor<2xf32> + %23 = "tfl.unidirectional_sequence_lstm"(%arg0, + %0, %1, %2, %3, + %4, %5, %6, %7, + %8, %9, %10, + %11, %12, %13, %14, + %15, %16, + %17, %18, + %19, %20, %21, %22) {cell_clip = 5.000000e+01 : f32, + effective_hidden_scale_intermediate = tensor>>, + fused_activation_function = "TANH", + input_to_cell_intermediate = tensor>>, + input_to_forget_intermediate = tensor>>, + input_to_input_intermediate = tensor>>, + input_to_output_intermediate = tensor>>, + proj_clip = 0.000000e+00 : f32, time_major = false} : ( + tensor<1x5xf32>, + tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, tensor<2x5xf32>, + tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>, + tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, + tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, + tensor<4x2xf32>, tensor<4xf32>, + tensor<1x4xf32>, tensor<1x2xf32>, + tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32> + return %23 : tensor<*xf32> +// CHECK: effective_hidden_scale_intermediate = tensor> +// CHECK: input_to_cell_intermediate = tensor:f32, 1.2207403790398877E-4>> +// CHECK: input_to_forget_intermediate = tensor:f32, 4.8829615161595508E-4>> +// CHECK: input_to_input_intermediate = tensor:f32, 9.7659230323191015E-4>> +// CHECK: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>>, +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 12a158ad77c..505faf51fc7 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -166,37 +166,3 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) // PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> // PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]] } - -// CHECK-LABEL: QuantizeLstmCellInput -func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> { - %cst_1 = constant dense<1.0> : tensor<1x20xf32> - %cst_2 = constant unit - %cst_3 = constant dense<1.0> : tensor<20x20xf32> - %cst_7 = constant dense<1.0> : tensor<20xf32> - %cst_11 = constant dense<1.0> : tensor<20x28xf32> - %cell_input = constant dense<0.0> : tensor<1x20xf32> - %cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, - %cst_11, %cst_11, %cst_11, %cst_11, - %cst_3, %cst_3, %cst_3, %cst_3, - %cst_2, %cst_2, %cst_2, - %cst_7, %cst_7, %cst_7, %cst_7, - %cst_2, %cst_2, - %cst_1, %cell_stats, - %cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} - : ( tensor<1x28x28xf32>, - tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, - tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, - none, none, none, - tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, - none, none, - tensor<1x20xf32>, tensor<1x20xf32>, - none, none, none, none) -> tensor<1x28x20xf32> - return %0 : tensor<1x28x20xf32> -// CHECK: %[[none:.*]] = constant unit -// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform> -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform>) -> tensor<1x20xf32> -// Checks if input 19 is correctly passed from a dequantize op. -// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]]) -} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 235117fcb92..6bf469bf4ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -19,25 +19,34 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tools/optimize/operator_property.h" // NOLINTNEXTLINE static llvm::cl::list quantize_allowlist( @@ -322,6 +331,101 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern { : OpRewritePattern(context, /*benefit=*/2) {} LogicalResult matchAndRewrite(SourceOp op, PatternRewriter& rewriter) const override { + tflite::optimize::operator_property::OpVariant lstm_variant; + if (llvm::isa(op.getOperation())) { + lstm_variant.op_code = tflite::BuiltinOperator_LSTM; + } else if (llvm::isa( + op.getOperation())) { + lstm_variant.op_code = + tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; + } else { + op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs."); + return failure(); + } + lstm_variant.use_projection = + !op.projection_weights().getType().template isa(); + lstm_variant.use_peephole = + !op.cell_to_output_weights().getType().template isa(); + lstm_variant.use_peephole = + !op.cell_to_output_weights().getType().template isa(); + lstm_variant.use_layer_norm = + !op.forget_layer_norm_coefficients().getType().template isa(); + + auto lstm_property = + tflite::optimize::operator_property::GetOperatorProperty(lstm_variant); + + // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td + const std::vector intermediate_attributes = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + + for (auto& enumerated_intermediates : lstm_property.intermediates) { + int index = enumerated_intermediates.first; + auto& tensor_property = enumerated_intermediates.second; + // intermediate tensors 0, 1, 2, 3 are only used with layer normalization. + if (!lstm_variant.use_layer_norm && index != 4) { + continue; + } + // intermediate tensor 4 is only used with projection. + if (!lstm_variant.use_projection && index == 4) { + continue; + } + TypeAttr attr = + op.template getAttrOfType(intermediate_attributes[index]); + + if (!attr) { + op.emitError() + << op.getOperationName() + << " requires quantization values for intermediate tensor " + << intermediate_attributes[index]; + return failure(); + } + auto quantized_type = + QuantizedType::getQuantizedElementType(attr.getValue()); + if (!quantized_type) { + op.emitError() << intermediate_attributes[index] + << " is not quantized."; + return failure(); + } + auto calibrated_type = + quantized_type.dyn_cast(); + if (!calibrated_type) { + int num_storage_bits = quantized_type.getStorageTypeIntegralWidth(); + if (tensor_property.number_of_bits != num_storage_bits) { + op.emitError() << intermediate_attributes[index] + << " is expected to be quantized with " + << tensor_property.number_of_bits << " bits, but got " + << num_storage_bits << " bits instead."; + return failure(); + } + continue; // skip if it is already quantized. + } + quant::UniformQuantizedType qtype; + if (tensor_property.number_of_bits == 8) { + qtype = quant::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, + calibrated_type.getMin(), calibrated_type.getMax(), + /*narrowRange=*/false, calibrated_type.getExpressedType(), + /*isSigned=*/false); + } else if (tensor_property.number_of_bits == 16) { + double max = std::max(std::abs(calibrated_type.getMin()), + std::abs(calibrated_type.getMax())); + qtype = quant::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, -max, max, + /*narrowRange=*/true, calibrated_type.getExpressedType(), + /*isSigned=*/true); + } else { + op.emitError() << "Unsupported quantization bits: " + << tensor_property.number_of_bits; + return failure(); + } + + op.setAttr(intermediate_attributes[index], + TypeAttr::get(qtype.castFromExpressedType( + qtype.castToExpressedType(attr.getValue())))); + } + quant::StatisticsOp stats_op = llvm::dyn_cast_or_null( op.input_cell_state().getDefiningOp()); // Recurrent input is be used within an LSTM, and thus should have one use. @@ -338,10 +442,12 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern { std::abs(FloatAttr::getValueAsDouble(stats.getValue({1})))); double bound = power_of_two_bound(max); Type expressed = stats_op.getType().cast().getElementType(); - // maximum value is adjusted to get a scale of power_of_two(max)/32768. - quant::QuantizedType quant_type = quant::fakeQuantAttrsToType( - stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0, - /*narrow_range*/ false, expressed, /*is_signed*/ true); + // Set flags to 1 for signed type. + quant::QuantizedType quant_type = UniformQuantizedType::getChecked( + quant::QuantizationFlags::Signed, + IntegerType::get(16, expressed.getContext()), expressed, + /*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16), + llvm::maxIntN(16), op.getLoc()); rewriter.setInsertionPointAfter(stats_op); Type result_type = quant_type.castFromExpressedType(stats_op.getType()); diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index c77c08eeb6b..88015d7634d 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -195,7 +195,6 @@ cc_library( compatible_with = get_compatible_with_cloud(), deps = [ "//tensorflow/lite:framework", - "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", ], diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index e1cbad8e60b..45dff78ef92 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -22,21 +22,6 @@ namespace optimize { namespace operator_property { namespace { - -// The op as well as it variants. -// TODO(jianlijianli): extend it to support ops that has multiple variants. -struct OpVariant { - BuiltinOperator op_code; - bool use_layer_norm = false; - bool use_projection = false; - bool use_peephole = false; - // An attribute to indicate if quantization is supported for this Op. - // This attribute is equivalent to the "quantizable" attribute in - // "OperatorProperty". It added here since OpVariants peeks inside the Op and - // determines its quantization related properties. - bool is_quantizable = true; -}; - const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index, int op_index) { OpVariant op_variant; @@ -67,12 +52,16 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index, } } // namespace -// Update operation defintions in TensorFlow Lite dialect accordingly when there -// are any needs on updating the kernel support level. -// LINT.IfChange OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, int op_index) { OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index); + return GetOperatorProperty(op_variant); +} + +// Update operation defintions in TensorFlow Lite dialect accordingly when there +// are any needs on updating the kernel support level. +// LINT.IfChange +OperatorProperty GetOperatorProperty(OpVariant op_variant) { BuiltinOperator op_code = op_variant.op_code; OperatorProperty property; switch (op_code) { diff --git a/tensorflow/lite/tools/optimize/operator_property.h b/tensorflow/lite/tools/optimize/operator_property.h index 58922a60e27..98b09a71c99 100644 --- a/tensorflow/lite/tools/optimize/operator_property.h +++ b/tensorflow/lite/tools/optimize/operator_property.h @@ -125,8 +125,23 @@ struct OperatorProperty { bool quantize_input_as_activations = false; }; +// The op as well as it variants. +// TODO(b/174283888): extend it to support ops that has multiple variants. +struct OpVariant { + BuiltinOperator op_code; + bool use_layer_norm = false; + bool use_projection = false; + bool use_peephole = false; + // An attribute to indicate if quantization is supported for this Op. + // This attribute is equivalent to the "quantizable" attribute in + // "OperatorProperty". It added here since OpVariants peeks inside the Op and + // determines its quantization related properties. + bool is_quantizable = true; +}; + OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, int op_index); +OperatorProperty GetOperatorProperty(OpVariant op_variant); } // namespace operator_property } // namespace optimize