Refactor Python utility functions from Python wrapper.
- Moves ErrorReporter to a separate class. - Moves some utility functions to a separate class. PiperOrigin-RevId: 230387612
This commit is contained in:
		
							parent
							
								
									0b8ac6e1b9
								
							
						
					
					
						commit
						bfcddf733e
					
				| @ -11,6 +11,8 @@ cc_library( | |||||||
|     srcs = ["interpreter_wrapper.cc"], |     srcs = ["interpreter_wrapper.cc"], | ||||||
|     hdrs = ["interpreter_wrapper.h"], |     hdrs = ["interpreter_wrapper.h"], | ||||||
|     deps = [ |     deps = [ | ||||||
|  |         ":python_error_reporter", | ||||||
|  |         ":python_utils", | ||||||
|         "//tensorflow/lite:framework", |         "//tensorflow/lite:framework", | ||||||
|         "//tensorflow/lite/kernels:builtin_ops", |         "//tensorflow/lite/kernels:builtin_ops", | ||||||
|         "//third_party/py/numpy:headers", |         "//third_party/py/numpy:headers", | ||||||
| @ -19,6 +21,27 @@ cc_library( | |||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "python_error_reporter", | ||||||
|  |     srcs = ["python_error_reporter.cc"], | ||||||
|  |     hdrs = ["python_error_reporter.h"], | ||||||
|  |     deps = [ | ||||||
|  |         "//tensorflow/lite/core/api", | ||||||
|  |         "//third_party/python_runtime:headers", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "python_utils", | ||||||
|  |     srcs = ["python_utils.cc"], | ||||||
|  |     hdrs = ["python_utils.h"], | ||||||
|  |     deps = [ | ||||||
|  |         "//tensorflow/lite:framework", | ||||||
|  |         "//third_party/py/numpy:headers", | ||||||
|  |         "//third_party/python_runtime:headers", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| tf_py_wrap_cc( | tf_py_wrap_cc( | ||||||
|     name = "tensorflow_wrap_interpreter_wrapper", |     name = "tensorflow_wrap_interpreter_wrapper", | ||||||
|     srcs = [ |     srcs = [ | ||||||
|  | |||||||
| @ -21,6 +21,8 @@ limitations under the License. | |||||||
| #include "tensorflow/lite/interpreter.h" | #include "tensorflow/lite/interpreter.h" | ||||||
| #include "tensorflow/lite/kernels/register.h" | #include "tensorflow/lite/kernels/register.h" | ||||||
| #include "tensorflow/lite/model.h" | #include "tensorflow/lite/model.h" | ||||||
|  | #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" | ||||||
|  | #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" | ||||||
| 
 | 
 | ||||||
| // Disallow Numpy 1.7 deprecated symbols.
 | // Disallow Numpy 1.7 deprecated symbols.
 | ||||||
| #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | ||||||
| @ -60,36 +62,6 @@ limitations under the License. | |||||||
| namespace tflite { | namespace tflite { | ||||||
| namespace interpreter_wrapper { | namespace interpreter_wrapper { | ||||||
| 
 | 
 | ||||||
| class PythonErrorReporter : public tflite::ErrorReporter { |  | ||||||
|  public: |  | ||||||
|   PythonErrorReporter() {} |  | ||||||
| 
 |  | ||||||
|   // Report an error message
 |  | ||||||
|   int Report(const char* format, va_list args) override { |  | ||||||
|     char buf[1024]; |  | ||||||
|     int formatted = vsnprintf(buf, sizeof(buf), format, args); |  | ||||||
|     buffer_ << buf; |  | ||||||
|     return formatted; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Set's a Python runtime exception with the last error.
 |  | ||||||
|   PyObject* exception() { |  | ||||||
|     std::string last_message = message(); |  | ||||||
|     PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); |  | ||||||
|     return nullptr; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Gets the last error message and clears the buffer.
 |  | ||||||
|   std::string message() { |  | ||||||
|     std::string value = buffer_.str(); |  | ||||||
|     buffer_.clear(); |  | ||||||
|     return value; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   std::stringstream buffer_; |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| // Calls PyArray's initialization to initialize all the API pointers. Note that
 | // Calls PyArray's initialization to initialize all the API pointers. Note that
 | ||||||
| @ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( | |||||||
|   return interpreter; |   return interpreter; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { |  | ||||||
|   switch (tf_lite_type) { |  | ||||||
|     case kTfLiteFloat32: |  | ||||||
|       return NPY_FLOAT32; |  | ||||||
|     case kTfLiteInt32: |  | ||||||
|       return NPY_INT32; |  | ||||||
|     case kTfLiteInt16: |  | ||||||
|       return NPY_INT16; |  | ||||||
|     case kTfLiteUInt8: |  | ||||||
|       return NPY_UINT8; |  | ||||||
|     case kTfLiteInt8: |  | ||||||
|       return NPY_INT8; |  | ||||||
|     case kTfLiteInt64: |  | ||||||
|       return NPY_INT64; |  | ||||||
|     case kTfLiteString: |  | ||||||
|       return NPY_OBJECT; |  | ||||||
|     case kTfLiteBool: |  | ||||||
|       return NPY_BOOL; |  | ||||||
|     case kTfLiteComplex64: |  | ||||||
|       return NPY_COMPLEX64; |  | ||||||
|     case kTfLiteNoType: |  | ||||||
|       return NPY_NOTYPE; |  | ||||||
|       // Avoid default so compiler errors created when new types are made.
 |  | ||||||
|   } |  | ||||||
|   return NPY_NOTYPE; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { |  | ||||||
|   int pyarray_type = PyArray_TYPE(array); |  | ||||||
|   switch (pyarray_type) { |  | ||||||
|     case NPY_FLOAT32: |  | ||||||
|       return kTfLiteFloat32; |  | ||||||
|     case NPY_INT32: |  | ||||||
|       return kTfLiteInt32; |  | ||||||
|     case NPY_INT16: |  | ||||||
|       return kTfLiteInt16; |  | ||||||
|     case NPY_UINT8: |  | ||||||
|       return kTfLiteUInt8; |  | ||||||
|     case NPY_INT8: |  | ||||||
|       return kTfLiteInt8; |  | ||||||
|     case NPY_INT64: |  | ||||||
|       return kTfLiteInt64; |  | ||||||
|     case NPY_BOOL: |  | ||||||
|       return kTfLiteBool; |  | ||||||
|     case NPY_OBJECT: |  | ||||||
|     case NPY_STRING: |  | ||||||
|     case NPY_UNICODE: |  | ||||||
|       return kTfLiteString; |  | ||||||
|     case NPY_COMPLEX64: |  | ||||||
|       return kTfLiteComplex64; |  | ||||||
|       // Avoid default so compiler errors created when new types are made.
 |  | ||||||
|   } |  | ||||||
|   return kTfLiteNoType; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| struct PyDecrefDeleter { | struct PyDecrefDeleter { | ||||||
|   void operator()(PyObject* p) const { Py_DECREF(p); } |   void operator()(PyObject* p) const { Py_DECREF(p); } | ||||||
| }; | }; | ||||||
| @ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const { | |||||||
|     return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   int code = TfLiteTypeToPyArrayType(tensor->type); |   int code = python_utils::TfLiteTypeToPyArrayType(tensor->type); | ||||||
|   if (code == -1) { |   if (code == -1) { | ||||||
|     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); |     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); | ||||||
|     return nullptr; |     return nullptr; | ||||||
| @ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { | |||||||
|   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); |   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); | ||||||
|   const TfLiteTensor* tensor = interpreter_->tensor(i); |   const TfLiteTensor* tensor = interpreter_->tensor(i); | ||||||
| 
 | 
 | ||||||
|   if (TfLiteTypeFromPyArray(array) != tensor->type) { |   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { | ||||||
|     PyErr_Format(PyExc_ValueError, |     PyErr_Format(PyExc_ValueError, | ||||||
|                  "Cannot set tensor:" |                  "Cannot set tensor:" | ||||||
|                  " Got tensor of type %d" |                  " Got tensor of type %d" | ||||||
|                  " but expected type %d for input %d ", |                  " but expected type %d for input %d ", | ||||||
|                  TfLiteTypeFromPyArray(array), tensor->type, i); |                  python_utils::TfLiteTypeFromPyArray(array), tensor->type, i); | ||||||
|     return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, | |||||||
|     return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   *type_num = TfLiteTypeToPyArrayType((*tensor)->type); |   *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type); | ||||||
|   if (*type_num == -1) { |   if (*type_num == -1) { | ||||||
|     PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); |     PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); | ||||||
|     return nullptr; |     return nullptr; | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ limitations under the License. | |||||||
| #include "tensorflow/lite/interpreter.h" | #include "tensorflow/lite/interpreter.h" | ||||||
| #include "tensorflow/lite/model.h" | #include "tensorflow/lite/model.h" | ||||||
| #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" | #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" | ||||||
|  | #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" | ||||||
| %} | %} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -0,0 +1,43 @@ | |||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" | ||||||
|  | 
 | ||||||
|  | namespace tflite { | ||||||
|  | namespace interpreter_wrapper { | ||||||
|  | 
 | ||||||
|  | // Report an error message
 | ||||||
|  | int PythonErrorReporter::Report(const char* format, va_list args) { | ||||||
|  |   char buf[1024]; | ||||||
|  |   int formatted = vsnprintf(buf, sizeof(buf), format, args); | ||||||
|  |   buffer_ << buf; | ||||||
|  |   return formatted; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Set's a Python runtime exception with the last error.
 | ||||||
|  | PyObject* PythonErrorReporter::exception() { | ||||||
|  |   std::string last_message = message(); | ||||||
|  |   PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); | ||||||
|  |   return nullptr; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Gets the last error message and clears the buffer.
 | ||||||
|  | std::string PythonErrorReporter::message() { | ||||||
|  |   std::string value = buffer_.str(); | ||||||
|  |   buffer_.clear(); | ||||||
|  |   return value; | ||||||
|  | } | ||||||
|  | }  // namespace interpreter_wrapper
 | ||||||
|  | }  // namespace tflite
 | ||||||
| @ -0,0 +1,49 @@ | |||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ | ||||||
|  | #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ | ||||||
|  | 
 | ||||||
|  | #include <Python.h> | ||||||
|  | 
 | ||||||
|  | #include <sstream> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/lite/core/api/error_reporter.h" | ||||||
|  | 
 | ||||||
|  | namespace tflite { | ||||||
|  | namespace interpreter_wrapper { | ||||||
|  | 
 | ||||||
|  | class PythonErrorReporter : public tflite::ErrorReporter { | ||||||
|  |  public: | ||||||
|  |   PythonErrorReporter() {} | ||||||
|  | 
 | ||||||
|  |   // Report an error message
 | ||||||
|  |   int Report(const char* format, va_list args) override; | ||||||
|  | 
 | ||||||
|  |   // Sets a Python runtime exception with the last error and
 | ||||||
|  |   // clears the error message buffer.
 | ||||||
|  |   PyObject* exception(); | ||||||
|  | 
 | ||||||
|  |   // Gets the last error message and clears the buffer.
 | ||||||
|  |   std::string message(); | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::stringstream buffer_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace interpreter_wrapper
 | ||||||
|  | }  // namespace tflite
 | ||||||
|  | #endif  // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
 | ||||||
							
								
								
									
										77
									
								
								tensorflow/lite/python/interpreter_wrapper/python_utils.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								tensorflow/lite/python/interpreter_wrapper/python_utils.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | |||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" | ||||||
|  | 
 | ||||||
|  | namespace tflite { | ||||||
|  | namespace python_utils { | ||||||
|  | 
 | ||||||
|  | int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { | ||||||
|  |   switch (tf_lite_type) { | ||||||
|  |     case kTfLiteFloat32: | ||||||
|  |       return NPY_FLOAT32; | ||||||
|  |     case kTfLiteInt32: | ||||||
|  |       return NPY_INT32; | ||||||
|  |     case kTfLiteInt16: | ||||||
|  |       return NPY_INT16; | ||||||
|  |     case kTfLiteUInt8: | ||||||
|  |       return NPY_UINT8; | ||||||
|  |     case kTfLiteInt8: | ||||||
|  |       return NPY_INT8; | ||||||
|  |     case kTfLiteInt64: | ||||||
|  |       return NPY_INT64; | ||||||
|  |     case kTfLiteString: | ||||||
|  |       return NPY_OBJECT; | ||||||
|  |     case kTfLiteBool: | ||||||
|  |       return NPY_BOOL; | ||||||
|  |     case kTfLiteComplex64: | ||||||
|  |       return NPY_COMPLEX64; | ||||||
|  |     case kTfLiteNoType: | ||||||
|  |       return NPY_NOTYPE; | ||||||
|  |       // Avoid default so compiler errors created when new types are made.
 | ||||||
|  |   } | ||||||
|  |   return NPY_NOTYPE; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { | ||||||
|  |   int pyarray_type = PyArray_TYPE(array); | ||||||
|  |   switch (pyarray_type) { | ||||||
|  |     case NPY_FLOAT32: | ||||||
|  |       return kTfLiteFloat32; | ||||||
|  |     case NPY_INT32: | ||||||
|  |       return kTfLiteInt32; | ||||||
|  |     case NPY_INT16: | ||||||
|  |       return kTfLiteInt16; | ||||||
|  |     case NPY_UINT8: | ||||||
|  |       return kTfLiteUInt8; | ||||||
|  |     case NPY_INT8: | ||||||
|  |       return kTfLiteInt8; | ||||||
|  |     case NPY_INT64: | ||||||
|  |       return kTfLiteInt64; | ||||||
|  |     case NPY_BOOL: | ||||||
|  |       return kTfLiteBool; | ||||||
|  |     case NPY_OBJECT: | ||||||
|  |     case NPY_STRING: | ||||||
|  |     case NPY_UNICODE: | ||||||
|  |       return kTfLiteString; | ||||||
|  |     case NPY_COMPLEX64: | ||||||
|  |       return kTfLiteComplex64; | ||||||
|  |       // Avoid default so compiler errors created when new types are made.
 | ||||||
|  |   } | ||||||
|  |   return kTfLiteNoType; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace python_utils
 | ||||||
|  | }  // namespace tflite
 | ||||||
							
								
								
									
										38
									
								
								tensorflow/lite/python/interpreter_wrapper/python_utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								tensorflow/lite/python/interpreter_wrapper/python_utils.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | |||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ | ||||||
|  | #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/lite/context.h" | ||||||
|  | 
 | ||||||
|  | // Disallow Numpy 1.7 deprecated symbols.
 | ||||||
|  | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | ||||||
|  | 
 | ||||||
|  | #include <Python.h> | ||||||
|  | 
 | ||||||
|  | #include "numpy/arrayobject.h" | ||||||
|  | #include "numpy/ufuncobject.h" | ||||||
|  | 
 | ||||||
|  | namespace tflite { | ||||||
|  | namespace python_utils { | ||||||
|  | 
 | ||||||
|  | int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type); | ||||||
|  | 
 | ||||||
|  | TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array); | ||||||
|  | 
 | ||||||
|  | }  // namespace python_utils
 | ||||||
|  | }  // namespace tflite
 | ||||||
|  | #endif  // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
 | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user