Add hybrid BatchMatMul kernel that supports legacy symmetric_quantize_inputs.
PiperOrigin-RevId: 346666303 Change-Id: Ife2d74a25aa24a8444c86741dc57b23d6de66ad6
This commit is contained in:
parent
a019f0d125
commit
4d1c107bef
@ -58,10 +58,11 @@
|
||||
directly.
|
||||
* 16 bits quantization
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
* Added DEPTH_TO_SPACE support in Post training quantization.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
* Both symmetric and asymmetric quantized input tensor are supported.
|
||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||
only supports float32 input.
|
||||
* TFLite Supports SingatureDef:
|
||||
|
@ -765,6 +765,8 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
op->builtin_options_as_BatchMatMulOptions()) {
|
||||
params->adj_x = bmm_params->adj_x();
|
||||
params->adj_y = bmm_params->adj_y();
|
||||
params->asymmetric_quantize_inputs =
|
||||
bmm_params->asymmetric_quantize_inputs();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
|
@ -450,6 +450,8 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
|
||||
TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
||||
TfLiteTensor* input_offsets, TfLiteTensor* output) {
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
|
||||
const int32_t num_input_dims = input_shape.DimensionsCount();
|
||||
|
||||
// Input row/cols have been swapped at this point, so dims are
|
||||
@ -465,18 +467,20 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
|
||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||
int32_t* input_offset_ptr = nullptr;
|
||||
int32_t* row_sums_ptr = nullptr;
|
||||
// Only asymmetric quantization is supported.
|
||||
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
|
||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||
if (!params->asymmetric_quantize_inputs) {
|
||||
memset(input_offset_ptr, 0, input_offsets->bytes);
|
||||
}
|
||||
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
||||
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
||||
const float* input_ptr = GetTensorData<float>(input);
|
||||
// Quantize each batch independently.
|
||||
tensor_utils::BatchQuantizeFloats(input_ptr, num_batches_to_quantize,
|
||||
input_size, quant_data, scaling_factors_ptr,
|
||||
input_offset_ptr,
|
||||
params->asymmetric_quantize_inputs);
|
||||
for (int b = 0; b < num_batches_to_quantize; ++b) {
|
||||
const int offset = b * input_size;
|
||||
tensor_utils::AsymmetricQuantizeFloats(
|
||||
input_ptr + offset, input_size, quant_data + offset,
|
||||
&scaling_factors_ptr[b], &input_offset_ptr[b]);
|
||||
// Incorporate scaling of the filter.
|
||||
scaling_factors_ptr[b] *= filter->params.scale;
|
||||
}
|
||||
|
@ -281,12 +281,12 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
|
||||
// In the hybrid model the weights are quantized int8. But the input
|
||||
// and output are expected to be in float precision.
|
||||
class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
|
||||
class HybridBatchMatMulOpModel : public SingleOpModel {
|
||||
public:
|
||||
HybridAsymmetricBatchMatMulOpModel(
|
||||
int units, int batches, const TensorData& lhs, const TensorData& rhs,
|
||||
const TensorData& output = {TensorType_FLOAT32}, bool adj_x = false,
|
||||
bool adj_y = false)
|
||||
HybridBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
|
||||
const TensorData& rhs,
|
||||
const TensorData& output = {TensorType_FLOAT32},
|
||||
bool asymmetric_quantize_inputs = true)
|
||||
: units_(units), batches_(batches) {
|
||||
int total_input_size = 1;
|
||||
for (size_t i = 0; i < lhs.shape.size(); ++i) {
|
||||
@ -299,9 +299,11 @@ class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
|
||||
|
||||
output_id_ = AddOutput(output);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_MATMUL,
|
||||
BuiltinOptions_BatchMatMulOptions,
|
||||
CreateBatchMatMulOptions(builder_, adj_x, adj_y).Union());
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions,
|
||||
CreateBatchMatMulOptions(builder_, /*adj_x=*/false, /*adj_y=*/false,
|
||||
asymmetric_quantize_inputs)
|
||||
.Union());
|
||||
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
|
||||
}
|
||||
void SetWeights(const std::vector<float>& data) {
|
||||
@ -340,7 +342,7 @@ class HybridAsymmetricBatchMatMulOpTest : public SingleOpTest {
|
||||
};
|
||||
|
||||
TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
||||
HybridAsymmetricBatchMatMulOpModel m(
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0});
|
||||
@ -371,7 +373,7 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
||||
}
|
||||
|
||||
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
||||
HybridAsymmetricBatchMatMulOpModel m(
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0});
|
||||
@ -402,7 +404,7 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
||||
}
|
||||
|
||||
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
||||
HybridAsymmetricBatchMatMulOpModel m(
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/9, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 9}, 0, 0, 10.0 / 127.0, 0});
|
||||
@ -437,7 +439,7 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
||||
}
|
||||
|
||||
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
|
||||
HybridAsymmetricBatchMatMulOpModel m(
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {2, 10, 3}, 0, 0, 10.0 / 127.0, 0});
|
||||
@ -470,6 +472,148 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
HybridAsymmetricBatchMatMulOpTest, HybridAsymmetricBatchMatMulOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||
|
||||
class HybridSymmetricBatchMatMulOpTest : public SingleOpTest {
|
||||
protected:
|
||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||
return *kKernelMap;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(HybridSymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0},
|
||||
/*output=*/{TensorType_FLOAT32}, /*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
m.SetSignedWeights({
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
|
||||
6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
|
||||
});
|
||||
|
||||
m.SetInput({
|
||||
11, 12, 13, 14, 15, 16, 17, 18, -19, -20, // batch 1, 0
|
||||
11, 12, 13, 14, 15, 16, 17, -18, 19, -20, // batch 1, 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
194,
|
||||
194,
|
||||
194,
|
||||
248,
|
||||
248,
|
||||
248,
|
||||
},
|
||||
/*max_abs_error=*/0.64f)));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, 10.0 / 127.0, 0},
|
||||
/*output=*/{TensorType_FLOAT32}, /*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
m.SetSignedWeights({
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
|
||||
6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
|
||||
});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // batch 0, 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // batch 0, 1
|
||||
11, 12, 13, 14, 15, 16, 17, 18, -19, -20, // batch 1, 0
|
||||
11, 12, 13, 14, 15, 16, 17, -18, 19, -20, // batch 1, 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
24, 24, 24, //
|
||||
56, 56, 56, //
|
||||
194, 194, 194, //
|
||||
248, 248, 248, //
|
||||
},
|
||||
/*max_abs_error=*/1.3f)));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||
}
|
||||
|
||||
TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/9, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {10, 9}, 0, 0, 10.0 / 127.0, 0},
|
||||
{TensorType_FLOAT32}, false);
|
||||
|
||||
m.SetSignedWeights({
|
||||
1, 1, 1, 17, 17, 17, 26, 26, 26, 2, 2, 2, 18, 18, 18, 27, 27, 27,
|
||||
3, 3, 3, 19, 19, 19, 28, 28, 28, 4, 4, 4, 20, 20, 20, 29, 29, 29,
|
||||
5, 5, 5, 21, 21, 21, 30, 30, 30, 6, 6, 6, 22, 22, 22, 31, 31, 31,
|
||||
7, 7, 7, 23, 23, 23, 32, 32, 32, 8, 8, 8, 24, 24, 24, 33, 33, 33,
|
||||
9, 9, 9, 25, 25, 25, 34, 34, 34, 10, 10, 10, 26, 26, 26, 35, 35, 35,
|
||||
});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // batch 0, 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // batch 0, 1
|
||||
11, 12, 13, 14, 15, 16, 17, 18, -19, -20, // batch 1, 0
|
||||
11, 12, 13, 14, 15, 16, 17, -18, 19, -20, // batch 1, 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
23, 23, 23, 296, 296, 296, 451, 451, 451, //
|
||||
58, 58, 58, 362, 362, 362, 529, 529, 529, //
|
||||
193, 193, 193, 1424, 1424, 1424, 2118, 2118, 2118, //
|
||||
253, 253, 253, 1519, 1519, 1519, 2223, 2223, 2223 //
|
||||
},
|
||||
/*max_abs_error=*/1.3f)));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
|
||||
}
|
||||
|
||||
TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
|
||||
HybridBatchMatMulOpModel m(
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||
/*rhs=*/{TensorType_INT8, {2, 10, 3}, 0, 0, 10.0 / 127.0, 0},
|
||||
{TensorType_FLOAT32}, false);
|
||||
|
||||
m.SetSignedWeights({
|
||||
1, -3, 1, 2, -2, 2, 3, -1, 3, 4, 0, 4, 5, 1, 5, 6, 2, 6, 7, 3,
|
||||
7, 8, 4, 8, 9, 5, 9, 10, 6, 10, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4,
|
||||
4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
|
||||
});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // batch 0, 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // batch 0, 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
24, -45, 24, //
|
||||
56, -19, 56, //
|
||||
24, 24, 24, //
|
||||
56, 56, 56, //
|
||||
},
|
||||
/*max_abs_error=*/0.64f)));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HybridSymmetricBatchMatMulOpTest, HybridSymmetricBatchMatMulOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||
|
||||
class QuantizedBatchMatMulOpModel : public SingleOpModel {
|
||||
public:
|
||||
QuantizedBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
|
||||
|
@ -301,7 +301,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
/* max_version = */ 4);
|
||||
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
|
||||
// The version one of broadcast to op won't be not supported since the version
|
||||
// one was rollbacked and the builtin op code number has been changed because
|
||||
|
@ -61,6 +61,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
|
||||
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"},
|
||||
{{BuiltinOperator_BATCH_MATMUL, 4}, kPendingReleaseVersion},
|
||||
// The version one of broadcast to op won't be not supported since
|
||||
// the version one was rollbacked and the builtin op code number
|
||||
// has been changed because of builtin op code shortage problem.
|
||||
|
Loading…
Reference in New Issue
Block a user