Add pybind11 support to expose C++ functions to Python.

PiperOrigin-RevId: 240662841
This commit is contained in:
Amit Patankar 2019-03-27 16:09:30 -07:00 committed by TensorFlower Gardener
parent b29be96170
commit 233350baa2
4 changed files with 136 additions and 0 deletions

View File

@ -202,6 +202,7 @@ tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl

View File

@ -2212,3 +2212,105 @@ register_extension_info(
def tensorflow_opensource_extra_deps():
return []
def tf_pybind_extension(
name,
srcs,
hdrs,
module_name,
features = [],
srcs_version = "PY2AND3",
data = [],
copts = None,
nocopts = None,
linkopts = [],
deps = [],
visibility = None,
testonly = None,
licenses = None,
compatible_with = None,
restricted_to = None,
deprecation = None):
"""Builds a Python extension module."""
_ignore = [module_name]
p = name.rfind("/")
if p == -1:
sname = name
prefix = ""
else:
sname = name[p + 1:]
prefix = name[:p + 1]
so_file = "%s%s.so" % (prefix, sname)
pyd_file = "%s%s.pyd" % (prefix, sname)
symbol = "init%s" % sname
symbol2 = "init_%s" % sname
symbol3 = "PyInit_%s" % sname
exported_symbols_file = "%s-exported-symbols.lds" % name
version_script_file = "%s-version-script.lds" % name
native.genrule(
name = name + "_exported_symbols",
outs = [exported_symbols_file],
cmd = "echo '%s\n%s\n%s' >$@" % (symbol, symbol2, symbol3),
output_licenses = ["unencumbered"],
visibility = ["//visibility:private"],
testonly = testonly,
)
native.genrule(
name = name + "_version_script",
outs = [version_script_file],
cmd = "echo '{global:\n %s;\n %s;\n %s;\n local: *;};' >$@" % (symbol, symbol2, symbol3),
output_licenses = ["unencumbered"],
visibility = ["//visibility:private"],
testonly = testonly,
)
native.cc_binary(
name = so_file,
srcs = srcs + hdrs,
data = data,
copts = copts,
nocopts = nocopts,
linkopts = linkopts + select({
"//conditions:default": [
"-Wl,--version-script",
"$(location %s)" % version_script_file,
],
}),
deps = deps + [
exported_symbols_file,
version_script_file,
],
features = features,
linkshared = 1,
testonly = testonly,
licenses = licenses,
visibility = visibility,
deprecation = deprecation,
restricted_to = restricted_to,
compatible_with = compatible_with,
)
native.genrule(
name = name + "_pyd_copy",
srcs = [so_file],
outs = [pyd_file],
cmd = "cp $< $@",
output_to_bindir = True,
visibility = visibility,
deprecation = deprecation,
restricted_to = restricted_to,
compatible_with = compatible_with,
)
native.py_library(
name = name,
data = select({
"@org_tensorflow//tensorflow:windows": [pyd_file],
"//conditions:default": [so_file],
}),
srcs_version = srcs_version,
licenses = licenses,
testonly = testonly,
visibility = visibility,
deprecation = deprecation,
restricted_to = restricted_to,
compatible_with = compatible_with,
)

View File

@ -881,6 +881,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
)
tf_http_archive(
name = "pybind11",
urls = [
"https://mirror.bazel.build/github.com/pybind/pybind11/archive/v2.2.4.tar.gz",
"https://github.com/pybind/pybind11/archive/v2.2.4.tar.gz",
],
sha256 = "b69e83658513215b8d1443544d0549b7d231b9f201f6fc787a2b2218b408181e",
strip_prefix = "pybind11-2.2.4",
build_file = clean_dep("//third_party:pybind11.BUILD"),
)
##############################################################################
# BIND DEFINITIONS
#

22
third_party/pybind11.BUILD vendored Normal file
View File

@ -0,0 +1,22 @@
package(default_visibility = ["//visibility:public"])
cc_library(
name = "pybind11",
hdrs = glob(
include = [
"include/pybind11/*.h",
"include/pybind11/detail/*.h",
],
exclude = [
"include/pybind11/common.h",
"include/pybind11/eigen.h",
],
),
copts = [
"-fexceptions",
"-Xclang-only=-Wno-undefined-inline",
"-Xclang-only=-Wno-pragma-once-outside-header",
"-Xgcc-only=-Wno-error", # no way to just disable the pragma-once warning in gcc
],
includes = ["include"],
)