Merge pull request #28633 from jhalakpatel:fix_trt6_version_parsing
PiperOrigin-RevId: 248706905
This commit is contained in:
commit
cc03fdce67
15
third_party/gpus/find_cuda_config.py
vendored
15
third_party/gpus/find_cuda_config.py
vendored
@ -384,20 +384,21 @@ def _find_tensorrt_config(base_paths, required_version):
|
||||
_get_header_version(path, name)
|
||||
for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR",
|
||||
"NV_TENSORRT_PATCH"))
|
||||
if not all(version):
|
||||
return None # Versions not found, make _matches_version returns False.
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
|
||||
if ".." in header_version:
|
||||
# From TRT 6.0 onwards, version information has been moved to NvInferVersion.h.
|
||||
try:
|
||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
except ConfigError:
|
||||
# TensorRT 6 moved the version information to NvInferVersion.h.
|
||||
header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
|
||||
tensorrt_version = header_version.split(".")[0]
|
||||
|
||||
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
||||
|
||||
return {
|
||||
|
28
third_party/tensorrt/tensorrt_configure.bzl
vendored
28
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -22,6 +22,15 @@ _TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
|
||||
|
||||
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
|
||||
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
|
||||
_TF_TENSORRT_HEADERS_V6 = [
|
||||
"NvInfer.h",
|
||||
"NvUtils.h",
|
||||
"NvInferPlugin.h",
|
||||
"NvInferVersion.h",
|
||||
"NvInferRTSafe.h",
|
||||
"NvInferRTExt.h",
|
||||
"NvInferPluginUtils.h",
|
||||
]
|
||||
|
||||
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
||||
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
||||
@ -32,18 +41,10 @@ def _at_least_version(actual_version, required_version):
|
||||
required = [int(v) for v in required_version.split(".")]
|
||||
return actual >= required
|
||||
|
||||
def _update_tensorrt_headers(tensorrt_version):
|
||||
if not _at_least_version(tensorrt_version, "6"):
|
||||
return
|
||||
_TF_TENSORRT_HEADERS = [
|
||||
"NvInferVersion.h",
|
||||
"NvInfer.h",
|
||||
"NvUtils.h",
|
||||
"NvInferPlugin.h",
|
||||
"NvInferRTSafe.h",
|
||||
"NvInferRTExt.h",
|
||||
"NvInferPluginUtils.h",
|
||||
]
|
||||
def _get_tensorrt_headers(tensorrt_version):
|
||||
if _at_least_version(tensorrt_version, "6"):
|
||||
return _TF_TENSORRT_HEADERS_V6
|
||||
return _TF_TENSORRT_HEADERS
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions):
|
||||
repository_ctx.template(
|
||||
@ -87,10 +88,9 @@ def _tensorrt_configure_impl(repository_ctx):
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
|
||||
# Copy the library and header files.
|
||||
_update_tensorrt_headers(trt_version)
|
||||
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
||||
library_dir = config["tensorrt_library_dir"] + "/"
|
||||
headers = _TF_TENSORRT_HEADERS
|
||||
headers = _get_tensorrt_headers(trt_version)
|
||||
include_dir = config["tensorrt_include_dir"] + "/"
|
||||
copy_rules = [
|
||||
make_copy_files_rule(
|
||||
|
Loading…
Reference in New Issue
Block a user