[TF:TRT] Support OSS build that uses TensorRT static libraries.

Allow the building of TensorFlow that uses TensorRT static libraries by adding
options --define=TF_TENSORRT_STATIC=1 --action_env=TF_TENSORRT_STATIC_PATH= to
the bazel command.

PiperOrigin-RevId: 345503209
Change-Id: I3aeb39a672b1f9a7c2bab7f9f6018c304685c302
This commit is contained in:
Bixia Zheng 2020-12-03 11:58:23 -08:00 committed by TensorFlower Gardener
parent c10504b9fa
commit b1660e7e18
4 changed files with 65 additions and 3 deletions

View File

@ -50,10 +50,26 @@ cc_library(
]), ]),
) )
# Copybara will replace tensorrt_lib_oss_placeholder with this one
alias(
name = "tensorrt_lib_oss",
actual = select({
"@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt",
"//conditions:default": ":tensorrt_stub",
}),
visibility = ["//visibility:private"],
)
alias(
name = "tensorrt_lib_oss_placeholder",
actual = ":tensorrt_stub",
visibility = ["//visibility:private"],
)
alias( alias(
name = "tensorrt_lib", name = "tensorrt_lib",
actual = select({ actual = select({
"//tensorflow:oss": ":tensorrt_stub", "//tensorflow:oss": ":tensorrt_lib_oss",
"//conditions:default": "@local_config_tensorrt//:tensorrt", "//conditions:default": "@local_config_tensorrt//:tensorrt",
}), }),
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
@ -590,11 +606,21 @@ tf_proto_library(
protodeps = tf_additional_all_protos(), protodeps = tf_additional_all_protos(),
) )
tensorrt_static_define_oss_placeholder = []
# @unused
# Copybara will replace tensorrt_static_define_oss_placeholder with this one
tensorrt_static_define_oss = select({
"@local_config_tensorrt//:use_static_tensorrt": ["TF_OSS_TENSORRT_STATIC=1"],
"//conditions:default": [],
})
cc_library( cc_library(
name = "py_utils", name = "py_utils",
srcs = ["utils/py_utils.cc"], srcs = ["utils/py_utils.cc"],
hdrs = ["utils/py_utils.h"], hdrs = ["utils/py_utils.h"],
copts = tf_copts(), copts = tf_copts(),
defines = tensorrt_static_define_oss,
deps = if_tensorrt([ deps = if_tensorrt([
":common_utils", ":common_utils",
":tensorrt_lib", ":tensorrt_lib",

View File

@ -26,6 +26,10 @@ namespace tensorrt {
bool IsGoogleTensorRTEnabled() { bool IsGoogleTensorRTEnabled() {
#if GOOGLE_CUDA && GOOGLE_TENSORRT #if GOOGLE_CUDA && GOOGLE_TENSORRT
#if TF_OSS_TENSORRT_STATIC
LOG(INFO) << "TensorRT libraries are statically linked, skip dlopen check";
return true;
#else
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries(); auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
if (!handle_or.ok()) { if (!handle_or.ok()) {
LOG_WARNING_WITH_PREFIX LOG_WARNING_WITH_PREFIX
@ -36,6 +40,7 @@ bool IsGoogleTensorRTEnabled() {
} else { } else {
return true; return true;
} }
#endif
#else #else
return false; return false;
#endif #endif

View File

@ -10,6 +10,11 @@ package(default_visibility = ["//visibility:public"])
exports_files(["LICENSE"]) exports_files(["LICENSE"])
config_setting(
name = "use_static_tensorrt",
define_values = {"TF_TENSORRT_STATIC":"1"},
)
cc_library( cc_library(
name = "tensorrt_headers", name = "tensorrt_headers",
hdrs = [ hdrs = [
@ -22,12 +27,19 @@ cc_library(
cc_library( cc_library(
name = "tensorrt", name = "tensorrt",
srcs = [":tensorrt_lib"], srcs = select({
":use_static_tensorrt": [":tensorrt_static_lib"],
"//conditions:default": [":tensorrt_lib"],
}),
copts = cuda_default_copts(), copts = cuda_default_copts(),
data = [":tensorrt_lib"], data = select({
":use_static_tensorrt": [],
"//conditions:default": [":tensorrt_lib"],
}),
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":tensorrt_headers", ":tensorrt_headers",
# TODO(b/174608722): fix this line.
"@local_config_cuda//cuda", "@local_config_cuda//cuda",
], ],
) )

View File

@ -20,6 +20,7 @@ load(
) )
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO" _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" _TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
_TF_NEED_TENSORRT = "TF_NEED_TENSORRT" _TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
@ -91,6 +92,10 @@ def enable_tensorrt(repository_ctx):
"""Returns whether to build with TensorRT support.""" """Returns whether to build with TensorRT support."""
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False)) return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
def _get_tensorrt_static_path(repository_ctx):
"""Returns the path for TensorRT static libraries."""
return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
def _create_local_tensorrt_repository(repository_ctx): def _create_local_tensorrt_repository(repository_ctx):
# Resolve all labels before doing any real work. Resolving causes the # Resolve all labels before doing any real work. Resolving causes the
# function to be restarted with all previous state being lost. This # function to be restarted with all previous state being lost. This
@ -110,6 +115,7 @@ def _create_local_tensorrt_repository(repository_ctx):
# Copy the library and header files. # Copy the library and header files.
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
library_dir = config["tensorrt_library_dir"] + "/" library_dir = config["tensorrt_library_dir"] + "/"
headers = _get_tensorrt_headers(trt_version) headers = _get_tensorrt_headers(trt_version)
include_dir = config["tensorrt_include_dir"] + "/" include_dir = config["tensorrt_include_dir"] + "/"
@ -128,6 +134,19 @@ def _create_local_tensorrt_repository(repository_ctx):
), ),
] ]
tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in raw_static_library_names]
if tensorrt_static_path != None:
copy_rules = copy_rules + [
make_copy_files_rule(
repository_ctx,
name = "tensorrt_static_lib",
srcs = [tensorrt_static_path + library for library in static_libraries],
outs = ["tensorrt/lib/" + library for library in static_libraries],
),
]
# Set up config file. # Set up config file.
repository_ctx.template( repository_ctx.template(
"build_defs.bzl", "build_defs.bzl",