Export the get_num_test_registerer_calls function from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

Also modified pybind_extension to accept additional symbols to export.

PiperOrigin-RevId: 300483745
Change-Id: I7168f25688e52e84679732531c7cb22befe29d85
This commit is contained in:
Taehee Jeong 2020-03-11 23:12:24 -07:00 committed by TensorFlower Gardener
parent 4783a85dec
commit 373069d559
6 changed files with 63 additions and 32 deletions

View File

@ -44,7 +44,7 @@ py_test(
], ],
deps = [ deps = [
":interpreter", ":interpreter",
"//tensorflow/lite/python/testdata:test_registerer_wrapper", "//tensorflow/lite/python/testdata:_pywrap_test_registerer",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform", "//tensorflow/python:platform",

View File

@ -33,7 +33,7 @@ if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
from tensorflow.lite.python import interpreter as interpreter_wrapper from tensorflow.lite.python import interpreter as interpreter_wrapper
from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -1,5 +1,6 @@
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite") load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs")
package( package(
default_visibility = ["//tensorflow:internal"], default_visibility = ["//tensorflow:internal"],
@ -87,13 +88,19 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
tf_py_wrap_cc( pybind_extension(
name = "test_registerer_wrapper", name = "_pywrap_test_registerer",
srcs = [ srcs = [
"test_registerer.i", "test_registerer_wrapper.cc",
], ] + tf_binary_additional_srcs(),
hdrs = ["test_registerer.h"],
additional_exported_symbols = ["TF_TestRegisterer"],
module_name = "_pywrap_test_registerer",
deps = [ deps = [
":test_registerer", ":test_registerer",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/python_runtime:headers", "//third_party/python_runtime:headers",
"@pybind11",
], ],
) )

View File

@ -1,20 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#include "tensorflow/lite/python/testdata/test_registerer.h"
%}
%include "tensorflow/lite/python/testdata/test_registerer.h"

View File

@ -0,0 +1,36 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "tensorflow/lite/python/testdata/test_registerer.h"
PYBIND11_MODULE(_pywrap_test_registerer, m) {
m.doc() = R"pbdoc(
_pywrap_test_registerer
-----
)pbdoc";
m.def("get_num_test_registerer_calls", &tflite::get_num_test_registerer_calls,
R"pbdoc(
Returns the num_test_registerer_calls counter and re-sets it.
)pbdoc");
m.def(
"TF_TestRegisterer",
[](uintptr_t resolver) {
tflite::TF_TestRegisterer(
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
},
R"pbdoc(
Dummy registerer function with the correct signature. Registers a fake
custom op needed by test models. Increments the
num_test_registerer_calls counter by one.
)pbdoc");
}

View File

@ -2526,6 +2526,7 @@ def pybind_extension(
linkopts = [], linkopts = [],
deps = [], deps = [],
defines = [], defines = [],
additional_exported_symbols = [],
visibility = None, visibility = None,
testonly = None, testonly = None,
licenses = None, licenses = None,
@ -2544,15 +2545,22 @@ def pybind_extension(
prefix = name[:p + 1] prefix = name[:p + 1]
so_file = "%s%s.so" % (prefix, sname) so_file = "%s%s.so" % (prefix, sname)
pyd_file = "%s%s.pyd" % (prefix, sname) pyd_file = "%s%s.pyd" % (prefix, sname)
symbol = "init%s" % sname exported_symbols = [
symbol2 = "init_%s" % sname "init%s" % sname,
symbol3 = "PyInit_%s" % sname "init_%s" % sname,
"PyInit_%s" % sname,
] + additional_exported_symbols
exported_symbols_file = "%s-exported-symbols.lds" % name exported_symbols_file = "%s-exported-symbols.lds" % name
version_script_file = "%s-version-script.lds" % name version_script_file = "%s-version-script.lds" % name
exported_symbols_output = "\n".join(["_%s" % symbol for symbol in exported_symbols])
version_script_output = "\n".join([" %s;" % symbol for symbol in exported_symbols])
native.genrule( native.genrule(
name = name + "_exported_symbols", name = name + "_exported_symbols",
outs = [exported_symbols_file], outs = [exported_symbols_file],
cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3), cmd = "echo '%s' >$@" % exported_symbols_output,
output_licenses = ["unencumbered"], output_licenses = ["unencumbered"],
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
testonly = testonly, testonly = testonly,
@ -2561,7 +2569,7 @@ def pybind_extension(
native.genrule( native.genrule(
name = name + "_version_script", name = name + "_version_script",
outs = [version_script_file], outs = [version_script_file],
cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3), cmd = "echo '{global:\n%s\n local: *;};' >$@" % version_script_output,
output_licenses = ["unencumbered"], output_licenses = ["unencumbered"],
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
testonly = testonly, testonly = testonly,