Allow update for min/max values in calibrator.

PiperOrigin-RevId: 257871274
This commit is contained in:
Jian Li 2019-07-12 14:39:41 -07:00 committed by TensorFlower Gardener
parent c448af4c07
commit e60d77cdb9
6 changed files with 106 additions and 5 deletions

View File

@ -198,7 +198,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
return nullptr; return nullptr;
} }
auto tflite_model = CreateMutableModel(*model_->GetModel()); auto tflite_model = CreateMutableModel(*model_->GetModel());
reader_->AddCalibrationToModel(tflite_model.get()); reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
flatbuffers::FlatBufferBuilder builder; flatbuffers::FlatBufferBuilder builder;
auto status = tflite::optimize::QuantizeModel( auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),

View File

@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )

View File

@ -32,7 +32,8 @@ TfLiteStatus CalibrationReader::GetTensorStatsAsMap(
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const { TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model,
bool update) const {
if (!model || model->subgraphs.empty()) { if (!model || model->subgraphs.empty()) {
return kTfLiteError; return kTfLiteError;
} }
@ -41,6 +42,15 @@ TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const {
auto minmax = tensorid_stat.second; auto minmax = tensorid_stat.second;
float min, max; float min, max;
TF_LITE_ENSURE_STATUS(minmax.Get(&min, &max)); TF_LITE_ENSURE_STATUS(minmax.Get(&min, &max));
if (update) {
auto tensor = subgraph->tensors[tensorid_stat.first].get();
if (tensor->quantization) {
const float existing_min = tensor->quantization->min[0];
const float existing_max = tensor->quantization->max[0];
min = min < existing_min ? min : existing_min;
max = max > existing_max ? max : existing_max;
}
}
auto quant_params = absl::make_unique<tflite::QuantizationParametersT>(); auto quant_params = absl::make_unique<tflite::QuantizationParametersT>();
quant_params->min.push_back(min); quant_params->min.push_back(min);
quant_params->max.push_back(max); quant_params->max.push_back(max);

View File

@ -42,7 +42,9 @@ class CalibrationReader {
// Annotates the tensors in the given model with statistics captured during // Annotates the tensors in the given model with statistics captured during
// calibration. // calibration.
virtual TfLiteStatus AddCalibrationToModel(ModelT* model) const; // "update" is a flag: when set to true, the min/max are updated, instead of
// being overwritten.
virtual TfLiteStatus AddCalibrationToModel(ModelT* model, bool update) const;
virtual ~CalibrationReader() {} virtual ~CalibrationReader() {}

View File

@ -50,7 +50,8 @@ namespace calibration {
// //
// or adding calibration data to model itself. // or adding calibration data to model itself.
// ModelT * original_floating_point_model = ... // ModelT * original_floating_point_model = ...
// calibration_reader->AddCalibrationToModel(original_floating_point_model); // calibration_reader->AddCalibrationToModel(original_floating_point_model,
// false);
// //
TfLiteStatus BuildLoggingInterpreter( TfLiteStatus BuildLoggingInterpreter(
const FlatBufferModel& model, const OpResolver& op_resolver, const FlatBufferModel& model, const OpResolver& op_resolver,

View File

@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
#include <cstring> #include <cstring>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
namespace { namespace {
tensorflow::string* g_test_model_file = nullptr; tensorflow::string* g_test_model_file = nullptr;
@ -189,6 +191,91 @@ TEST(CalibratorTest, MultipleInvokes) {
EXPECT_NEAR(stats.at(6).max, 9.0f, eps); EXPECT_NEAR(stats.at(6).max, 9.0f, eps);
} }
TEST(CalibratorTest, UpdateMinMax) {
auto flatbuffer_model = ReadModel();
ASSERT_TRUE(flatbuffer_model);
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<CalibrationReader> reader;
auto status = BuildLoggingInterpreter(*flatbuffer_model,
ops::builtin::BuiltinOpResolver{},
&interpreter, &reader);
EXPECT_EQ(kTfLiteOk, status);
auto readonly_model = flatbuffer_model->GetModel();
tflite::ModelT model;
readonly_model->UnPackTo(&model);
ASSERT_TRUE(interpreter);
ASSERT_TRUE(reader);
status = interpreter->AllocateTensors();
EXPECT_EQ(kTfLiteOk, status);
const size_t tensor_size = 1 * 8 * 8 * 3;
for (size_t i = 0; i < interpreter->inputs().size(); i++) {
int input_tensor_idx = interpreter->inputs()[i];
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
ASSERT_EQ(tensor->bytes, tensor_size * sizeof(float));
for (size_t j = 0; j < tensor_size; j++) {
tensor->data.f[j] = i + 1;
}
}
auto input_0_quant_params =
absl::make_unique<tflite::QuantizationParametersT>();
input_0_quant_params->min.push_back(0.5);
input_0_quant_params->max.push_back(1.5);
model.subgraphs[0]->tensors[0]->quantization =
std::move(input_0_quant_params);
// Invoke with update == true.
status = interpreter->Invoke();
ASSERT_EQ(kTfLiteOk, status);
const float eps = 1e-6f;
// Verify that min max of tensors.
const float expected_min[7] = {
0.5f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
const float expected_max[7] = {
1.5f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
status = reader->AddCalibrationToModel(&model, /*update=*/true);
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
expected_min[tensor_idx], eps);
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
expected_max[tensor_idx], eps);
}
// Invoke with update == false;
// Verify that min max of tensors.
const float expected_value[7] = {
1.0f, // input 0
2.0f, // input 1
3.0f, // input 2
4.0f, // input 3
5.0f, // Add(1, 2)
6.0f, // Output 5: Add(0, Add(1,2))
9.0f, // Output 6: Add(Add(1,2), 3)
};
status = reader->AddCalibrationToModel(&model, /*update=*/false);
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
expected_value[tensor_idx], eps);
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
expected_value[tensor_idx], eps);
}
}
} // namespace } // namespace
} // namespace calibration } // namespace calibration
} // namespace optimize } // namespace optimize