Allow update for min/max values in calibrator.
PiperOrigin-RevId: 257871274
This commit is contained in:
parent
c448af4c07
commit
e60d77cdb9
@ -198,7 +198,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
return nullptr;
|
||||
}
|
||||
auto tflite_model = CreateMutableModel(*model_->GetModel());
|
||||
reader_->AddCalibrationToModel(tflite_model.get());
|
||||
reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = tflite::optimize::QuantizeModel(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
|
@ -48,6 +48,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -32,7 +32,8 @@ TfLiteStatus CalibrationReader::GetTensorStatsAsMap(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const {
|
||||
TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model,
|
||||
bool update) const {
|
||||
if (!model || model->subgraphs.empty()) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -41,6 +42,15 @@ TfLiteStatus CalibrationReader::AddCalibrationToModel(ModelT* model) const {
|
||||
auto minmax = tensorid_stat.second;
|
||||
float min, max;
|
||||
TF_LITE_ENSURE_STATUS(minmax.Get(&min, &max));
|
||||
if (update) {
|
||||
auto tensor = subgraph->tensors[tensorid_stat.first].get();
|
||||
if (tensor->quantization) {
|
||||
const float existing_min = tensor->quantization->min[0];
|
||||
const float existing_max = tensor->quantization->max[0];
|
||||
min = min < existing_min ? min : existing_min;
|
||||
max = max > existing_max ? max : existing_max;
|
||||
}
|
||||
}
|
||||
auto quant_params = absl::make_unique<tflite::QuantizationParametersT>();
|
||||
quant_params->min.push_back(min);
|
||||
quant_params->max.push_back(max);
|
||||
|
@ -42,7 +42,9 @@ class CalibrationReader {
|
||||
|
||||
// Annotates the tensors in the given model with statistics captured during
|
||||
// calibration.
|
||||
virtual TfLiteStatus AddCalibrationToModel(ModelT* model) const;
|
||||
// "update" is a flag: when set to true, the min/max are updated, instead of
|
||||
// being overwritten.
|
||||
virtual TfLiteStatus AddCalibrationToModel(ModelT* model, bool update) const;
|
||||
|
||||
virtual ~CalibrationReader() {}
|
||||
|
||||
|
@ -50,7 +50,8 @@ namespace calibration {
|
||||
//
|
||||
// or adding calibration data to model itself.
|
||||
// ModelT * original_floating_point_model = ...
|
||||
// calibration_reader->AddCalibrationToModel(original_floating_point_model);
|
||||
// calibration_reader->AddCalibrationToModel(original_floating_point_model,
|
||||
// false);
|
||||
//
|
||||
TfLiteStatus BuildLoggingInterpreter(
|
||||
const FlatBufferModel& model, const OpResolver& op_resolver,
|
||||
|
@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
|
||||
|
||||
namespace {
|
||||
tensorflow::string* g_test_model_file = nullptr;
|
||||
@ -189,6 +191,91 @@ TEST(CalibratorTest, MultipleInvokes) {
|
||||
EXPECT_NEAR(stats.at(6).max, 9.0f, eps);
|
||||
}
|
||||
|
||||
TEST(CalibratorTest, UpdateMinMax) {
|
||||
auto flatbuffer_model = ReadModel();
|
||||
ASSERT_TRUE(flatbuffer_model);
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
std::unique_ptr<CalibrationReader> reader;
|
||||
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
||||
ops::builtin::BuiltinOpResolver{},
|
||||
&interpreter, &reader);
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
auto readonly_model = flatbuffer_model->GetModel();
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
|
||||
ASSERT_TRUE(interpreter);
|
||||
ASSERT_TRUE(reader);
|
||||
status = interpreter->AllocateTensors();
|
||||
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
const size_t tensor_size = 1 * 8 * 8 * 3;
|
||||
for (size_t i = 0; i < interpreter->inputs().size(); i++) {
|
||||
int input_tensor_idx = interpreter->inputs()[i];
|
||||
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
||||
ASSERT_EQ(tensor->bytes, tensor_size * sizeof(float));
|
||||
for (size_t j = 0; j < tensor_size; j++) {
|
||||
tensor->data.f[j] = i + 1;
|
||||
}
|
||||
}
|
||||
auto input_0_quant_params =
|
||||
absl::make_unique<tflite::QuantizationParametersT>();
|
||||
input_0_quant_params->min.push_back(0.5);
|
||||
input_0_quant_params->max.push_back(1.5);
|
||||
model.subgraphs[0]->tensors[0]->quantization =
|
||||
std::move(input_0_quant_params);
|
||||
|
||||
// Invoke with update == true.
|
||||
status = interpreter->Invoke();
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const float eps = 1e-6f;
|
||||
// Verify that min max of tensors.
|
||||
const float expected_min[7] = {
|
||||
0.5f, // input 0
|
||||
2.0f, // input 1
|
||||
3.0f, // input 2
|
||||
4.0f, // input 3
|
||||
5.0f, // Add(1, 2)
|
||||
6.0f, // Output 5: Add(0, Add(1,2))
|
||||
9.0f, // Output 6: Add(Add(1,2), 3)
|
||||
};
|
||||
const float expected_max[7] = {
|
||||
1.5f, // input 0
|
||||
2.0f, // input 1
|
||||
3.0f, // input 2
|
||||
4.0f, // input 3
|
||||
5.0f, // Add(1, 2)
|
||||
6.0f, // Output 5: Add(0, Add(1,2))
|
||||
9.0f, // Output 6: Add(Add(1,2), 3)
|
||||
};
|
||||
status = reader->AddCalibrationToModel(&model, /*update=*/true);
|
||||
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
|
||||
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
|
||||
expected_min[tensor_idx], eps);
|
||||
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
|
||||
expected_max[tensor_idx], eps);
|
||||
}
|
||||
|
||||
// Invoke with update == false;
|
||||
// Verify that min max of tensors.
|
||||
const float expected_value[7] = {
|
||||
1.0f, // input 0
|
||||
2.0f, // input 1
|
||||
3.0f, // input 2
|
||||
4.0f, // input 3
|
||||
5.0f, // Add(1, 2)
|
||||
6.0f, // Output 5: Add(0, Add(1,2))
|
||||
9.0f, // Output 6: Add(Add(1,2), 3)
|
||||
};
|
||||
status = reader->AddCalibrationToModel(&model, /*update=*/false);
|
||||
for (int tensor_idx = 0; tensor_idx < 7; tensor_idx++) {
|
||||
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->min[0],
|
||||
expected_value[tensor_idx], eps);
|
||||
EXPECT_NEAR(model.subgraphs[0]->tensors[tensor_idx]->quantization->max[0],
|
||||
expected_value[tensor_idx], eps);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
Loading…
Reference in New Issue
Block a user