Fix for the broken 16-bit interface after latest changes to master.

This commit is contained in:
Elena Zhelezina 2020-03-27 15:59:37 +00:00
parent 0d95600271
commit 69ee4de053
4 changed files with 49 additions and 44 deletions

View File

@ -233,12 +233,11 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
flatbuffers::FlatBufferBuilder builder;
auto status = kTfLiteOk;
status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float,
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
}
status = tflite::optimize::QuantizeModelAllOperators(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float,
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
if (status != kTfLiteOk) {
error_reporter_->exception();
@ -269,8 +268,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
TensorType_INT8,
error_reporter_.get());
TensorType_INT8, error_reporter_.get());
if (status != kTfLiteOk) {
error_reporter_->exception();
return nullptr;

View File

@ -1240,11 +1240,13 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
return kTfLiteOk;
}
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
ModelT* model,
const TensorType& input_type,
const TensorType& output_type,
bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, input_type, output_type, allow_float,
GetAllOperatorOutputs(model), activations_type,
error_reporter);

View File

@ -69,11 +69,13 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
// could be TensorType_INT16 or TensorType_INT8.
//
// Note: This is a private API, subject to change.
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter);
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
ModelT* model,
const TensorType& input_type,
const TensorType& output_type,
bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter);
// Quantizes input_model and populates the provided builder with the new model
// with all possible input parameters.

View File

@ -190,9 +190,9 @@ TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
}
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
auto status =
QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
/*allow_float*/ false, tensor_type_, &error_reporter_);
auto status = QuantizeModelAllOperators(
&builder_, &model_, tensor_type_, tensor_type_,
/*allow_float*/ false, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
for (const auto& subgraph : model_.subgraphs) {
for (const auto& tensor : subgraph->tensors) {
@ -209,9 +209,9 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
}
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float*/ false, tensor_type_, &error_reporter_);
auto status = QuantizeModelAllOperators(
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float*/ false, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -384,8 +384,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest,
// concat - output
// input1 /
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
auto status =
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph.
@ -549,9 +550,9 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
};
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, false,
TensorType_INT8, &error_reporter_);
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
TensorType_INT8, false,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
const auto& subgraph = model_.subgraphs[0];
@ -658,8 +659,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
TensorType_INT16}));
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
auto status =
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto conv_op = subgraph->operators[0].get();
@ -1308,8 +1310,9 @@ TEST_F(QuantizeFCTest, VerifyFC) {
EXPECT_EQ(model_.operator_codes[1]->version, 1);
}
class QuantizeCustomOpTest : public QuantizeModelTest,
public ::testing::WithParamInterface<tflite::TensorType> {
class QuantizeCustomOpTest
: public QuantizeModelTest,
public ::testing::WithParamInterface<tflite::TensorType> {
protected:
QuantizeCustomOpTest() {
input_model_ = ReadModel(internal::kModelMixed);
@ -1319,9 +1322,9 @@ class QuantizeCustomOpTest : public QuantizeModelTest,
};
TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
auto status =
QuantizeModel(&builder_, &model_, GetParam(), GetParam(),
/*allow_float=*/true, GetParam(), &error_reporter_);
auto status = QuantizeModelAllOperators(
&builder_, &model_, GetParam(), GetParam(),
/*allow_float=*/true, GetParam(), &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto float_graph = readonly_model_->subgraphs()->Get(0);
@ -1335,7 +1338,7 @@ TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
BuiltinOperator_CUSTOM, BuiltinOperator_CUSTOM,
BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
const std::vector<TensorType> op_input_types = {
GetParam(), GetParam(), TensorType_FLOAT32,
GetParam(), GetParam(), TensorType_FLOAT32,
TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
@ -1358,9 +1361,9 @@ class QuantizeOp16x8Test : public QuantizeModelTest {
};
TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
/*allow_float=*/true, TensorType_INT16, &error_reporter_);
auto status = QuantizeModelAllOperators(
&builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
/*allow_float=*/true, TensorType_INT16, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto float_graph = readonly_model_->subgraphs()->Get(0);
@ -1369,11 +1372,11 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
// The resulting model should be:
// conv_2d->dequantize->log_softmax
ASSERT_EQ(subgraph->operators.size(), 3);
const std::vector<BuiltinOperator> op_codes = {
BuiltinOperator_CONV_2D, BuiltinOperator_DEQUANTIZE,
BuiltinOperator_LOG_SOFTMAX};
const std::vector<BuiltinOperator> op_codes = {BuiltinOperator_CONV_2D,
BuiltinOperator_DEQUANTIZE,
BuiltinOperator_LOG_SOFTMAX};
const std::vector<TensorType> op_input_types = {
TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
ASSERT_EQ(model_.operator_codes[op->opcode_index]->builtin_code,