From f70160322a579144950dff1537dcbe3c7c09d6f5 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Tue, 24 Nov 2020 11:28:54 -0800 Subject: [PATCH] Add hybrid quantization support for BatchMatMul during conversion. PiperOrigin-RevId: 344098609 Change-Id: Ib00f38d3bb7403fbd9d21169301c0a080e632d17 --- RELEASE.md | 1 + tensorflow/lite/c/builtin_op_data.h | 4 ++ tensorflow/lite/schema/schema.fbs | 4 ++ .../lite/tools/optimize/quantize_weights.cc | 6 +++ .../lite/tools/versioning/op_version.cc | 25 +++++++++- .../lite/tools/versioning/op_version_test.cc | 46 +++++++++++++++++++ 6 files changed, 85 insertions(+), 1 deletion(-) diff --git a/RELEASE.md b/RELEASE.md index 9ccef55583a..a07685356c6 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -49,6 +49,7 @@ * Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators. * Added support for saved model's session initializer through `TFLiteConverter.from_saved_model`. + * Added dynamic range quantization support for the BatchMatMul op. * TF Core: * Corrected higher-order gradients of control flow constructs (`tf.cond`, diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 5801e3c4abe..31d792ad12d 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -213,6 +213,10 @@ typedef struct { typedef struct { bool adj_x; bool adj_y; + // Parameters for BatchMatMul version 4 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; } TfLiteBatchMatMulParams; typedef struct { diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 0170e682eba..a3e729c3ce1 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -990,6 +990,10 @@ table SegmentSumOptions { table BatchMatMulOptions { adj_x:bool; adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; } table CumsumOptions { diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index 1b22cb56117..29f5c1fa6c7 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -90,6 +90,7 @@ std::vector GetWeightInputIndices(const OperatorCodeT* op_code, } else if (builtin_op_code == BuiltinOperator_CONV_2D || builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) { return {1}; } else if (builtin_op_code == BuiltinOperator_SVDF) { @@ -145,6 +146,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, return custom_op_info->second.is_hybrid; } } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || builtin_op_code == BuiltinOperator_CONV_2D || builtin_op_code == BuiltinOperator_SVDF || builtin_op_code == BuiltinOperator_RNN || @@ -255,6 +257,10 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( op->builtin_options.AsFullyConnectedOptions() ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; break; + case BuiltinOperator_BATCH_MATMUL: + op->builtin_options.AsBatchMatMulOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; case BuiltinOperator_LSTM: op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = use_updated_hybrid_scheme; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 1526921149d..042198a5d05 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -557,8 +557,25 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; - case BuiltinOperator_CONCATENATION: case BuiltinOperator_BATCH_MATMUL: + // In case of int16 inputs, the version is 3. + if (op_sig.input_types.at(0) == TensorType_INT16) { + return 3; + } + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + if (op_sig.input_types.at(0) == TensorType_FLOAT32 && + op_sig.input_types.at(1) == TensorType_INT8 && + op_sig.output_types.at(0) == TensorType_FLOAT32) { + if (op_sig.options.input_quantization.asymmetric_quantize_inputs) { + // This is to use the updated quantization scheme. + return 4; + } + } + return 1; + + case BuiltinOperator_CONCATENATION: case BuiltinOperator_SOFTMAX: case BuiltinOperator_MEAN: case BuiltinOperator_PAD: @@ -792,6 +809,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1)); } break; + case BuiltinOperator_BATCH_MATMUL: { + auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions(); + op_sig.options.input_quantization.asymmetric_quantize_inputs = + batch_matmul_option->asymmetric_quantize_inputs(); + } break; + default: break; } diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 8a34a45a584..449109a8f93 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -779,4 +780,49 @@ TEST(OpVersionTest, VersioningAbsTest) { }; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); } +TEST(OpVersionTest, VersioningBatchMatMulTest) { + // Default. + OpSignature fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .input_types = + std::vector{TensorType_FLOAT32, TensorType_FLOAT32}, + .output_types = std::vector{TensorType_FLOAT32}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // int8 input is version 2. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .input_types = std::vector{TensorType_INT8, TensorType_INT8}, + .output_types = std::vector{TensorType_INT8}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + // int16 input is version 3. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .input_types = std::vector{TensorType_INT16, TensorType_INT8}, + .output_types = std::vector{TensorType_INT16}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // Symmetric hybrid quantized input is version 1. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .input_types = + std::vector{TensorType_FLOAT32, TensorType_INT8}, + .output_types = std::vector{TensorType_FLOAT32}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + // Asymmetric hybrid quantized input is version 4. + fake_op_sig = { + .op = BuiltinOperator_BATCH_MATMUL, + .input_types = + std::vector{TensorType_FLOAT32, TensorType_INT8}, + .output_types = std::vector{TensorType_FLOAT32}, + }; + fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); +} } // namespace tflite