Add hybrid quantization support for BatchMatMul during conversion.

PiperOrigin-RevId: 344098609
Change-Id: Ib00f38d3bb7403fbd9d21169301c0a080e632d17
This commit is contained in:
Yunlu Li 2020-11-24 11:28:54 -08:00 committed by TensorFlower Gardener
parent 256556c132
commit f70160322a
6 changed files with 85 additions and 1 deletions

View File

@ -49,6 +49,7 @@
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,

View File

@ -213,6 +213,10 @@ typedef struct {
typedef struct {
bool adj_x;
bool adj_y;
// Parameters for BatchMatMul version 4 or above.
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
} TfLiteBatchMatMulParams;
typedef struct {

View File

@ -990,6 +990,10 @@ table SegmentSumOptions {
table BatchMatMulOptions {
adj_x:bool;
adj_y:bool;
// Parameters for BatchMatMul version 4 or above.
// If set to true, then weights-only op will use asymmetric quantization for
// inputs.
asymmetric_quantize_inputs: bool;
}
table CumsumOptions {

View File

@ -90,6 +90,7 @@ std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code,
} else if (builtin_op_code == BuiltinOperator_CONV_2D ||
builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
return {1};
} else if (builtin_op_code == BuiltinOperator_SVDF) {
@ -145,6 +146,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
return custom_op_info->second.is_hybrid;
}
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
builtin_op_code == BuiltinOperator_CONV_2D ||
builtin_op_code == BuiltinOperator_SVDF ||
builtin_op_code == BuiltinOperator_RNN ||
@ -255,6 +257,10 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
op->builtin_options.AsFullyConnectedOptions()
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_BATCH_MATMUL:
op->builtin_options.AsBatchMatMulOptions()
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_LSTM:
op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
use_updated_hybrid_scheme;

View File

@ -557,8 +557,25 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_CONCATENATION:
case BuiltinOperator_BATCH_MATMUL:
// In case of int16 inputs, the version is 3.
if (op_sig.input_types.at(0) == TensorType_INT16) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
// This is to use the updated quantization scheme.
return 4;
}
}
return 1;
case BuiltinOperator_CONCATENATION:
case BuiltinOperator_SOFTMAX:
case BuiltinOperator_MEAN:
case BuiltinOperator_PAD:
@ -792,6 +809,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
} break;
case BuiltinOperator_BATCH_MATMUL: {
auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions();
op_sig.options.input_quantization.asymmetric_quantize_inputs =
batch_matmul_option->asymmetric_quantize_inputs();
} break;
default:
break;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
@ -779,4 +780,49 @@ TEST(OpVersionTest, VersioningAbsTest) {
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
}
TEST(OpVersionTest, VersioningBatchMatMulTest) {
// Default.
OpSignature fake_op_sig = {
.op = BuiltinOperator_BATCH_MATMUL,
.input_types =
std::vector<TensorType>{TensorType_FLOAT32, TensorType_FLOAT32},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
// int8 input is version 2.
fake_op_sig = {
.op = BuiltinOperator_BATCH_MATMUL,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
// int16 input is version 3.
fake_op_sig = {
.op = BuiltinOperator_BATCH_MATMUL,
.input_types = std::vector<TensorType>{TensorType_INT16, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT16},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
// Symmetric hybrid quantized input is version 1.
fake_op_sig = {
.op = BuiltinOperator_BATCH_MATMUL,
.input_types =
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
// Asymmetric hybrid quantized input is version 4.
fake_op_sig = {
.op = BuiltinOperator_BATCH_MATMUL,
.input_types =
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
}
} // namespace tflite