From bfcddf733e85c790921afccec4ca80849e9eb9ab Mon Sep 17 00:00:00 2001
From: Shashi Shekhar <shashishekhar@google.com>
Date: Tue, 22 Jan 2019 12:22:53 -0800
Subject: [PATCH] Refactor Python utility functions from Python wrapper. -
 Moves ErrorReporter to a separate class. - Moves some utility functions to a
 separate class.

PiperOrigin-RevId: 230387612
---
 .../lite/python/interpreter_wrapper/BUILD     | 23 +++++
 .../interpreter_wrapper.cc                    | 95 ++-----------------
 .../interpreter_wrapper/interpreter_wrapper.i |  1 +
 .../python_error_reporter.cc                  | 43 +++++++++
 .../python_error_reporter.h                   | 49 ++++++++++
 .../interpreter_wrapper/python_utils.cc       | 77 +++++++++++++++
 .../python/interpreter_wrapper/python_utils.h | 38 ++++++++
 7 files changed, 237 insertions(+), 89 deletions(-)
 create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc
 create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h
 create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_utils.cc
 create mode 100644 tensorflow/lite/python/interpreter_wrapper/python_utils.h

diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD
index 767a9fc4763..6de6fb48f78 100644
--- a/tensorflow/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/lite/python/interpreter_wrapper/BUILD
@@ -11,6 +11,8 @@ cc_library(
     srcs = ["interpreter_wrapper.cc"],
     hdrs = ["interpreter_wrapper.h"],
     deps = [
+        ":python_error_reporter",
+        ":python_utils",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/kernels:builtin_ops",
         "//third_party/py/numpy:headers",
@@ -19,6 +21,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "python_error_reporter",
+    srcs = ["python_error_reporter.cc"],
+    hdrs = ["python_error_reporter.h"],
+    deps = [
+        "//tensorflow/lite/core/api",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
+    name = "python_utils",
+    srcs = ["python_utils.cc"],
+    hdrs = ["python_utils.h"],
+    deps = [
+        "//tensorflow/lite:framework",
+        "//third_party/py/numpy:headers",
+        "//third_party/python_runtime:headers",
+    ],
+)
+
 tf_py_wrap_cc(
     name = "tensorflow_wrap_interpreter_wrapper",
     srcs = [
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index d14af439ec0..9ccaabbfe97 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -21,6 +21,8 @@ limitations under the License.
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
+#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
+#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
 
 // Disallow Numpy 1.7 deprecated symbols.
 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -60,36 +62,6 @@ limitations under the License.
 namespace tflite {
 namespace interpreter_wrapper {
 
-class PythonErrorReporter : public tflite::ErrorReporter {
- public:
-  PythonErrorReporter() {}
-
-  // Report an error message
-  int Report(const char* format, va_list args) override {
-    char buf[1024];
-    int formatted = vsnprintf(buf, sizeof(buf), format, args);
-    buffer_ << buf;
-    return formatted;
-  }
-
-  // Set's a Python runtime exception with the last error.
-  PyObject* exception() {
-    std::string last_message = message();
-    PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
-    return nullptr;
-  }
-
-  // Gets the last error message and clears the buffer.
-  std::string message() {
-    std::string value = buffer_.str();
-    buffer_.clear();
-    return value;
-  }
-
- private:
-  std::stringstream buffer_;
-};
-
 namespace {
 
 // Calls PyArray's initialization to initialize all the API pointers. Note that
@@ -114,61 +86,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
   return interpreter;
 }
 
-int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
-  switch (tf_lite_type) {
-    case kTfLiteFloat32:
-      return NPY_FLOAT32;
-    case kTfLiteInt32:
-      return NPY_INT32;
-    case kTfLiteInt16:
-      return NPY_INT16;
-    case kTfLiteUInt8:
-      return NPY_UINT8;
-    case kTfLiteInt8:
-      return NPY_INT8;
-    case kTfLiteInt64:
-      return NPY_INT64;
-    case kTfLiteString:
-      return NPY_OBJECT;
-    case kTfLiteBool:
-      return NPY_BOOL;
-    case kTfLiteComplex64:
-      return NPY_COMPLEX64;
-    case kTfLiteNoType:
-      return NPY_NOTYPE;
-      // Avoid default so compiler errors created when new types are made.
-  }
-  return NPY_NOTYPE;
-}
-
-TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
-  int pyarray_type = PyArray_TYPE(array);
-  switch (pyarray_type) {
-    case NPY_FLOAT32:
-      return kTfLiteFloat32;
-    case NPY_INT32:
-      return kTfLiteInt32;
-    case NPY_INT16:
-      return kTfLiteInt16;
-    case NPY_UINT8:
-      return kTfLiteUInt8;
-    case NPY_INT8:
-      return kTfLiteInt8;
-    case NPY_INT64:
-      return kTfLiteInt64;
-    case NPY_BOOL:
-      return kTfLiteBool;
-    case NPY_OBJECT:
-    case NPY_STRING:
-    case NPY_UNICODE:
-      return kTfLiteString;
-    case NPY_COMPLEX64:
-      return kTfLiteComplex64;
-      // Avoid default so compiler errors created when new types are made.
-  }
-  return kTfLiteNoType;
-}
-
 struct PyDecrefDeleter {
   void operator()(PyObject* p) const { Py_DECREF(p); }
 };
@@ -307,7 +224,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
     return nullptr;
   }
 
-  int code = TfLiteTypeToPyArrayType(tensor->type);
+  int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
   if (code == -1) {
     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
     return nullptr;
@@ -352,12 +269,12 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
   const TfLiteTensor* tensor = interpreter_->tensor(i);
 
-  if (TfLiteTypeFromPyArray(array) != tensor->type) {
+  if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
     PyErr_Format(PyExc_ValueError,
                  "Cannot set tensor:"
                  " Got tensor of type %d"
                  " but expected type %d for input %d ",
-                 TfLiteTypeFromPyArray(array), tensor->type, i);
+                 python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
     return nullptr;
   }
 
@@ -400,7 +317,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
     return nullptr;
   }
 
-  *type_num = TfLiteTypeToPyArrayType((*tensor)->type);
+  *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
   if (*type_num == -1) {
     PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
     return nullptr;
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i
index f52ef1eeca7..ef4b28f0472 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
 %}
 
 
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc
new file mode 100644
index 00000000000..803a4c29345
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc
@@ -0,0 +1,43 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
+
+namespace tflite {
+namespace interpreter_wrapper {
+
+// Report an error message
+int PythonErrorReporter::Report(const char* format, va_list args) {
+  char buf[1024];
+  int formatted = vsnprintf(buf, sizeof(buf), format, args);
+  buffer_ << buf;
+  return formatted;
+}
+
+// Set's a Python runtime exception with the last error.
+PyObject* PythonErrorReporter::exception() {
+  std::string last_message = message();
+  PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+  return nullptr;
+}
+
+// Gets the last error message and clears the buffer.
+std::string PythonErrorReporter::message() {
+  std::string value = buffer_.str();
+  buffer_.clear();
+  return value;
+}
+}  // namespace interpreter_wrapper
+}  // namespace tflite
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h
new file mode 100644
index 00000000000..7d4e308834a
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h
@@ -0,0 +1,49 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+
+#include <Python.h>
+
+#include <sstream>
+#include <string>
+
+#include "tensorflow/lite/core/api/error_reporter.h"
+
+namespace tflite {
+namespace interpreter_wrapper {
+
+class PythonErrorReporter : public tflite::ErrorReporter {
+ public:
+  PythonErrorReporter() {}
+
+  // Report an error message
+  int Report(const char* format, va_list args) override;
+
+  // Sets a Python runtime exception with the last error and
+  // clears the error message buffer.
+  PyObject* exception();
+
+  // Gets the last error message and clears the buffer.
+  std::string message();
+
+ private:
+  std::stringstream buffer_;
+};
+
+}  // namespace interpreter_wrapper
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc
new file mode 100644
index 00000000000..2dc604356ab
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc
@@ -0,0 +1,77 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
+
+namespace tflite {
+namespace python_utils {
+
+int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
+  switch (tf_lite_type) {
+    case kTfLiteFloat32:
+      return NPY_FLOAT32;
+    case kTfLiteInt32:
+      return NPY_INT32;
+    case kTfLiteInt16:
+      return NPY_INT16;
+    case kTfLiteUInt8:
+      return NPY_UINT8;
+    case kTfLiteInt8:
+      return NPY_INT8;
+    case kTfLiteInt64:
+      return NPY_INT64;
+    case kTfLiteString:
+      return NPY_OBJECT;
+    case kTfLiteBool:
+      return NPY_BOOL;
+    case kTfLiteComplex64:
+      return NPY_COMPLEX64;
+    case kTfLiteNoType:
+      return NPY_NOTYPE;
+      // Avoid default so compiler errors created when new types are made.
+  }
+  return NPY_NOTYPE;
+}
+
+TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
+  int pyarray_type = PyArray_TYPE(array);
+  switch (pyarray_type) {
+    case NPY_FLOAT32:
+      return kTfLiteFloat32;
+    case NPY_INT32:
+      return kTfLiteInt32;
+    case NPY_INT16:
+      return kTfLiteInt16;
+    case NPY_UINT8:
+      return kTfLiteUInt8;
+    case NPY_INT8:
+      return kTfLiteInt8;
+    case NPY_INT64:
+      return kTfLiteInt64;
+    case NPY_BOOL:
+      return kTfLiteBool;
+    case NPY_OBJECT:
+    case NPY_STRING:
+    case NPY_UNICODE:
+      return kTfLiteString;
+    case NPY_COMPLEX64:
+      return kTfLiteComplex64;
+      // Avoid default so compiler errors created when new types are made.
+  }
+  return kTfLiteNoType;
+}
+
+}  // namespace python_utils
+}  // namespace tflite
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.h b/tensorflow/lite/python/interpreter_wrapper/python_utils.h
new file mode 100644
index 00000000000..30a44226b8f
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.h
@@ -0,0 +1,38 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
+#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
+
+#include "tensorflow/lite/context.h"
+
+// Disallow Numpy 1.7 deprecated symbols.
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+#include <Python.h>
+
+#include "numpy/arrayobject.h"
+#include "numpy/ufuncobject.h"
+
+namespace tflite {
+namespace python_utils {
+
+int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
+
+TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
+
+}  // namespace python_utils
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_