Refactor Python utility functions from Python wrapper.
- Moves ErrorReporter to a separate class. - Moves some utility functions to a separate class. PiperOrigin-RevId: 230387612
This commit is contained in:
parent
0b8ac6e1b9
commit
bfcddf733e
@ -11,6 +11,8 @@ cc_library(
|
|||||||
srcs = ["interpreter_wrapper.cc"],
|
srcs = ["interpreter_wrapper.cc"],
|
||||||
hdrs = ["interpreter_wrapper.h"],
|
hdrs = ["interpreter_wrapper.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":python_error_reporter",
|
||||||
|
":python_utils",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//third_party/py/numpy:headers",
|
"//third_party/py/numpy:headers",
|
||||||
@ -19,6 +21,27 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "python_error_reporter",
|
||||||
|
srcs = ["python_error_reporter.cc"],
|
||||||
|
hdrs = ["python_error_reporter.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/core/api",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "python_utils",
|
||||||
|
srcs = ["python_utils.cc"],
|
||||||
|
hdrs = ["python_utils.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//third_party/py/numpy:headers",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_wrap_cc(
|
tf_py_wrap_cc(
|
||||||
name = "tensorflow_wrap_interpreter_wrapper",
|
name = "tensorflow_wrap_interpreter_wrapper",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||||
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
|
|
||||||
// Disallow Numpy 1.7 deprecated symbols.
|
// Disallow Numpy 1.7 deprecated symbols.
|
||||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||||
@ -60,36 +62,6 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace interpreter_wrapper {
|
namespace interpreter_wrapper {
|
||||||
|
|
||||||
class PythonErrorReporter : public tflite::ErrorReporter {
|
|
||||||
public:
|
|
||||||
PythonErrorReporter() {}
|
|
||||||
|
|
||||||
// Report an error message
|
|
||||||
int Report(const char* format, va_list args) override {
|
|
||||||
char buf[1024];
|
|
||||||
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
|
||||||
buffer_ << buf;
|
|
||||||
return formatted;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set's a Python runtime exception with the last error.
|
|
||||||
PyObject* exception() {
|
|
||||||
std::string last_message = message();
|
|
||||||
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets the last error message and clears the buffer.
|
|
||||||
std::string message() {
|
|
||||||
std::string value = buffer_.str();
|
|
||||||
buffer_.clear();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::stringstream buffer_;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Calls PyArray's initialization to initialize all the API pointers. Note that
|
// Calls PyArray's initialization to initialize all the API pointers. Note that
|
||||||
@ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
|||||||
return interpreter;
|
return interpreter;
|
||||||
}
|
}
|
||||||
|
|
||||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
|
||||||
switch (tf_lite_type) {
|
|
||||||
case kTfLiteFloat32:
|
|
||||||
return NPY_FLOAT32;
|
|
||||||
case kTfLiteInt32:
|
|
||||||
return NPY_INT32;
|
|
||||||
case kTfLiteInt16:
|
|
||||||
return NPY_INT16;
|
|
||||||
case kTfLiteUInt8:
|
|
||||||
return NPY_UINT8;
|
|
||||||
case kTfLiteInt8:
|
|
||||||
return NPY_INT8;
|
|
||||||
case kTfLiteInt64:
|
|
||||||
return NPY_INT64;
|
|
||||||
case kTfLiteString:
|
|
||||||
return NPY_OBJECT;
|
|
||||||
case kTfLiteBool:
|
|
||||||
return NPY_BOOL;
|
|
||||||
case kTfLiteComplex64:
|
|
||||||
return NPY_COMPLEX64;
|
|
||||||
case kTfLiteNoType:
|
|
||||||
return NPY_NOTYPE;
|
|
||||||
// Avoid default so compiler errors created when new types are made.
|
|
||||||
}
|
|
||||||
return NPY_NOTYPE;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
|
|
||||||
int pyarray_type = PyArray_TYPE(array);
|
|
||||||
switch (pyarray_type) {
|
|
||||||
case NPY_FLOAT32:
|
|
||||||
return kTfLiteFloat32;
|
|
||||||
case NPY_INT32:
|
|
||||||
return kTfLiteInt32;
|
|
||||||
case NPY_INT16:
|
|
||||||
return kTfLiteInt16;
|
|
||||||
case NPY_UINT8:
|
|
||||||
return kTfLiteUInt8;
|
|
||||||
case NPY_INT8:
|
|
||||||
return kTfLiteInt8;
|
|
||||||
case NPY_INT64:
|
|
||||||
return kTfLiteInt64;
|
|
||||||
case NPY_BOOL:
|
|
||||||
return kTfLiteBool;
|
|
||||||
case NPY_OBJECT:
|
|
||||||
case NPY_STRING:
|
|
||||||
case NPY_UNICODE:
|
|
||||||
return kTfLiteString;
|
|
||||||
case NPY_COMPLEX64:
|
|
||||||
return kTfLiteComplex64;
|
|
||||||
// Avoid default so compiler errors created when new types are made.
|
|
||||||
}
|
|
||||||
return kTfLiteNoType;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct PyDecrefDeleter {
|
struct PyDecrefDeleter {
|
||||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||||
};
|
};
|
||||||
@ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
int code = TfLiteTypeToPyArrayType(tensor->type);
|
int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
|
||||||
if (code == -1) {
|
if (code == -1) {
|
||||||
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
|
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
|||||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
|
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
|
||||||
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||||
|
|
||||||
if (TfLiteTypeFromPyArray(array) != tensor->type) {
|
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
|
||||||
PyErr_Format(PyExc_ValueError,
|
PyErr_Format(PyExc_ValueError,
|
||||||
"Cannot set tensor:"
|
"Cannot set tensor:"
|
||||||
" Got tensor of type %d"
|
" Got tensor of type %d"
|
||||||
" but expected type %d for input %d ",
|
" but expected type %d for input %d ",
|
||||||
TfLiteTypeFromPyArray(array), tensor->type, i);
|
python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
|
*type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
|
||||||
if (*type_num == -1) {
|
if (*type_num == -1) {
|
||||||
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
|
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||||
|
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace interpreter_wrapper {
|
||||||
|
|
||||||
|
// Report an error message
|
||||||
|
int PythonErrorReporter::Report(const char* format, va_list args) {
|
||||||
|
char buf[1024];
|
||||||
|
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||||
|
buffer_ << buf;
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set's a Python runtime exception with the last error.
|
||||||
|
PyObject* PythonErrorReporter::exception() {
|
||||||
|
std::string last_message = message();
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets the last error message and clears the buffer.
|
||||||
|
std::string PythonErrorReporter::message() {
|
||||||
|
std::string value = buffer_.str();
|
||||||
|
buffer_.clear();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
} // namespace interpreter_wrapper
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
||||||
|
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace interpreter_wrapper {
|
||||||
|
|
||||||
|
class PythonErrorReporter : public tflite::ErrorReporter {
|
||||||
|
public:
|
||||||
|
PythonErrorReporter() {}
|
||||||
|
|
||||||
|
// Report an error message
|
||||||
|
int Report(const char* format, va_list args) override;
|
||||||
|
|
||||||
|
// Sets a Python runtime exception with the last error and
|
||||||
|
// clears the error message buffer.
|
||||||
|
PyObject* exception();
|
||||||
|
|
||||||
|
// Gets the last error message and clears the buffer.
|
||||||
|
std::string message();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::stringstream buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace interpreter_wrapper
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
77
tensorflow/lite/python/interpreter_wrapper/python_utils.cc
Normal file
77
tensorflow/lite/python/interpreter_wrapper/python_utils.cc
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace python_utils {
|
||||||
|
|
||||||
|
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||||
|
switch (tf_lite_type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
return NPY_FLOAT32;
|
||||||
|
case kTfLiteInt32:
|
||||||
|
return NPY_INT32;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
return NPY_INT16;
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
return NPY_UINT8;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
return NPY_INT8;
|
||||||
|
case kTfLiteInt64:
|
||||||
|
return NPY_INT64;
|
||||||
|
case kTfLiteString:
|
||||||
|
return NPY_OBJECT;
|
||||||
|
case kTfLiteBool:
|
||||||
|
return NPY_BOOL;
|
||||||
|
case kTfLiteComplex64:
|
||||||
|
return NPY_COMPLEX64;
|
||||||
|
case kTfLiteNoType:
|
||||||
|
return NPY_NOTYPE;
|
||||||
|
// Avoid default so compiler errors created when new types are made.
|
||||||
|
}
|
||||||
|
return NPY_NOTYPE;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
|
||||||
|
int pyarray_type = PyArray_TYPE(array);
|
||||||
|
switch (pyarray_type) {
|
||||||
|
case NPY_FLOAT32:
|
||||||
|
return kTfLiteFloat32;
|
||||||
|
case NPY_INT32:
|
||||||
|
return kTfLiteInt32;
|
||||||
|
case NPY_INT16:
|
||||||
|
return kTfLiteInt16;
|
||||||
|
case NPY_UINT8:
|
||||||
|
return kTfLiteUInt8;
|
||||||
|
case NPY_INT8:
|
||||||
|
return kTfLiteInt8;
|
||||||
|
case NPY_INT64:
|
||||||
|
return kTfLiteInt64;
|
||||||
|
case NPY_BOOL:
|
||||||
|
return kTfLiteBool;
|
||||||
|
case NPY_OBJECT:
|
||||||
|
case NPY_STRING:
|
||||||
|
case NPY_UNICODE:
|
||||||
|
return kTfLiteString;
|
||||||
|
case NPY_COMPLEX64:
|
||||||
|
return kTfLiteComplex64;
|
||||||
|
// Avoid default so compiler errors created when new types are made.
|
||||||
|
}
|
||||||
|
return kTfLiteNoType;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace python_utils
|
||||||
|
} // namespace tflite
|
38
tensorflow/lite/python/interpreter_wrapper/python_utils.h
Normal file
38
tensorflow/lite/python/interpreter_wrapper/python_utils.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
|
|
||||||
|
// Disallow Numpy 1.7 deprecated symbols.
|
||||||
|
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include "numpy/arrayobject.h"
|
||||||
|
#include "numpy/ufuncobject.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace python_utils {
|
||||||
|
|
||||||
|
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
|
||||||
|
|
||||||
|
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
|
||||||
|
|
||||||
|
} // namespace python_utils
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
Loading…
Reference in New Issue
Block a user