Add hybrid quantization support for BatchMatMul during conversion.
PiperOrigin-RevId: 344098609 Change-Id: Ib00f38d3bb7403fbd9d21169301c0a080e632d17
This commit is contained in:
parent
256556c132
commit
f70160322a
@ -49,6 +49,7 @@
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
|
@ -213,6 +213,10 @@ typedef struct {
|
||||
typedef struct {
|
||||
bool adj_x;
|
||||
bool adj_y;
|
||||
// Parameters for BatchMatMul version 4 or above.
|
||||
// If set to true and the weights are quantized, then non constant inputs
|
||||
// are quantized at evaluation time with asymmetric quantization.
|
||||
bool asymmetric_quantize_inputs;
|
||||
} TfLiteBatchMatMulParams;
|
||||
|
||||
typedef struct {
|
||||
|
@ -990,6 +990,10 @@ table SegmentSumOptions {
|
||||
table BatchMatMulOptions {
|
||||
adj_x:bool;
|
||||
adj_y:bool;
|
||||
// Parameters for BatchMatMul version 4 or above.
|
||||
// If set to true, then weights-only op will use asymmetric quantization for
|
||||
// inputs.
|
||||
asymmetric_quantize_inputs: bool;
|
||||
}
|
||||
|
||||
table CumsumOptions {
|
||||
|
@ -90,6 +90,7 @@ std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code,
|
||||
} else if (builtin_op_code == BuiltinOperator_CONV_2D ||
|
||||
builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
|
||||
builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
|
||||
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
|
||||
return {1};
|
||||
} else if (builtin_op_code == BuiltinOperator_SVDF) {
|
||||
@ -145,6 +146,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
|
||||
return custom_op_info->second.is_hybrid;
|
||||
}
|
||||
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
|
||||
builtin_op_code == BuiltinOperator_CONV_2D ||
|
||||
builtin_op_code == BuiltinOperator_SVDF ||
|
||||
builtin_op_code == BuiltinOperator_RNN ||
|
||||
@ -255,6 +257,10 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
||||
op->builtin_options.AsFullyConnectedOptions()
|
||||
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
|
||||
break;
|
||||
case BuiltinOperator_BATCH_MATMUL:
|
||||
op->builtin_options.AsBatchMatMulOptions()
|
||||
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
|
||||
break;
|
||||
case BuiltinOperator_LSTM:
|
||||
op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
|
||||
use_updated_hybrid_scheme;
|
||||
|
@ -557,8 +557,25 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_CONCATENATION:
|
||||
case BuiltinOperator_BATCH_MATMUL:
|
||||
// In case of int16 inputs, the version is 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
|
||||
// This is to use the updated quantization scheme.
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_CONCATENATION:
|
||||
case BuiltinOperator_SOFTMAX:
|
||||
case BuiltinOperator_MEAN:
|
||||
case BuiltinOperator_PAD:
|
||||
@ -792,6 +809,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
|
||||
} break;
|
||||
|
||||
case BuiltinOperator_BATCH_MATMUL: {
|
||||
auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions();
|
||||
op_sig.options.input_quantization.asymmetric_quantize_inputs =
|
||||
batch_matmul_option->asymmetric_quantize_inputs();
|
||||
} break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@ -779,4 +780,49 @@ TEST(OpVersionTest, VersioningAbsTest) {
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
}
|
||||
TEST(OpVersionTest, VersioningBatchMatMulTest) {
|
||||
// Default.
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_MATMUL,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_FLOAT32, TensorType_FLOAT32},
|
||||
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
// int8 input is version 2.
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_MATMUL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
// int16 input is version 3.
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_MATMUL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
// Symmetric hybrid quantized input is version 1.
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_MATMUL,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
// Asymmetric hybrid quantized input is version 4.
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_MATMUL,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user