Fix for the broken 16-bit interface after latest changes to master.
This commit is contained in:
parent
0d95600271
commit
69ee4de053
@ -234,11 +234,10 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = kTfLiteOk;
|
||||
|
||||
status = tflite::optimize::QuantizeModel(
|
||||
status = tflite::optimize::QuantizeModelAllOperators(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float,
|
||||
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
|
||||
}
|
||||
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->exception();
|
||||
@ -269,8 +268,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
auto status = tflite::optimize::QuantizeModel(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
|
||||
TensorType_INT8,
|
||||
error_reporter_.get());
|
||||
TensorType_INT8, error_reporter_.get());
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->exception();
|
||||
return nullptr;
|
||||
|
@ -1240,9 +1240,11 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model,
|
||||
const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||
|
@ -69,9 +69,11 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
// could be TensorType_INT16 or TensorType_INT8.
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model,
|
||||
const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
|
@ -190,8 +190,8 @@ TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
}
|
||||
|
||||
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, tensor_type_, tensor_type_,
|
||||
/*allow_float*/ false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
for (const auto& subgraph : model_.subgraphs) {
|
||||
@ -209,8 +209,8 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
}
|
||||
|
||||
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float*/ false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
@ -384,7 +384,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest,
|
||||
// concat - output
|
||||
// input1 /
|
||||
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
@ -549,8 +550,8 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, false,
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -658,7 +659,8 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -1308,7 +1310,8 @@ TEST_F(QuantizeFCTest, VerifyFC) {
|
||||
EXPECT_EQ(model_.operator_codes[1]->version, 1);
|
||||
}
|
||||
|
||||
class QuantizeCustomOpTest : public QuantizeModelTest,
|
||||
class QuantizeCustomOpTest
|
||||
: public QuantizeModelTest,
|
||||
public ::testing::WithParamInterface<tflite::TensorType> {
|
||||
protected:
|
||||
QuantizeCustomOpTest() {
|
||||
@ -1319,8 +1322,8 @@ class QuantizeCustomOpTest : public QuantizeModelTest,
|
||||
};
|
||||
|
||||
TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, GetParam(), GetParam(),
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, GetParam(), GetParam(),
|
||||
/*allow_float=*/true, GetParam(), &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -1358,8 +1361,8 @@ class QuantizeOp16x8Test : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, TensorType_INT16, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -1369,8 +1372,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
|
||||
// The resulting model should be:
|
||||
// conv_2d->dequantize->log_softmax
|
||||
ASSERT_EQ(subgraph->operators.size(), 3);
|
||||
const std::vector<BuiltinOperator> op_codes = {
|
||||
BuiltinOperator_CONV_2D, BuiltinOperator_DEQUANTIZE,
|
||||
const std::vector<BuiltinOperator> op_codes = {BuiltinOperator_CONV_2D,
|
||||
BuiltinOperator_DEQUANTIZE,
|
||||
BuiltinOperator_LOG_SOFTMAX};
|
||||
const std::vector<TensorType> op_input_types = {
|
||||
TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
|
||||
|
Loading…
Reference in New Issue
Block a user