Export the SerializeAsHexString functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 288278210
Change-Id: Ie6e01d42e7b6155f88f2e4cd54953a785c58ac99
This commit is contained in:
Taehee Jeong 2020-01-06 04:24:51 -08:00 committed by TensorFlower Gardener
parent 3b05ac1fc2
commit 53843e51a7
3 changed files with 31 additions and 20 deletions

View File

@ -7,7 +7,7 @@ load(
)
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow/lite/testing:tflite_model_test.bzl", "tflite_model_test")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load(
"//tensorflow:tensorflow.bzl",
"py_test", # @unused
@ -121,8 +121,8 @@ py_library(
srcs = ["zip_test_utils.py"],
srcs_version = "PY2AND3",
deps = [
":_pywrap_string_util",
":generate_examples_report",
":string_util_wrapper",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
],
@ -513,14 +513,19 @@ cc_library(
],
)
tf_py_wrap_cc(
name = "string_util_wrapper",
pybind_extension(
name = "_pywrap_string_util",
srcs = [
"string_util.i",
"string_util_wrapper.cc",
],
hdrs = ["string_util.h"],
features = ["-use_header_modules"],
module_name = "_pywrap_string_util",
deps = [
":string_util_lib",
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)

View File

@ -13,19 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#define SWIG_FILE_WITH_INIT
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "tensorflow/lite/testing/string_util.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
%}
namespace py = pybind11;
namespace tflite {
namespace testing {
namespace python {
PyObject* SerializeAsHexString(PyObject* string_tensor);
} // namespace python
} // namespace testing
} // namespace tflite
PYBIND11_MODULE(_pywrap_string_util, m) {
m.doc() = R"pbdoc(
_pywrap_string_util
-----
)pbdoc";
m.def(
"SerializeAsHexString",
[](py::handle& string_tensor) {
return tensorflow::pyo_or_throw(
tflite::testing::python::SerializeAsHexString(string_tensor.ptr()));
},
R"pbdoc(
Serializes TF Lite dynamic buffer format as a HexString.
)pbdoc");
}

View File

@ -32,8 +32,8 @@ from six import StringIO
# pylint: disable=g-import-not-at-top
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.lite.testing import _pywrap_string_util
from tensorflow.lite.testing import generate_examples_report as report_lib
from tensorflow.lite.testing import string_util_wrapper
from tensorflow.python.framework import graph_util as tf_graph_util
# A map from names to functions which make test cases.
@ -156,7 +156,7 @@ def format_result(t):
values = ["{:.9f}".format(value) for value in list(t.flatten())]
return ",".join(values)
else:
return string_util_wrapper.SerializeAsHexString(t.flatten())
return _pywrap_string_util.SerializeAsHexString(t.flatten())
def write_examples(fp, examples):