Refactor Python utility functions from Python wrapper.
- Moves ErrorReporter to a separate class. - Moves some utility functions to a separate class. PiperOrigin-RevId: 230387612
This commit is contained in:
parent
0b8ac6e1b9
commit
bfcddf733e
@ -11,6 +11,8 @@ cc_library(
|
||||
srcs = ["interpreter_wrapper.cc"],
|
||||
hdrs = ["interpreter_wrapper.h"],
|
||||
deps = [
|
||||
":python_error_reporter",
|
||||
":python_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/py/numpy:headers",
|
||||
@ -19,6 +21,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_error_reporter",
|
||||
srcs = ["python_error_reporter.cc"],
|
||||
hdrs = ["python_error_reporter.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_utils",
|
||||
srcs = ["python_utils.cc"],
|
||||
hdrs = ["python_utils.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "tensorflow_wrap_interpreter_wrapper",
|
||||
srcs = [
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
@ -60,36 +62,6 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace interpreter_wrapper {
|
||||
|
||||
class PythonErrorReporter : public tflite::ErrorReporter {
|
||||
public:
|
||||
PythonErrorReporter() {}
|
||||
|
||||
// Report an error message
|
||||
int Report(const char* format, va_list args) override {
|
||||
char buf[1024];
|
||||
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||
buffer_ << buf;
|
||||
return formatted;
|
||||
}
|
||||
|
||||
// Set's a Python runtime exception with the last error.
|
||||
PyObject* exception() {
|
||||
std::string last_message = message();
|
||||
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Gets the last error message and clears the buffer.
|
||||
std::string message() {
|
||||
std::string value = buffer_.str();
|
||||
buffer_.clear();
|
||||
return value;
|
||||
}
|
||||
|
||||
private:
|
||||
std::stringstream buffer_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Calls PyArray's initialization to initialize all the API pointers. Note that
|
||||
@ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
||||
return interpreter;
|
||||
}
|
||||
|
||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
switch (tf_lite_type) {
|
||||
case kTfLiteFloat32:
|
||||
return NPY_FLOAT32;
|
||||
case kTfLiteInt32:
|
||||
return NPY_INT32;
|
||||
case kTfLiteInt16:
|
||||
return NPY_INT16;
|
||||
case kTfLiteUInt8:
|
||||
return NPY_UINT8;
|
||||
case kTfLiteInt8:
|
||||
return NPY_INT8;
|
||||
case kTfLiteInt64:
|
||||
return NPY_INT64;
|
||||
case kTfLiteString:
|
||||
return NPY_OBJECT;
|
||||
case kTfLiteBool:
|
||||
return NPY_BOOL;
|
||||
case kTfLiteComplex64:
|
||||
return NPY_COMPLEX64;
|
||||
case kTfLiteNoType:
|
||||
return NPY_NOTYPE;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
}
|
||||
return NPY_NOTYPE;
|
||||
}
|
||||
|
||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
|
||||
int pyarray_type = PyArray_TYPE(array);
|
||||
switch (pyarray_type) {
|
||||
case NPY_FLOAT32:
|
||||
return kTfLiteFloat32;
|
||||
case NPY_INT32:
|
||||
return kTfLiteInt32;
|
||||
case NPY_INT16:
|
||||
return kTfLiteInt16;
|
||||
case NPY_UINT8:
|
||||
return kTfLiteUInt8;
|
||||
case NPY_INT8:
|
||||
return kTfLiteInt8;
|
||||
case NPY_INT64:
|
||||
return kTfLiteInt64;
|
||||
case NPY_BOOL:
|
||||
return kTfLiteBool;
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE:
|
||||
return kTfLiteString;
|
||||
case NPY_COMPLEX64:
|
||||
return kTfLiteComplex64;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
}
|
||||
return kTfLiteNoType;
|
||||
}
|
||||
|
||||
struct PyDecrefDeleter {
|
||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||
};
|
||||
@ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int code = TfLiteTypeToPyArrayType(tensor->type);
|
||||
int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
|
||||
if (code == -1) {
|
||||
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
|
||||
return nullptr;
|
||||
@ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||
|
||||
if (TfLiteTypeFromPyArray(array) != tensor->type) {
|
||||
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Cannot set tensor:"
|
||||
" Got tensor of type %d"
|
||||
" but expected type %d for input %d ",
|
||||
TfLiteTypeFromPyArray(array), tensor->type, i);
|
||||
python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
|
||||
*type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
|
||||
if (*type_num == -1) {
|
||||
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
|
||||
return nullptr;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
%}
|
||||
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace interpreter_wrapper {
|
||||
|
||||
// Report an error message
|
||||
int PythonErrorReporter::Report(const char* format, va_list args) {
|
||||
char buf[1024];
|
||||
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||
buffer_ << buf;
|
||||
return formatted;
|
||||
}
|
||||
|
||||
// Set's a Python runtime exception with the last error.
|
||||
PyObject* PythonErrorReporter::exception() {
|
||||
std::string last_message = message();
|
||||
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Gets the last error message and clears the buffer.
|
||||
std::string PythonErrorReporter::message() {
|
||||
std::string value = buffer_.str();
|
||||
buffer_.clear();
|
||||
return value;
|
||||
}
|
||||
} // namespace interpreter_wrapper
|
||||
} // namespace tflite
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
||||
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace interpreter_wrapper {
|
||||
|
||||
class PythonErrorReporter : public tflite::ErrorReporter {
|
||||
public:
|
||||
PythonErrorReporter() {}
|
||||
|
||||
// Report an error message
|
||||
int Report(const char* format, va_list args) override;
|
||||
|
||||
// Sets a Python runtime exception with the last error and
|
||||
// clears the error message buffer.
|
||||
PyObject* exception();
|
||||
|
||||
// Gets the last error message and clears the buffer.
|
||||
std::string message();
|
||||
|
||||
private:
|
||||
std::stringstream buffer_;
|
||||
};
|
||||
|
||||
} // namespace interpreter_wrapper
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
|
77
tensorflow/lite/python/interpreter_wrapper/python_utils.cc
Normal file
77
tensorflow/lite/python/interpreter_wrapper/python_utils.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python_utils {
|
||||
|
||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
switch (tf_lite_type) {
|
||||
case kTfLiteFloat32:
|
||||
return NPY_FLOAT32;
|
||||
case kTfLiteInt32:
|
||||
return NPY_INT32;
|
||||
case kTfLiteInt16:
|
||||
return NPY_INT16;
|
||||
case kTfLiteUInt8:
|
||||
return NPY_UINT8;
|
||||
case kTfLiteInt8:
|
||||
return NPY_INT8;
|
||||
case kTfLiteInt64:
|
||||
return NPY_INT64;
|
||||
case kTfLiteString:
|
||||
return NPY_OBJECT;
|
||||
case kTfLiteBool:
|
||||
return NPY_BOOL;
|
||||
case kTfLiteComplex64:
|
||||
return NPY_COMPLEX64;
|
||||
case kTfLiteNoType:
|
||||
return NPY_NOTYPE;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
}
|
||||
return NPY_NOTYPE;
|
||||
}
|
||||
|
||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
|
||||
int pyarray_type = PyArray_TYPE(array);
|
||||
switch (pyarray_type) {
|
||||
case NPY_FLOAT32:
|
||||
return kTfLiteFloat32;
|
||||
case NPY_INT32:
|
||||
return kTfLiteInt32;
|
||||
case NPY_INT16:
|
||||
return kTfLiteInt16;
|
||||
case NPY_UINT8:
|
||||
return kTfLiteUInt8;
|
||||
case NPY_INT8:
|
||||
return kTfLiteInt8;
|
||||
case NPY_INT64:
|
||||
return kTfLiteInt64;
|
||||
case NPY_BOOL:
|
||||
return kTfLiteBool;
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE:
|
||||
return kTfLiteString;
|
||||
case NPY_COMPLEX64:
|
||||
return kTfLiteComplex64;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
}
|
||||
return kTfLiteNoType;
|
||||
}
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
38
tensorflow/lite/python/interpreter_wrapper/python_utils.h
Normal file
38
tensorflow/lite/python/interpreter_wrapper/python_utils.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python_utils {
|
||||
|
||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
|
||||
|
||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
Loading…
Reference in New Issue
Block a user