Merge pull request #28633 from jhalakpatel:fix_trt6_version_parsing
PiperOrigin-RevId: 248706905
This commit is contained in:
commit
cc03fdce67
15
third_party/gpus/find_cuda_config.py
vendored
15
third_party/gpus/find_cuda_config.py
vendored
@ -384,20 +384,21 @@ def _find_tensorrt_config(base_paths, required_version):
|
|||||||
_get_header_version(path, name)
|
_get_header_version(path, name)
|
||||||
for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR",
|
for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR",
|
||||||
"NV_TENSORRT_PATCH"))
|
"NV_TENSORRT_PATCH"))
|
||||||
|
if not all(version):
|
||||||
|
return None # Versions not found, make _matches_version returns False.
|
||||||
return ".".join(version)
|
return ".".join(version)
|
||||||
|
|
||||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
try:
|
||||||
required_version,
|
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||||
get_header_version)
|
required_version,
|
||||||
|
get_header_version)
|
||||||
if ".." in header_version:
|
except ConfigError:
|
||||||
# From TRT 6.0 onwards, version information has been moved to NvInferVersion.h.
|
# TensorRT 6 moved the version information to NvInferVersion.h.
|
||||||
header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
|
header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
|
||||||
required_version,
|
required_version,
|
||||||
get_header_version)
|
get_header_version)
|
||||||
|
|
||||||
tensorrt_version = header_version.split(".")[0]
|
tensorrt_version = header_version.split(".")[0]
|
||||||
|
|
||||||
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
28
third_party/tensorrt/tensorrt_configure.bzl
vendored
28
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -22,6 +22,15 @@ _TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
|
|||||||
|
|
||||||
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
|
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
|
||||||
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
|
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
|
||||||
|
_TF_TENSORRT_HEADERS_V6 = [
|
||||||
|
"NvInfer.h",
|
||||||
|
"NvUtils.h",
|
||||||
|
"NvInferPlugin.h",
|
||||||
|
"NvInferVersion.h",
|
||||||
|
"NvInferRTSafe.h",
|
||||||
|
"NvInferRTExt.h",
|
||||||
|
"NvInferPluginUtils.h",
|
||||||
|
]
|
||||||
|
|
||||||
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
||||||
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
||||||
@ -32,18 +41,10 @@ def _at_least_version(actual_version, required_version):
|
|||||||
required = [int(v) for v in required_version.split(".")]
|
required = [int(v) for v in required_version.split(".")]
|
||||||
return actual >= required
|
return actual >= required
|
||||||
|
|
||||||
def _update_tensorrt_headers(tensorrt_version):
|
def _get_tensorrt_headers(tensorrt_version):
|
||||||
if not _at_least_version(tensorrt_version, "6"):
|
if _at_least_version(tensorrt_version, "6"):
|
||||||
return
|
return _TF_TENSORRT_HEADERS_V6
|
||||||
_TF_TENSORRT_HEADERS = [
|
return _TF_TENSORRT_HEADERS
|
||||||
"NvInferVersion.h",
|
|
||||||
"NvInfer.h",
|
|
||||||
"NvUtils.h",
|
|
||||||
"NvInferPlugin.h",
|
|
||||||
"NvInferRTSafe.h",
|
|
||||||
"NvInferRTExt.h",
|
|
||||||
"NvInferPluginUtils.h",
|
|
||||||
]
|
|
||||||
|
|
||||||
def _tpl(repository_ctx, tpl, substitutions):
|
def _tpl(repository_ctx, tpl, substitutions):
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
@ -87,10 +88,9 @@ def _tensorrt_configure_impl(repository_ctx):
|
|||||||
cpu_value = get_cpu_value(repository_ctx)
|
cpu_value = get_cpu_value(repository_ctx)
|
||||||
|
|
||||||
# Copy the library and header files.
|
# Copy the library and header files.
|
||||||
_update_tensorrt_headers(trt_version)
|
|
||||||
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
||||||
library_dir = config["tensorrt_library_dir"] + "/"
|
library_dir = config["tensorrt_library_dir"] + "/"
|
||||||
headers = _TF_TENSORRT_HEADERS
|
headers = _get_tensorrt_headers(trt_version)
|
||||||
include_dir = config["tensorrt_include_dir"] + "/"
|
include_dir = config["tensorrt_include_dir"] + "/"
|
||||||
copy_rules = [
|
copy_rules = [
|
||||||
make_copy_files_rule(
|
make_copy_files_rule(
|
||||||
|
Loading…
Reference in New Issue
Block a user