Add utility functions for integration. This supports the calibration case that the model is initialized multiple times.
(1) quatization_wrapper is the external interface to help calibration and quantization. The interface is string and bool so the dependencies are minimal. It has two functions: - CreateCalibrationModel copies a model to a new location and adds intermediate tensors if any of the op need that. - CreateQuantizedModel quantizes a model in place. (2) quatization wrapper_utils is the helper function for quatization_wrapper - added function to load model - added function to write model PiperOrigin-RevId: 280510873 Change-Id: I58891d6e8d6d3b485242f321466ac91ce2fdffda
This commit is contained in:
parent
dd6f51d33b
commit
fa0fb0d4f6
@ -13,9 +13,9 @@ exports_files(glob([
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "add_intermediate_tensors",
|
name = "quantization_wrapper_utils",
|
||||||
srcs = ["add_intermediate_tensors.cc"],
|
srcs = ["quantization_wrapper_utils.cc"],
|
||||||
hdrs = ["add_intermediate_tensors.h"],
|
hdrs = ["quantization_wrapper_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":operator_property",
|
":operator_property",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
@ -26,14 +26,14 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "add_intermediate_tensors_test",
|
name = "quantization_wrapper_utils_test",
|
||||||
srcs = ["add_intermediate_tensors_test.cc"],
|
srcs = ["quantization_wrapper_utils_test.cc"],
|
||||||
tags = [
|
tags = [
|
||||||
"tflite_not_portable_android",
|
"tflite_not_portable_android",
|
||||||
"tflite_not_portable_ios",
|
"tflite_not_portable_ios",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":add_intermediate_tensors",
|
":quantization_wrapper_utils",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -42,6 +42,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "quantization_wrapper",
|
||||||
|
srcs = ["quantization_wrapper.cc"],
|
||||||
|
hdrs = ["quantization_wrapper.h"],
|
||||||
|
deps = [
|
||||||
|
":quantization_wrapper_utils",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/core/api",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/tools/optimize:quantize_model",
|
||||||
|
"@flatbuffers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "quantization_utils",
|
name = "quantization_utils",
|
||||||
srcs = ["quantization_utils.cc"],
|
srcs = ["quantization_utils.cc"],
|
||||||
|
52
tensorflow/lite/tools/optimize/quantization_wrapper.cc
Normal file
52
tensorflow/lite/tools/optimize/quantization_wrapper.cc
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/optimize/quantization_wrapper.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
|
||||||
|
#include "tensorflow/lite/tools/optimize/quantize_model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimize {
|
||||||
|
|
||||||
|
bool CreateModelForCalibration(const std::string& input_path,
|
||||||
|
const std::string& output_path) {
|
||||||
|
ModelT model;
|
||||||
|
if (LoadModel(input_path, &model) != kTfLiteOk) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
if (AddIntemediateTensorsToFusedOp(&builder, &model) != kTfLiteOk) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return WriteFile(output_path, builder.GetBufferPointer(), builder.GetSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CreateQuantizedModel(const std::string& path) {
|
||||||
|
ModelT model;
|
||||||
|
if (LoadModel(path, &model) != kTfLiteOk) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
tflite::StderrReporter error_reporter;
|
||||||
|
if (tflite::optimize::QuantizeModel(
|
||||||
|
&builder, &model, tflite::TensorType_FLOAT32,
|
||||||
|
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optimize
|
||||||
|
} // namespace tflite
|
39
tensorflow/lite/tools/optimize/quantization_wrapper.h
Normal file
39
tensorflow/lite/tools/optimize/quantization_wrapper.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_H_
|
||||||
|
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimize {
|
||||||
|
|
||||||
|
// Makes an copy of the model at input_path and writes it to output_path, adding
|
||||||
|
// tensors to the model needed for calibration.
|
||||||
|
// Returns true if it is successful.
|
||||||
|
// Example: a/b/c.tflite becomes a/b/c.calibrated.tflite and has
|
||||||
|
// intermediate tensors added according to operator properties.
|
||||||
|
bool CreateModelForCalibration(const std::string& input_path,
|
||||||
|
const std::string& output_path);
|
||||||
|
|
||||||
|
// Quantize a model in place. This function is only to be called after calling
|
||||||
|
// CreateModelForCalibration and running calibration over data.
|
||||||
|
// Returns true if it is successful.
|
||||||
|
bool CreateQuantizedModel(const std::string& path);
|
||||||
|
|
||||||
|
} // namespace optimize
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_H_
|
@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/tools/optimize/add_intermediate_tensors.h"
|
#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -51,6 +53,19 @@ bool IntermediateTensorExists(ModelT* model) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteStatus LoadModel(const string& path, ModelT* model) {
|
||||||
|
auto input_model = FlatBufferModel::BuildFromFile(path.c_str());
|
||||||
|
if (!input_model) {
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
auto readonly_model = input_model->GetModel();
|
||||||
|
if (!readonly_model) {
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
readonly_model->UnPackTo(model);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
||||||
flatbuffers::FlatBufferBuilder* builder, ModelT* model) {
|
flatbuffers::FlatBufferBuilder* builder, ModelT* model) {
|
||||||
// Return early if the model already has intermediate tensors.
|
// Return early if the model already has intermediate tensors.
|
||||||
@ -90,5 +105,14 @@ TfLiteStatus AddIntemediateTensorsToFusedOp(
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool WriteFile(const std::string& out_file, const uint8_t* bytes,
|
||||||
|
size_t num_bytes) {
|
||||||
|
std::fstream stream(out_file, std::ios::binary | std::ios::out);
|
||||||
|
for (size_t i = 0; i < num_bytes; i++) {
|
||||||
|
stream << bytes[i];
|
||||||
|
}
|
||||||
|
return (!stream.bad() && !stream.fail());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
} // namespace tflite
|
} // namespace tflite
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_UTILS_H_
|
||||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_UTILS_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
@ -22,13 +22,20 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace optimize {
|
namespace optimize {
|
||||||
|
|
||||||
|
// Load a tflite model from path.
|
||||||
|
TfLiteStatus LoadModel(const string& path, ModelT* model);
|
||||||
|
|
||||||
// Going through the model and add intermediates tensors if the ops have any.
|
// Going through the model and add intermediates tensors if the ops have any.
|
||||||
// Returns early if the model has already intermediate tensors. This is to
|
// Returns early if the model has already intermediate tensors. This is to
|
||||||
// support cases where a model is initialized multiple times.
|
// support cases where a model is initialized multiple times.
|
||||||
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
||||||
flatbuffers::FlatBufferBuilder* builder, ModelT* input_model);
|
flatbuffers::FlatBufferBuilder* builder, ModelT* model);
|
||||||
|
|
||||||
|
// Write model to a given location.
|
||||||
|
bool WriteFile(const std::string& out_file, const uint8_t* bytes,
|
||||||
|
size_t num_bytes);
|
||||||
|
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_WRAPPER_UTILS_H_
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/tools/optimize/add_intermediate_tensors.h"
|
#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
Loading…
Reference in New Issue
Block a user