Add hybrid quantization support for BatchMatMul during conversion.
PiperOrigin-RevId: 344098609 Change-Id: Ib00f38d3bb7403fbd9d21169301c0a080e632d17
This commit is contained in:
parent
256556c132
commit
f70160322a
@ -49,6 +49,7 @@
|
|||||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||||
* Added support for saved model's session initializer through
|
* Added support for saved model's session initializer through
|
||||||
`TFLiteConverter.from_saved_model`.
|
`TFLiteConverter.from_saved_model`.
|
||||||
|
* Added dynamic range quantization support for the BatchMatMul op.
|
||||||
|
|
||||||
* TF Core:
|
* TF Core:
|
||||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||||
|
@ -213,6 +213,10 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
bool adj_x;
|
bool adj_x;
|
||||||
bool adj_y;
|
bool adj_y;
|
||||||
|
// Parameters for BatchMatMul version 4 or above.
|
||||||
|
// If set to true and the weights are quantized, then non constant inputs
|
||||||
|
// are quantized at evaluation time with asymmetric quantization.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteBatchMatMulParams;
|
} TfLiteBatchMatMulParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -990,6 +990,10 @@ table SegmentSumOptions {
|
|||||||
table BatchMatMulOptions {
|
table BatchMatMulOptions {
|
||||||
adj_x:bool;
|
adj_x:bool;
|
||||||
adj_y:bool;
|
adj_y:bool;
|
||||||
|
// Parameters for BatchMatMul version 4 or above.
|
||||||
|
// If set to true, then weights-only op will use asymmetric quantization for
|
||||||
|
// inputs.
|
||||||
|
asymmetric_quantize_inputs: bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
table CumsumOptions {
|
table CumsumOptions {
|
||||||
|
@ -90,6 +90,7 @@ std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code,
|
|||||||
} else if (builtin_op_code == BuiltinOperator_CONV_2D ||
|
} else if (builtin_op_code == BuiltinOperator_CONV_2D ||
|
||||||
builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
|
builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
|
||||||
builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||||
|
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
|
||||||
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
|
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
|
||||||
return {1};
|
return {1};
|
||||||
} else if (builtin_op_code == BuiltinOperator_SVDF) {
|
} else if (builtin_op_code == BuiltinOperator_SVDF) {
|
||||||
@ -145,6 +146,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
|
|||||||
return custom_op_info->second.is_hybrid;
|
return custom_op_info->second.is_hybrid;
|
||||||
}
|
}
|
||||||
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||||
|
builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
|
||||||
builtin_op_code == BuiltinOperator_CONV_2D ||
|
builtin_op_code == BuiltinOperator_CONV_2D ||
|
||||||
builtin_op_code == BuiltinOperator_SVDF ||
|
builtin_op_code == BuiltinOperator_SVDF ||
|
||||||
builtin_op_code == BuiltinOperator_RNN ||
|
builtin_op_code == BuiltinOperator_RNN ||
|
||||||
@ -255,6 +257,10 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
|||||||
op->builtin_options.AsFullyConnectedOptions()
|
op->builtin_options.AsFullyConnectedOptions()
|
||||||
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
|
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_BATCH_MATMUL:
|
||||||
|
op->builtin_options.AsBatchMatMulOptions()
|
||||||
|
->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
|
||||||
|
break;
|
||||||
case BuiltinOperator_LSTM:
|
case BuiltinOperator_LSTM:
|
||||||
op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
|
op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
|
||||||
use_updated_hybrid_scheme;
|
use_updated_hybrid_scheme;
|
||||||
|
@ -557,8 +557,25 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_CONCATENATION:
|
|
||||||
case BuiltinOperator_BATCH_MATMUL:
|
case BuiltinOperator_BATCH_MATMUL:
|
||||||
|
// In case of int16 inputs, the version is 3.
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||||
|
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||||
|
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||||
|
if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
|
||||||
|
// This is to use the updated quantization scheme.
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
case BuiltinOperator_CONCATENATION:
|
||||||
case BuiltinOperator_SOFTMAX:
|
case BuiltinOperator_SOFTMAX:
|
||||||
case BuiltinOperator_MEAN:
|
case BuiltinOperator_MEAN:
|
||||||
case BuiltinOperator_PAD:
|
case BuiltinOperator_PAD:
|
||||||
@ -792,6 +809,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
|||||||
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
|
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
|
||||||
} break;
|
} break;
|
||||||
|
|
||||||
|
case BuiltinOperator_BATCH_MATMUL: {
|
||||||
|
auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions();
|
||||||
|
op_sig.options.input_quantization.asymmetric_quantize_inputs =
|
||||||
|
batch_matmul_option->asymmetric_quantize_inputs();
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -779,4 +780,49 @@ TEST(OpVersionTest, VersioningAbsTest) {
|
|||||||
};
|
};
|
||||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
}
|
}
|
||||||
|
TEST(OpVersionTest, VersioningBatchMatMulTest) {
|
||||||
|
// Default.
|
||||||
|
OpSignature fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BATCH_MATMUL,
|
||||||
|
.input_types =
|
||||||
|
std::vector<TensorType>{TensorType_FLOAT32, TensorType_FLOAT32},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||||
|
|
||||||
|
// int8 input is version 2.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BATCH_MATMUL,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||||
|
|
||||||
|
// int16 input is version 3.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BATCH_MATMUL,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT16, TensorType_INT8},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_INT16},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
|
|
||||||
|
// Symmetric hybrid quantized input is version 1.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BATCH_MATMUL,
|
||||||
|
.input_types =
|
||||||
|
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||||
|
|
||||||
|
// Asymmetric hybrid quantized input is version 4.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BATCH_MATMUL,
|
||||||
|
.input_types =
|
||||||
|
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||||
|
};
|
||||||
|
fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true;
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||||
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user