Refactor Python utility functions from Python wrapper.

- Moves ErrorReporter to a separate class.
- Moves some utility functions to a separate class.

PiperOrigin-RevId: 230387612
This commit is contained in:
Shashi Shekhar 2019-01-22 12:22:53 -08:00 committed by TensorFlower Gardener
parent 0b8ac6e1b9
commit bfcddf733e
7 changed files with 237 additions and 89 deletions

View File

@ -11,6 +11,8 @@ cc_library(
srcs = ["interpreter_wrapper.cc"],
hdrs = ["interpreter_wrapper.h"],
deps = [
":python_error_reporter",
":python_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/py/numpy:headers",
@ -19,6 +21,27 @@ cc_library(
],
)
cc_library(
name = "python_error_reporter",
srcs = ["python_error_reporter.cc"],
hdrs = ["python_error_reporter.h"],
deps = [
"//tensorflow/lite/core/api",
"//third_party/python_runtime:headers",
],
)
cc_library(
name = "python_utils",
srcs = ["python_utils.cc"],
hdrs = ["python_utils.h"],
deps = [
"//tensorflow/lite:framework",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
)
tf_py_wrap_cc(
name = "tensorflow_wrap_interpreter_wrapper",
srcs = [

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@ -60,36 +62,6 @@ limitations under the License.
namespace tflite {
namespace interpreter_wrapper {
class PythonErrorReporter : public tflite::ErrorReporter {
public:
PythonErrorReporter() {}
// Report an error message
int Report(const char* format, va_list args) override {
char buf[1024];
int formatted = vsnprintf(buf, sizeof(buf), format, args);
buffer_ << buf;
return formatted;
}
// Set's a Python runtime exception with the last error.
PyObject* exception() {
std::string last_message = message();
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
return nullptr;
}
// Gets the last error message and clears the buffer.
std::string message() {
std::string value = buffer_.str();
buffer_.clear();
return value;
}
private:
std::stringstream buffer_;
};
namespace {
// Calls PyArray's initialization to initialize all the API pointers. Note that
@ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
return interpreter;
}
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
switch (tf_lite_type) {
case kTfLiteFloat32:
return NPY_FLOAT32;
case kTfLiteInt32:
return NPY_INT32;
case kTfLiteInt16:
return NPY_INT16;
case kTfLiteUInt8:
return NPY_UINT8;
case kTfLiteInt8:
return NPY_INT8;
case kTfLiteInt64:
return NPY_INT64;
case kTfLiteString:
return NPY_OBJECT;
case kTfLiteBool:
return NPY_BOOL;
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteNoType:
return NPY_NOTYPE;
// Avoid default so compiler errors created when new types are made.
}
return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
int pyarray_type = PyArray_TYPE(array);
switch (pyarray_type) {
case NPY_FLOAT32:
return kTfLiteFloat32;
case NPY_INT32:
return kTfLiteInt32;
case NPY_INT16:
return kTfLiteInt16;
case NPY_UINT8:
return kTfLiteUInt8;
case NPY_INT8:
return kTfLiteInt8;
case NPY_INT64:
return kTfLiteInt64;
case NPY_BOOL:
return kTfLiteBool;
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE:
return kTfLiteString;
case NPY_COMPLEX64:
return kTfLiteComplex64;
// Avoid default so compiler errors created when new types are made.
}
return kTfLiteNoType;
}
struct PyDecrefDeleter {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
@ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
return nullptr;
}
int code = TfLiteTypeToPyArrayType(tensor->type);
int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
if (code == -1) {
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
return nullptr;
@ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (TfLiteTypeFromPyArray(array) != tensor->type) {
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
PyErr_Format(PyExc_ValueError,
"Cannot set tensor:"
" Got tensor of type %d"
" but expected type %d for input %d ",
TfLiteTypeFromPyArray(array), tensor->type, i);
python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
return nullptr;
}
@ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
return nullptr;
}
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
*type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
if (*type_num == -1) {
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
return nullptr;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
%}

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
namespace tflite {
namespace interpreter_wrapper {
// Report an error message
int PythonErrorReporter::Report(const char* format, va_list args) {
char buf[1024];
int formatted = vsnprintf(buf, sizeof(buf), format, args);
buffer_ << buf;
return formatted;
}
// Set's a Python runtime exception with the last error.
PyObject* PythonErrorReporter::exception() {
std::string last_message = message();
PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
return nullptr;
}
// Gets the last error message and clears the buffer.
std::string PythonErrorReporter::message() {
std::string value = buffer_.str();
buffer_.clear();
return value;
}
} // namespace interpreter_wrapper
} // namespace tflite

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
#include <Python.h>
#include <sstream>
#include <string>
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {
namespace interpreter_wrapper {
class PythonErrorReporter : public tflite::ErrorReporter {
public:
PythonErrorReporter() {}
// Report an error message
int Report(const char* format, va_list args) override;
// Sets a Python runtime exception with the last error and
// clears the error message buffer.
PyObject* exception();
// Gets the last error message and clears the buffer.
std::string message();
private:
std::stringstream buffer_;
};
} // namespace interpreter_wrapper
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
namespace tflite {
namespace python_utils {
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
switch (tf_lite_type) {
case kTfLiteFloat32:
return NPY_FLOAT32;
case kTfLiteInt32:
return NPY_INT32;
case kTfLiteInt16:
return NPY_INT16;
case kTfLiteUInt8:
return NPY_UINT8;
case kTfLiteInt8:
return NPY_INT8;
case kTfLiteInt64:
return NPY_INT64;
case kTfLiteString:
return NPY_OBJECT;
case kTfLiteBool:
return NPY_BOOL;
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteNoType:
return NPY_NOTYPE;
// Avoid default so compiler errors created when new types are made.
}
return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
int pyarray_type = PyArray_TYPE(array);
switch (pyarray_type) {
case NPY_FLOAT32:
return kTfLiteFloat32;
case NPY_INT32:
return kTfLiteInt32;
case NPY_INT16:
return kTfLiteInt16;
case NPY_UINT8:
return kTfLiteUInt8;
case NPY_INT8:
return kTfLiteInt8;
case NPY_INT64:
return kTfLiteInt64;
case NPY_BOOL:
return kTfLiteBool;
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE:
return kTfLiteString;
case NPY_COMPLEX64:
return kTfLiteComplex64;
// Avoid default so compiler errors created when new types are made.
}
return kTfLiteNoType;
}
} // namespace python_utils
} // namespace tflite

View File

@ -0,0 +1,38 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
#include "tensorflow/lite/context.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
namespace tflite {
namespace python_utils {
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
} // namespace python_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_