From 373069d5590141d1d91513dd7ac7b83dcd277dc6 Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Wed, 11 Mar 2020 23:12:24 -0700 Subject: [PATCH] Export the get_num_test_registerer_calls function from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. Also modified pybind_extension to accept additional symbols to export. PiperOrigin-RevId: 300483745 Change-Id: I7168f25688e52e84679732531c7cb22befe29d85 --- tensorflow/lite/python/BUILD | 2 +- tensorflow/lite/python/interpreter_test.py | 2 +- tensorflow/lite/python/testdata/BUILD | 17 ++++++--- .../lite/python/testdata/test_registerer.i | 20 ----------- .../testdata/test_registerer_wrapper.cc | 36 +++++++++++++++++++ tensorflow/tensorflow.bzl | 18 +++++++--- 6 files changed, 63 insertions(+), 32 deletions(-) delete mode 100644 tensorflow/lite/python/testdata/test_registerer.i create mode 100644 tensorflow/lite/python/testdata/test_registerer_wrapper.cc diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index fd156346580..5903a96fb52 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -44,7 +44,7 @@ py_test( ], deps = [ ":interpreter", - "//tensorflow/lite/python/testdata:test_registerer_wrapper", + "//tensorflow/lite/python/testdata:_pywrap_test_registerer", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 122f5f2d04c..2a10eb0cc69 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -33,7 +33,7 @@ if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'): sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) from tensorflow.lite.python import interpreter as interpreter_wrapper -from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer +from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index 23eb44cc06b..bff1ebe8f3b 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/lite:build_def.bzl", "tf_to_tflite") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs") package( default_visibility = ["//tensorflow:internal"], @@ -87,13 +88,19 @@ cc_library( alwayslink = 1, ) -tf_py_wrap_cc( - name = "test_registerer_wrapper", +pybind_extension( + name = "_pywrap_test_registerer", srcs = [ - "test_registerer.i", - ], + "test_registerer_wrapper.cc", + ] + tf_binary_additional_srcs(), + hdrs = ["test_registerer.h"], + additional_exported_symbols = ["TF_TestRegisterer"], + module_name = "_pywrap_test_registerer", deps = [ ":test_registerer", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", "//third_party/python_runtime:headers", + "@pybind11", ], ) diff --git a/tensorflow/lite/python/testdata/test_registerer.i b/tensorflow/lite/python/testdata/test_registerer.i deleted file mode 100644 index 1cd41c9164d..00000000000 --- a/tensorflow/lite/python/testdata/test_registerer.i +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%{ -#include "tensorflow/lite/python/testdata/test_registerer.h" -%} - -%include "tensorflow/lite/python/testdata/test_registerer.h" diff --git a/tensorflow/lite/python/testdata/test_registerer_wrapper.cc b/tensorflow/lite/python/testdata/test_registerer_wrapper.cc new file mode 100644 index 00000000000..c50dee4346c --- /dev/null +++ b/tensorflow/lite/python/testdata/test_registerer_wrapper.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "tensorflow/lite/python/testdata/test_registerer.h" + +PYBIND11_MODULE(_pywrap_test_registerer, m) { + m.doc() = R"pbdoc( + _pywrap_test_registerer + ----- + )pbdoc"; + m.def("get_num_test_registerer_calls", &tflite::get_num_test_registerer_calls, + R"pbdoc( + Returns the num_test_registerer_calls counter and re-sets it. + )pbdoc"); + m.def( + "TF_TestRegisterer", + [](uintptr_t resolver) { + tflite::TF_TestRegisterer( + reinterpret_cast(resolver)); + }, + R"pbdoc( + Dummy registerer function with the correct signature. Registers a fake + custom op needed by test models. Increments the + num_test_registerer_calls counter by one. + )pbdoc"); +} diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 1a77d0f64f0..300839ff3b4 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2526,6 +2526,7 @@ def pybind_extension( linkopts = [], deps = [], defines = [], + additional_exported_symbols = [], visibility = None, testonly = None, licenses = None, @@ -2544,15 +2545,22 @@ def pybind_extension( prefix = name[:p + 1] so_file = "%s%s.so" % (prefix, sname) pyd_file = "%s%s.pyd" % (prefix, sname) - symbol = "init%s" % sname - symbol2 = "init_%s" % sname - symbol3 = "PyInit_%s" % sname + exported_symbols = [ + "init%s" % sname, + "init_%s" % sname, + "PyInit_%s" % sname, + ] + additional_exported_symbols + exported_symbols_file = "%s-exported-symbols.lds" % name version_script_file = "%s-version-script.lds" % name + + exported_symbols_output = "\n".join(["_%s" % symbol for symbol in exported_symbols]) + version_script_output = "\n".join([" %s;" % symbol for symbol in exported_symbols]) + native.genrule( name = name + "_exported_symbols", outs = [exported_symbols_file], - cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3), + cmd = "echo '%s' >$@" % exported_symbols_output, output_licenses = ["unencumbered"], visibility = ["//visibility:private"], testonly = testonly, @@ -2561,7 +2569,7 @@ def pybind_extension( native.genrule( name = name + "_version_script", outs = [version_script_file], - cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3), + cmd = "echo '{global:\n%s\n local: *;};' >$@" % version_script_output, output_licenses = ["unencumbered"], visibility = ["//visibility:private"], testonly = testonly,