From bfcddf733e85c790921afccec4ca80849e9eb9ab Mon Sep 17 00:00:00 2001 From: Shashi Shekhar <shashishekhar@google.com> Date: Tue, 22 Jan 2019 12:22:53 -0800 Subject: [PATCH] Refactor Python utility functions from Python wrapper. - Moves ErrorReporter to a separate class. - Moves some utility functions to a separate class. PiperOrigin-RevId: 230387612 --- .../lite/python/interpreter_wrapper/BUILD | 23 +++++ .../interpreter_wrapper.cc | 95 ++----------------- .../interpreter_wrapper/interpreter_wrapper.i | 1 + .../python_error_reporter.cc | 43 +++++++++ .../python_error_reporter.h | 49 ++++++++++ .../interpreter_wrapper/python_utils.cc | 77 +++++++++++++++ .../python/interpreter_wrapper/python_utils.h | 38 ++++++++ 7 files changed, 237 insertions(+), 89 deletions(-) create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_utils.cc create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_utils.h diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 767a9fc4763..6de6fb48f78 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["interpreter_wrapper.cc"], hdrs = ["interpreter_wrapper.h"], deps = [ + ":python_error_reporter", + ":python_utils", "//tensorflow/lite:framework", "//tensorflow/lite/kernels:builtin_ops", "//third_party/py/numpy:headers", @@ -19,6 +21,27 @@ cc_library( ], ) +cc_library( + name = "python_error_reporter", + srcs = ["python_error_reporter.cc"], + hdrs = ["python_error_reporter.h"], + deps = [ + "//tensorflow/lite/core/api", + "//third_party/python_runtime:headers", + ], +) + +cc_library( + name = "python_utils", + srcs = ["python_utils.cc"], + hdrs = ["python_utils.h"], + deps = [ + "//tensorflow/lite:framework", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + ], +) + tf_py_wrap_cc( name = "tensorflow_wrap_interpreter_wrapper", srcs = [ diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index d14af439ec0..9ccaabbfe97 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION @@ -60,36 +62,6 @@ limitations under the License. namespace tflite { namespace interpreter_wrapper { -class PythonErrorReporter : public tflite::ErrorReporter { - public: - PythonErrorReporter() {} - - // Report an error message - int Report(const char* format, va_list args) override { - char buf[1024]; - int formatted = vsnprintf(buf, sizeof(buf), format, args); - buffer_ << buf; - return formatted; - } - - // Set's a Python runtime exception with the last error. - PyObject* exception() { - std::string last_message = message(); - PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); - return nullptr; - } - - // Gets the last error message and clears the buffer. - std::string message() { - std::string value = buffer_.str(); - buffer_.clear(); - return value; - } - - private: - std::stringstream buffer_; -}; - namespace { // Calls PyArray's initialization to initialize all the API pointers. Note that @@ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( return interpreter; } -int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { - switch (tf_lite_type) { - case kTfLiteFloat32: - return NPY_FLOAT32; - case kTfLiteInt32: - return NPY_INT32; - case kTfLiteInt16: - return NPY_INT16; - case kTfLiteUInt8: - return NPY_UINT8; - case kTfLiteInt8: - return NPY_INT8; - case kTfLiteInt64: - return NPY_INT64; - case kTfLiteString: - return NPY_OBJECT; - case kTfLiteBool: - return NPY_BOOL; - case kTfLiteComplex64: - return NPY_COMPLEX64; - case kTfLiteNoType: - return NPY_NOTYPE; - // Avoid default so compiler errors created when new types are made. - } - return NPY_NOTYPE; -} - -TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { - int pyarray_type = PyArray_TYPE(array); - switch (pyarray_type) { - case NPY_FLOAT32: - return kTfLiteFloat32; - case NPY_INT32: - return kTfLiteInt32; - case NPY_INT16: - return kTfLiteInt16; - case NPY_UINT8: - return kTfLiteUInt8; - case NPY_INT8: - return kTfLiteInt8; - case NPY_INT64: - return kTfLiteInt64; - case NPY_BOOL: - return kTfLiteBool; - case NPY_OBJECT: - case NPY_STRING: - case NPY_UNICODE: - return kTfLiteString; - case NPY_COMPLEX64: - return kTfLiteComplex64; - // Avoid default so compiler errors created when new types are made. - } - return kTfLiteNoType; -} - struct PyDecrefDeleter { void operator()(PyObject* p) const { Py_DECREF(p); } }; @@ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const { return nullptr; } - int code = TfLiteTypeToPyArrayType(tensor->type); + int code = python_utils::TfLiteTypeToPyArrayType(tensor->type); if (code == -1) { PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); return nullptr; @@ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); const TfLiteTensor* tensor = interpreter_->tensor(i); - if (TfLiteTypeFromPyArray(array) != tensor->type) { + if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { PyErr_Format(PyExc_ValueError, "Cannot set tensor:" " Got tensor of type %d" " but expected type %d for input %d ", - TfLiteTypeFromPyArray(array), tensor->type, i); + python_utils::TfLiteTypeFromPyArray(array), tensor->type, i); return nullptr; } @@ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, return nullptr; } - *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type); if (*type_num == -1) { PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); return nullptr; diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i index f52ef1eeca7..ef4b28f0472 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" %} diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc new file mode 100644 index 00000000000..803a4c29345 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" + +namespace tflite { +namespace interpreter_wrapper { + +// Report an error message +int PythonErrorReporter::Report(const char* format, va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; +} + +// Set's a Python runtime exception with the last error. +PyObject* PythonErrorReporter::exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; +} + +// Gets the last error message and clears the buffer. +std::string PythonErrorReporter::message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; +} +} // namespace interpreter_wrapper +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h new file mode 100644 index 00000000000..7d4e308834a --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ + +#include <Python.h> + +#include <sstream> +#include <string> + +#include "tensorflow/lite/core/api/error_reporter.h" + +namespace tflite { +namespace interpreter_wrapper { + +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override; + + // Sets a Python runtime exception with the last error and + // clears the error message buffer. + PyObject* exception(); + + // Gets the last error message and clears the buffer. + std::string message(); + + private: + std::stringstream buffer_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc new file mode 100644 index 00000000000..2dc604356ab --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" + +namespace tflite { +namespace python_utils { + +int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { + switch (tf_lite_type) { + case kTfLiteFloat32: + return NPY_FLOAT32; + case kTfLiteInt32: + return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; + case kTfLiteUInt8: + return NPY_UINT8; + case kTfLiteInt8: + return NPY_INT8; + case kTfLiteInt64: + return NPY_INT64; + case kTfLiteString: + return NPY_OBJECT; + case kTfLiteBool: + return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; + case kTfLiteNoType: + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. + } + return NPY_NOTYPE; +} + +TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { + int pyarray_type = PyArray_TYPE(array); + switch (pyarray_type) { + case NPY_FLOAT32: + return kTfLiteFloat32; + case NPY_INT32: + return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; + case NPY_UINT8: + return kTfLiteUInt8; + case NPY_INT8: + return kTfLiteInt8; + case NPY_INT64: + return kTfLiteInt64; + case NPY_BOOL: + return kTfLiteBool; + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: + return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. + } + return kTfLiteNoType; +} + +} // namespace python_utils +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.h b/tensorflow/lite/python/interpreter_wrapper/python_utils.h new file mode 100644 index 00000000000..30a44226b8f --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ + +#include "tensorflow/lite/context.h" + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include <Python.h> + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" + +namespace tflite { +namespace python_utils { + +int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type); + +TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array); + +} // namespace python_utils +} // namespace tflite +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_