[TF:TRT] Support OSS build that uses TensorRT static libraries.
Allow the building of TensorFlow that uses TensorRT static libraries by adding options --define=TF_TENSORRT_STATIC=1 --action_env=TF_TENSORRT_STATIC_PATH= to the bazel command. PiperOrigin-RevId: 345503209 Change-Id: I3aeb39a672b1f9a7c2bab7f9f6018c304685c302
This commit is contained in:
parent
c10504b9fa
commit
b1660e7e18
@ -50,10 +50,26 @@ cc_library(
|
||||
]),
|
||||
)
|
||||
|
||||
# Copybara will replace tensorrt_lib_oss_placeholder with this one
|
||||
alias(
|
||||
name = "tensorrt_lib_oss",
|
||||
actual = select({
|
||||
"@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt",
|
||||
"//conditions:default": ":tensorrt_stub",
|
||||
}),
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorrt_lib_oss_placeholder",
|
||||
actual = ":tensorrt_stub",
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorrt_lib",
|
||||
actual = select({
|
||||
"//tensorflow:oss": ":tensorrt_stub",
|
||||
"//tensorflow:oss": ":tensorrt_lib_oss",
|
||||
"//conditions:default": "@local_config_tensorrt//:tensorrt",
|
||||
}),
|
||||
visibility = ["//visibility:private"],
|
||||
@ -590,11 +606,21 @@ tf_proto_library(
|
||||
protodeps = tf_additional_all_protos(),
|
||||
)
|
||||
|
||||
tensorrt_static_define_oss_placeholder = []
|
||||
|
||||
# @unused
|
||||
# Copybara will replace tensorrt_static_define_oss_placeholder with this one
|
||||
tensorrt_static_define_oss = select({
|
||||
"@local_config_tensorrt//:use_static_tensorrt": ["TF_OSS_TENSORRT_STATIC=1"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
cc_library(
|
||||
name = "py_utils",
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
hdrs = ["utils/py_utils.h"],
|
||||
copts = tf_copts(),
|
||||
defines = tensorrt_static_define_oss,
|
||||
deps = if_tensorrt([
|
||||
":common_utils",
|
||||
":tensorrt_lib",
|
||||
|
@ -26,6 +26,10 @@ namespace tensorrt {
|
||||
|
||||
bool IsGoogleTensorRTEnabled() {
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#if TF_OSS_TENSORRT_STATIC
|
||||
LOG(INFO) << "TensorRT libraries are statically linked, skip dlopen check";
|
||||
return true;
|
||||
#else
|
||||
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
|
||||
if (!handle_or.ok()) {
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
@ -36,6 +40,7 @@ bool IsGoogleTensorRTEnabled() {
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
16
third_party/tensorrt/BUILD.tpl
vendored
16
third_party/tensorrt/BUILD.tpl
vendored
@ -10,6 +10,11 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
config_setting(
|
||||
name = "use_static_tensorrt",
|
||||
define_values = {"TF_TENSORRT_STATIC":"1"},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt_headers",
|
||||
hdrs = [
|
||||
@ -22,12 +27,19 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt",
|
||||
srcs = [":tensorrt_lib"],
|
||||
srcs = select({
|
||||
":use_static_tensorrt": [":tensorrt_static_lib"],
|
||||
"//conditions:default": [":tensorrt_lib"],
|
||||
}),
|
||||
copts = cuda_default_copts(),
|
||||
data = [":tensorrt_lib"],
|
||||
data = select({
|
||||
":use_static_tensorrt": [],
|
||||
"//conditions:default": [":tensorrt_lib"],
|
||||
}),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":tensorrt_headers",
|
||||
# TODO(b/174608722): fix this line.
|
||||
"@local_config_cuda//cuda",
|
||||
],
|
||||
)
|
||||
|
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -20,6 +20,7 @@ load(
|
||||
)
|
||||
|
||||
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
||||
_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
|
||||
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
|
||||
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
|
||||
_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
|
||||
@ -91,6 +92,10 @@ def enable_tensorrt(repository_ctx):
|
||||
"""Returns whether to build with TensorRT support."""
|
||||
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||
|
||||
def _get_tensorrt_static_path(repository_ctx):
|
||||
"""Returns the path for TensorRT static libraries."""
|
||||
return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
|
||||
|
||||
def _create_local_tensorrt_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
@ -110,6 +115,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
|
||||
# Copy the library and header files.
|
||||
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
||||
|
||||
library_dir = config["tensorrt_library_dir"] + "/"
|
||||
headers = _get_tensorrt_headers(trt_version)
|
||||
include_dir = config["tensorrt_include_dir"] + "/"
|
||||
@ -128,6 +134,19 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
),
|
||||
]
|
||||
|
||||
tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
|
||||
raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
|
||||
static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in raw_static_library_names]
|
||||
if tensorrt_static_path != None:
|
||||
copy_rules = copy_rules + [
|
||||
make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "tensorrt_static_lib",
|
||||
srcs = [tensorrt_static_path + library for library in static_libraries],
|
||||
outs = ["tensorrt/lib/" + library for library in static_libraries],
|
||||
),
|
||||
]
|
||||
|
||||
# Set up config file.
|
||||
repository_ctx.template(
|
||||
"build_defs.bzl",
|
||||
|
Loading…
Reference in New Issue
Block a user