[TF:TRT] Support OSS build that uses TensorRT static libraries.
Allow the building of TensorFlow that uses TensorRT static libraries by adding options --define=TF_TENSORRT_STATIC=1 --action_env=TF_TENSORRT_STATIC_PATH= to the bazel command. PiperOrigin-RevId: 345503209 Change-Id: I3aeb39a672b1f9a7c2bab7f9f6018c304685c302
This commit is contained in:
parent
c10504b9fa
commit
b1660e7e18
@ -50,10 +50,26 @@ cc_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copybara will replace tensorrt_lib_oss_placeholder with this one
|
||||||
|
alias(
|
||||||
|
name = "tensorrt_lib_oss",
|
||||||
|
actual = select({
|
||||||
|
"@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt",
|
||||||
|
"//conditions:default": ":tensorrt_stub",
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "tensorrt_lib_oss_placeholder",
|
||||||
|
actual = ":tensorrt_stub",
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "tensorrt_lib",
|
name = "tensorrt_lib",
|
||||||
actual = select({
|
actual = select({
|
||||||
"//tensorflow:oss": ":tensorrt_stub",
|
"//tensorflow:oss": ":tensorrt_lib_oss",
|
||||||
"//conditions:default": "@local_config_tensorrt//:tensorrt",
|
"//conditions:default": "@local_config_tensorrt//:tensorrt",
|
||||||
}),
|
}),
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
@ -590,11 +606,21 @@ tf_proto_library(
|
|||||||
protodeps = tf_additional_all_protos(),
|
protodeps = tf_additional_all_protos(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tensorrt_static_define_oss_placeholder = []
|
||||||
|
|
||||||
|
# @unused
|
||||||
|
# Copybara will replace tensorrt_static_define_oss_placeholder with this one
|
||||||
|
tensorrt_static_define_oss = select({
|
||||||
|
"@local_config_tensorrt//:use_static_tensorrt": ["TF_OSS_TENSORRT_STATIC=1"],
|
||||||
|
"//conditions:default": [],
|
||||||
|
})
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "py_utils",
|
name = "py_utils",
|
||||||
srcs = ["utils/py_utils.cc"],
|
srcs = ["utils/py_utils.cc"],
|
||||||
hdrs = ["utils/py_utils.h"],
|
hdrs = ["utils/py_utils.h"],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
|
defines = tensorrt_static_define_oss,
|
||||||
deps = if_tensorrt([
|
deps = if_tensorrt([
|
||||||
":common_utils",
|
":common_utils",
|
||||||
":tensorrt_lib",
|
":tensorrt_lib",
|
||||||
|
@ -26,6 +26,10 @@ namespace tensorrt {
|
|||||||
|
|
||||||
bool IsGoogleTensorRTEnabled() {
|
bool IsGoogleTensorRTEnabled() {
|
||||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
#if TF_OSS_TENSORRT_STATIC
|
||||||
|
LOG(INFO) << "TensorRT libraries are statically linked, skip dlopen check";
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
|
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
|
||||||
if (!handle_or.ok()) {
|
if (!handle_or.ok()) {
|
||||||
LOG_WARNING_WITH_PREFIX
|
LOG_WARNING_WITH_PREFIX
|
||||||
@ -36,6 +40,7 @@ bool IsGoogleTensorRTEnabled() {
|
|||||||
} else {
|
} else {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
16
third_party/tensorrt/BUILD.tpl
vendored
16
third_party/tensorrt/BUILD.tpl
vendored
@ -10,6 +10,11 @@ package(default_visibility = ["//visibility:public"])
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "use_static_tensorrt",
|
||||||
|
define_values = {"TF_TENSORRT_STATIC":"1"},
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorrt_headers",
|
name = "tensorrt_headers",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
@ -22,12 +27,19 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorrt",
|
name = "tensorrt",
|
||||||
srcs = [":tensorrt_lib"],
|
srcs = select({
|
||||||
|
":use_static_tensorrt": [":tensorrt_static_lib"],
|
||||||
|
"//conditions:default": [":tensorrt_lib"],
|
||||||
|
}),
|
||||||
copts = cuda_default_copts(),
|
copts = cuda_default_copts(),
|
||||||
data = [":tensorrt_lib"],
|
data = select({
|
||||||
|
":use_static_tensorrt": [],
|
||||||
|
"//conditions:default": [":tensorrt_lib"],
|
||||||
|
}),
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
":tensorrt_headers",
|
":tensorrt_headers",
|
||||||
|
# TODO(b/174608722): fix this line.
|
||||||
"@local_config_cuda//cuda",
|
"@local_config_cuda//cuda",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -20,6 +20,7 @@ load(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
||||||
|
_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
|
||||||
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
|
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
|
||||||
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
|
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
|
||||||
_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
|
_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
|
||||||
@ -91,6 +92,10 @@ def enable_tensorrt(repository_ctx):
|
|||||||
"""Returns whether to build with TensorRT support."""
|
"""Returns whether to build with TensorRT support."""
|
||||||
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||||
|
|
||||||
|
def _get_tensorrt_static_path(repository_ctx):
|
||||||
|
"""Returns the path for TensorRT static libraries."""
|
||||||
|
return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
|
||||||
|
|
||||||
def _create_local_tensorrt_repository(repository_ctx):
|
def _create_local_tensorrt_repository(repository_ctx):
|
||||||
# Resolve all labels before doing any real work. Resolving causes the
|
# Resolve all labels before doing any real work. Resolving causes the
|
||||||
# function to be restarted with all previous state being lost. This
|
# function to be restarted with all previous state being lost. This
|
||||||
@ -110,6 +115,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
|||||||
|
|
||||||
# Copy the library and header files.
|
# Copy the library and header files.
|
||||||
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
||||||
|
|
||||||
library_dir = config["tensorrt_library_dir"] + "/"
|
library_dir = config["tensorrt_library_dir"] + "/"
|
||||||
headers = _get_tensorrt_headers(trt_version)
|
headers = _get_tensorrt_headers(trt_version)
|
||||||
include_dir = config["tensorrt_include_dir"] + "/"
|
include_dir = config["tensorrt_include_dir"] + "/"
|
||||||
@ -128,6 +134,19 @@ def _create_local_tensorrt_repository(repository_ctx):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
|
||||||
|
raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
|
||||||
|
static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in raw_static_library_names]
|
||||||
|
if tensorrt_static_path != None:
|
||||||
|
copy_rules = copy_rules + [
|
||||||
|
make_copy_files_rule(
|
||||||
|
repository_ctx,
|
||||||
|
name = "tensorrt_static_lib",
|
||||||
|
srcs = [tensorrt_static_path + library for library in static_libraries],
|
||||||
|
outs = ["tensorrt/lib/" + library for library in static_libraries],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# Set up config file.
|
# Set up config file.
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
"build_defs.bzl",
|
"build_defs.bzl",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user