Export the get_num_test_registerer_calls function from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
Also modified pybind_extension to accept additional symbols to export. PiperOrigin-RevId: 300483745 Change-Id: I7168f25688e52e84679732531c7cb22befe29d85
This commit is contained in:
parent
4783a85dec
commit
373069d559
@ -44,7 +44,7 @@ py_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":interpreter",
|
":interpreter",
|
||||||
"//tensorflow/lite/python/testdata:test_registerer_wrapper",
|
"//tensorflow/lite/python/testdata:_pywrap_test_registerer",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
|
@ -33,7 +33,7 @@ if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
|
|||||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||||
|
|
||||||
from tensorflow.lite.python import interpreter as interpreter_wrapper
|
from tensorflow.lite.python import interpreter as interpreter_wrapper
|
||||||
from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer
|
from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
17
tensorflow/lite/python/testdata/BUILD
vendored
17
tensorflow/lite/python/testdata/BUILD
vendored
@ -1,5 +1,6 @@
|
|||||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = ["//tensorflow:internal"],
|
||||||
@ -87,13 +88,19 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_wrap_cc(
|
pybind_extension(
|
||||||
name = "test_registerer_wrapper",
|
name = "_pywrap_test_registerer",
|
||||||
srcs = [
|
srcs = [
|
||||||
"test_registerer.i",
|
"test_registerer_wrapper.cc",
|
||||||
],
|
] + tf_binary_additional_srcs(),
|
||||||
|
hdrs = ["test_registerer.h"],
|
||||||
|
additional_exported_symbols = ["TF_TestRegisterer"],
|
||||||
|
module_name = "_pywrap_test_registerer",
|
||||||
deps = [
|
deps = [
|
||||||
":test_registerer",
|
":test_registerer",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//third_party/python_runtime:headers",
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%include "tensorflow/lite/python/testdata/test_registerer.h"
|
|
36
tensorflow/lite/python/testdata/test_registerer_wrapper.cc
vendored
Normal file
36
tensorflow/lite/python/testdata/test_registerer_wrapper.cc
vendored
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_test_registerer, m) {
|
||||||
|
m.doc() = R"pbdoc(
|
||||||
|
_pywrap_test_registerer
|
||||||
|
-----
|
||||||
|
)pbdoc";
|
||||||
|
m.def("get_num_test_registerer_calls", &tflite::get_num_test_registerer_calls,
|
||||||
|
R"pbdoc(
|
||||||
|
Returns the num_test_registerer_calls counter and re-sets it.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"TF_TestRegisterer",
|
||||||
|
[](uintptr_t resolver) {
|
||||||
|
tflite::TF_TestRegisterer(
|
||||||
|
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Dummy registerer function with the correct signature. Registers a fake
|
||||||
|
custom op needed by test models. Increments the
|
||||||
|
num_test_registerer_calls counter by one.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -2526,6 +2526,7 @@ def pybind_extension(
|
|||||||
linkopts = [],
|
linkopts = [],
|
||||||
deps = [],
|
deps = [],
|
||||||
defines = [],
|
defines = [],
|
||||||
|
additional_exported_symbols = [],
|
||||||
visibility = None,
|
visibility = None,
|
||||||
testonly = None,
|
testonly = None,
|
||||||
licenses = None,
|
licenses = None,
|
||||||
@ -2544,15 +2545,22 @@ def pybind_extension(
|
|||||||
prefix = name[:p + 1]
|
prefix = name[:p + 1]
|
||||||
so_file = "%s%s.so" % (prefix, sname)
|
so_file = "%s%s.so" % (prefix, sname)
|
||||||
pyd_file = "%s%s.pyd" % (prefix, sname)
|
pyd_file = "%s%s.pyd" % (prefix, sname)
|
||||||
symbol = "init%s" % sname
|
exported_symbols = [
|
||||||
symbol2 = "init_%s" % sname
|
"init%s" % sname,
|
||||||
symbol3 = "PyInit_%s" % sname
|
"init_%s" % sname,
|
||||||
|
"PyInit_%s" % sname,
|
||||||
|
] + additional_exported_symbols
|
||||||
|
|
||||||
exported_symbols_file = "%s-exported-symbols.lds" % name
|
exported_symbols_file = "%s-exported-symbols.lds" % name
|
||||||
version_script_file = "%s-version-script.lds" % name
|
version_script_file = "%s-version-script.lds" % name
|
||||||
|
|
||||||
|
exported_symbols_output = "\n".join(["_%s" % symbol for symbol in exported_symbols])
|
||||||
|
version_script_output = "\n".join([" %s;" % symbol for symbol in exported_symbols])
|
||||||
|
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = name + "_exported_symbols",
|
name = name + "_exported_symbols",
|
||||||
outs = [exported_symbols_file],
|
outs = [exported_symbols_file],
|
||||||
cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3),
|
cmd = "echo '%s' >$@" % exported_symbols_output,
|
||||||
output_licenses = ["unencumbered"],
|
output_licenses = ["unencumbered"],
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
@ -2561,7 +2569,7 @@ def pybind_extension(
|
|||||||
native.genrule(
|
native.genrule(
|
||||||
name = name + "_version_script",
|
name = name + "_version_script",
|
||||||
outs = [version_script_file],
|
outs = [version_script_file],
|
||||||
cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3),
|
cmd = "echo '{global:\n%s\n local: *;};' >$@" % version_script_output,
|
||||||
output_licenses = ["unencumbered"],
|
output_licenses = ["unencumbered"],
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
|
Loading…
Reference in New Issue
Block a user