Allow update for min/max values in calibrator.

PiperOrigin-RevId: 257871274
This commit is contained in:
Jian Li 2019-07-12 14:39:41 -07:00 committed by TensorFlower Gardener
parent c448af4c07
commit e60d77cdb9
6 changed files with 106 additions and 5 deletions

View File

@ -198,7 +198,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
return nullptr;
}
auto tflite_model = CreateMutableModel(*model_->GetModel());
reader_->AddCalibrationToModel(tflite_model.get());
reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
flatbuffers::FlatBufferBuilder builder;
auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),

View File

@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)

View File

@ -32,7 +32,8 @@ TfLiteStatus CalibrationReader::GetTensorStatsAsMap(
return kTfLiteOk;
}
TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const {
TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model,
bool update) const {
if (!model || model->subgraphs.empty()) {
return kTfLiteError;
}
@ -41,6 +42,15 @@ TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const {
auto minmax = tensorid_stat.second;
float min, max;
TF_LITE_ENSURE_STATUS(minmax.Get(&min, &max));
if (update) {
auto tensor = subgraph->tensors[tensorid_stat.first].get();
if (tensor->quantization) {
const float existing_min = tensor->quantization->min[0];
const float existing_max = tensor->quantization->max[0];
min = min < existing_min ? min : existing_min;
max = max > existing_max ? max : existing_max;
}
}
auto quant_params = absl::make_unique<tflite::QuantizationParametersT>();
quant_params->min.push_back(min);
quant_params->max.push_back(max);

View File

@ -42,7 +42,9 @@ class CalibrationReader {
// Annotates the tensors in the given model with statistics captured during
// calibration.
virtual TfLiteStatus AddCalibrationToModel(ModelT* model) const;
// "update" is a flag: when set to true, the min/max are updated, instead of
// being overwritten.
virtual TfLiteStatus AddCalibrationToModel(ModelT* model, bool update) const;
virtual ~CalibrationReader() {}

View File

@ -50,7 +50,8 @@ namespace calibration {
//
// or adding calibration data to model itself.
// ModelT * original_floating_point_model = ...
// calibration_reader->AddCalibrationToModel(original_floating_point_model);
// calibration_reader->AddCalibrationToModel(original_floating_point_model,
// false);
//
TfLiteStatus BuildLoggingInterpreter(
const FlatBufferModel& model, const OpResolver& op_resolver,

View File

@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
#include <cstring>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
namespace {
tensorflow::string* g_test_model_file = nullptr;
@ -189,6 +191,91 @@ TEST(CalibratorTest, MultipleInvokes) {
EXPECT_NEAR(stats.at(6).max, 9.0f, eps);
}
TEST(CalibratorTest, UpdateMinMax) {
auto flatbuffer_model = ReadModel();
ASSERT_TRUE(flatbuffer_model);
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<CalibrationReader> reader;
auto status = BuildLoggingInterpreter(*flatbuffer_model,
ops::builtin::BuiltinOpResolver{},
&interpreter, &reader);
EXPECT_EQ(kTfLiteOk, status);
auto readonly_model = flatbuffer_model->GetModel();
tflite::ModelT model;
readonly_model->UnPackTo(&model);
ASSERT_TRUE(interpreter);
ASSERT_TRUE(reader);
status = interpreter->AllocateTensors();
EXPECT_EQ(kTfLiteOk, status);
const size_t tensor_size = 1 * 8 * 8 * 3;
for (size_t i = 0; i < interpreter->inputs().size(); i++) {
int input_tensor_idx = interpreter->inputs()[i];
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
ASSERT_EQ(tensor->bytes, tensor_size * sizeof(float));
for (size_t j = 0; j < tensor_size; j++) {
tensor->data.f[j] = i + 1;
}
}
auto input_0_quant_params =
absl::make_unique<tflite::QuantizationParametersT>();
input_0_quant_params->min.push_back(0.5);
input_0_quant_params->max.push_back(1.5);
model.subgraphs[0]->tensors[0]->quantization =
std::move(input_0_quant_params);
// Invoke with update == true.
status = interpreter->Invoke();
ASSERT_EQ(kTfLiteOk, status);
const float eps = 1e-6f;
// Verify that min max of tensors.
const float expected_min[7] = {
0.5f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
const float expected_max[7] = {
1.5f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
status = reader->AddCalibrationToModel(&model, /*update=*/true);
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
expected_min[tensor_idx], eps);
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
expected_max[tensor_idx], eps);
}
// Invoke with update == false;
// Verify that min max of tensors.
const float expected_value[7] = {
1.0f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
status = reader->AddCalibrationToModel(&model, /*update=*/false);
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
expected_value[tensor_idx], eps);
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
expected_value[tensor_idx], eps);
}
}
} // namespace
} // namespace calibration
} // namespace optimize