Export the get_num_test_registerer_calls function from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
Also modified pybind_extension to accept additional symbols to export. PiperOrigin-RevId: 300483745 Change-Id: I7168f25688e52e84679732531c7cb22befe29d85
This commit is contained in:
parent
4783a85dec
commit
373069d559
@ -44,7 +44,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":interpreter",
|
||||
"//tensorflow/lite/python/testdata:test_registerer_wrapper",
|
||||
"//tensorflow/lite/python/testdata:_pywrap_test_registerer",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
|
@ -33,7 +33,7 @@ if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
from tensorflow.lite.python import interpreter as interpreter_wrapper
|
||||
from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer
|
||||
from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
|
17
tensorflow/lite/python/testdata/BUILD
vendored
17
tensorflow/lite/python/testdata/BUILD
vendored
@ -1,5 +1,6 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -87,13 +88,19 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "test_registerer_wrapper",
|
||||
pybind_extension(
|
||||
name = "_pywrap_test_registerer",
|
||||
srcs = [
|
||||
"test_registerer.i",
|
||||
],
|
||||
"test_registerer_wrapper.cc",
|
||||
] + tf_binary_additional_srcs(),
|
||||
hdrs = ["test_registerer.h"],
|
||||
additional_exported_symbols = ["TF_TestRegisterer"],
|
||||
module_name = "_pywrap_test_registerer",
|
||||
deps = [
|
||||
":test_registerer",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -1,20 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
||||
%}
|
||||
|
||||
%include "tensorflow/lite/python/testdata/test_registerer.h"
|
36
tensorflow/lite/python/testdata/test_registerer_wrapper.cc
vendored
Normal file
36
tensorflow/lite/python/testdata/test_registerer_wrapper.cc
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/pytypes.h"
|
||||
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_test_registerer, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_test_registerer
|
||||
-----
|
||||
)pbdoc";
|
||||
m.def("get_num_test_registerer_calls", &tflite::get_num_test_registerer_calls,
|
||||
R"pbdoc(
|
||||
Returns the num_test_registerer_calls counter and re-sets it.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"TF_TestRegisterer",
|
||||
[](uintptr_t resolver) {
|
||||
tflite::TF_TestRegisterer(
|
||||
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
|
||||
},
|
||||
R"pbdoc(
|
||||
Dummy registerer function with the correct signature. Registers a fake
|
||||
custom op needed by test models. Increments the
|
||||
num_test_registerer_calls counter by one.
|
||||
)pbdoc");
|
||||
}
|
@ -2526,6 +2526,7 @@ def pybind_extension(
|
||||
linkopts = [],
|
||||
deps = [],
|
||||
defines = [],
|
||||
additional_exported_symbols = [],
|
||||
visibility = None,
|
||||
testonly = None,
|
||||
licenses = None,
|
||||
@ -2544,15 +2545,22 @@ def pybind_extension(
|
||||
prefix = name[:p + 1]
|
||||
so_file = "%s%s.so" % (prefix, sname)
|
||||
pyd_file = "%s%s.pyd" % (prefix, sname)
|
||||
symbol = "init%s" % sname
|
||||
symbol2 = "init_%s" % sname
|
||||
symbol3 = "PyInit_%s" % sname
|
||||
exported_symbols = [
|
||||
"init%s" % sname,
|
||||
"init_%s" % sname,
|
||||
"PyInit_%s" % sname,
|
||||
] + additional_exported_symbols
|
||||
|
||||
exported_symbols_file = "%s-exported-symbols.lds" % name
|
||||
version_script_file = "%s-version-script.lds" % name
|
||||
|
||||
exported_symbols_output = "\n".join(["_%s" % symbol for symbol in exported_symbols])
|
||||
version_script_output = "\n".join([" %s;" % symbol for symbol in exported_symbols])
|
||||
|
||||
native.genrule(
|
||||
name = name + "_exported_symbols",
|
||||
outs = [exported_symbols_file],
|
||||
cmd = "echo '_%s\n_%s\n_%s' >$@" % (symbol, symbol2, symbol3),
|
||||
cmd = "echo '%s' >$@" % exported_symbols_output,
|
||||
output_licenses = ["unencumbered"],
|
||||
visibility = ["//visibility:private"],
|
||||
testonly = testonly,
|
||||
@ -2561,7 +2569,7 @@ def pybind_extension(
|
||||
native.genrule(
|
||||
name = name + "_version_script",
|
||||
outs = [version_script_file],
|
||||
cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3),
|
||||
cmd = "echo '{global:\n%s\n local: *;};' >$@" % version_script_output,
|
||||
output_licenses = ["unencumbered"],
|
||||
visibility = ["//visibility:private"],
|
||||
testonly = testonly,
|
||||
|
Loading…
Reference in New Issue
Block a user