diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index a42228313f9..18502b78c48 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -7,7 +7,7 @@ load( ) load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow/lite/testing:tflite_model_test.bzl", "tflite_model_test") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "pybind_extension") load( "//tensorflow:tensorflow.bzl", "py_test", # @unused @@ -121,8 +121,8 @@ py_library( srcs = ["zip_test_utils.py"], srcs_version = "PY2AND3", deps = [ + ":_pywrap_string_util", ":generate_examples_report", - ":string_util_wrapper", "//tensorflow:tensorflow_py", "//third_party/py/numpy", ], @@ -513,14 +513,19 @@ cc_library( ], ) -tf_py_wrap_cc( - name = "string_util_wrapper", +pybind_extension( + name = "_pywrap_string_util", srcs = [ - "string_util.i", + "string_util_wrapper.cc", ], + hdrs = ["string_util.h"], + features = ["-use_header_modules"], + module_name = "_pywrap_string_util", deps = [ ":string_util_lib", + "//tensorflow/python:pybind11_lib", "//third_party/python_runtime:headers", + "@pybind11", ], ) diff --git a/tensorflow/lite/testing/string_util.i b/tensorflow/lite/testing/string_util_wrapper.cc similarity index 56% rename from tensorflow/lite/testing/string_util.i rename to tensorflow/lite/testing/string_util_wrapper.cc index 574abb79653..f5b490ab617 100644 --- a/tensorflow/lite/testing/string_util.i +++ b/tensorflow/lite/testing/string_util_wrapper.cc @@ -13,19 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%{ - -#define SWIG_FILE_WITH_INIT +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" #include "tensorflow/lite/testing/string_util.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" -%} +namespace py = pybind11; -namespace tflite { -namespace testing { -namespace python { - -PyObject* SerializeAsHexString(PyObject* string_tensor); - -} // namespace python -} // namespace testing -} // namespace tflite +PYBIND11_MODULE(_pywrap_string_util, m) { + m.doc() = R"pbdoc( + _pywrap_string_util + ----- + )pbdoc"; + m.def( + "SerializeAsHexString", + [](py::handle& string_tensor) { + return tensorflow::pyo_or_throw( + tflite::testing::python::SerializeAsHexString(string_tensor.ptr())); + }, + R"pbdoc( + Serializes TF Lite dynamic buffer format as a HexString. + )pbdoc"); +} diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index 3d380ff0385..dcfe77875ff 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -32,8 +32,8 @@ from six import StringIO # pylint: disable=g-import-not-at-top import tensorflow as tf from google.protobuf import text_format +from tensorflow.lite.testing import _pywrap_string_util from tensorflow.lite.testing import generate_examples_report as report_lib -from tensorflow.lite.testing import string_util_wrapper from tensorflow.python.framework import graph_util as tf_graph_util # A map from names to functions which make test cases. @@ -156,7 +156,7 @@ def format_result(t): values = ["{:.9f}".format(value) for value in list(t.flatten())] return ",".join(values) else: - return string_util_wrapper.SerializeAsHexString(t.flatten()) + return _pywrap_string_util.SerializeAsHexString(t.flatten()) def write_examples(fp, examples):