Append intermediate tensors to fused operator if necessary before creating calibrator.
PiperOrigin-RevId: 338118191 Change-Id: Ie4758443e914db22fd5513f043c70e8f5e2c18b2
This commit is contained in:
parent
84967b39fa
commit
59e30f648f
tensorflow/lite
@ -462,6 +462,8 @@ class TFLiteConverterBase(object):
|
||||
self.representative_dataset = RepresentativeDataset(
|
||||
self.representative_dataset)
|
||||
|
||||
# Add intermediate tensors to the model if needed.
|
||||
result = _calibrator.add_intermediate_tensors(result)
|
||||
calibrate_quantize = _calibrator.Calibrator(result)
|
||||
if self._experimental_calibrate_only or self._experimental_new_quantizer:
|
||||
calibrated = calibrate_quantize.calibrate(
|
||||
|
@ -16,6 +16,7 @@ cc_library(
|
||||
"//tensorflow/lite/python/interpreter_wrapper:numpy",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
||||
"//tensorflow/lite/tools/optimize:quantization_wrapper_utils",
|
||||
"//tensorflow/lite/tools/optimize:quantize_model",
|
||||
"//tensorflow/lite/tools/optimize/calibration:calibration_reader",
|
||||
"//tensorflow/lite/tools/optimize/calibration:calibrator_lib",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
|
||||
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
|
||||
#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
|
||||
#include "tensorflow/lite/tools/optimize/quantize_model.h"
|
||||
|
||||
#define TFLITE_PY_CHECK(x) \
|
||||
@ -94,6 +95,42 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* AddIntermediateTensors(PyObject* data) {
|
||||
using tflite::interpreter_wrapper::PythonErrorReporter;
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t length;
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||
::tflite::python::ImportNumpy();
|
||||
|
||||
if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
|
||||
error_reporter.get());
|
||||
if (!model) {
|
||||
PyErr_Format(PyExc_ValueError, "Invalid model");
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto tflite_model = CreateMutableModel(*model->GetModel());
|
||||
if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
|
||||
kTfLiteOk) {
|
||||
error_reporter->exception();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (builder.GetSize()) {
|
||||
return python_utils::ConvertToPyString(
|
||||
reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
|
||||
builder.GetSize());
|
||||
} else {
|
||||
// When AddIntermediateTensorsToFusedOp early returns, return the model as
|
||||
// it is.
|
||||
return python_utils::ConvertToPyString(buf, length);
|
||||
}
|
||||
}
|
||||
|
||||
CalibrationWrapper::CalibrationWrapper(
|
||||
std::unique_ptr<tflite::Interpreter> interpreter,
|
||||
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
|
||||
|
@ -50,6 +50,8 @@ class CalibrationReader;
|
||||
|
||||
namespace calibration_wrapper {
|
||||
|
||||
PyObject* AddIntermediateTensors(PyObject* data);
|
||||
|
||||
class CalibrationWrapper {
|
||||
public:
|
||||
// SWIG caller takes ownership of pointer.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using tflite::calibration_wrapper::AddIntermediateTensors;
|
||||
using tflite::calibration_wrapper::CalibrationWrapper;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
|
||||
@ -25,6 +26,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
|
||||
_pywrap_tensorflow_lite_calibration_wrapper
|
||||
-----
|
||||
)pbdoc";
|
||||
m.def("AddIntermediateTensors", [](py::handle& data) {
|
||||
return tensorflow::PyoOrThrow(AddIntermediateTensors(data.ptr()));
|
||||
});
|
||||
py::class_<CalibrationWrapper>(m, "CalibrationWrapper")
|
||||
.def(py::init([](py::handle& data) {
|
||||
return ::CalibrationWrapper::CreateWrapperCPPFromBuffer(data.ptr());
|
||||
|
@ -31,6 +31,11 @@ _calibration_wrapper = LazyLoader(
|
||||
"_pywrap_tensorflow_lite_calibration_wrapper")
|
||||
|
||||
|
||||
def add_intermediate_tensors(model_content):
|
||||
"""Adds intermedaite tensors to fused op if needed."""
|
||||
return _calibration_wrapper.AddIntermediateTensors(model_content)
|
||||
|
||||
|
||||
class Calibrator(object):
|
||||
"""Calibrates a floating point model and then quantizes it.
|
||||
|
||||
|
@ -199,5 +199,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
quantized_model = quantizer.calibrate(input_gen)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_add_intermediate_tensors(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
model = open(model_path, 'rb').read()
|
||||
added_model = _calibrator.add_intermediate_tensors(model)
|
||||
self.assertIsNotNone(added_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -68,6 +68,10 @@ TfLiteStatus LoadModel(const string& path, ModelT* model) {
|
||||
|
||||
TfLiteStatus AddIntermediateTensorsToFusedOp(
|
||||
flatbuffers::FlatBufferBuilder* builder, ModelT* model) {
|
||||
// Return early when the model has no operator.
|
||||
if (model->subgraphs.size() == 1 && model->subgraphs[0]->operators.empty()) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
// Return early if the model already has intermediate tensors.
|
||||
if (IntermediateTensorExists(model)) {
|
||||
return kTfLiteOk;
|
||||
|
Loading…
Reference in New Issue
Block a user