Merge pull request #28633 from jhalakpatel:fix_trt6_version_parsing

PiperOrigin-RevId: 248706905
This commit is contained in:
TensorFlower Gardener 2019-05-17 06:14:26 -07:00
commit cc03fdce67
2 changed files with 22 additions and 21 deletions

View File

@ -384,20 +384,21 @@ def _find_tensorrt_config(base_paths, required_version):
_get_header_version(path, name) _get_header_version(path, name)
for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR", for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR",
"NV_TENSORRT_PATCH")) "NV_TENSORRT_PATCH"))
if not all(version):
return None # Versions not found, make _matches_version returns False.
return ".".join(version) return ".".join(version)
header_path, header_version = _find_header(base_paths, "NvInfer.h", try:
required_version, header_path, header_version = _find_header(base_paths, "NvInfer.h",
get_header_version) required_version,
get_header_version)
if ".." in header_version: except ConfigError:
# From TRT 6.0 onwards, version information has been moved to NvInferVersion.h. # TensorRT 6 moved the version information to NvInferVersion.h.
header_path, header_version = _find_header(base_paths, "NvInferVersion.h", header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
required_version, required_version,
get_header_version) get_header_version)
tensorrt_version = header_version.split(".")[0] tensorrt_version = header_version.split(".")[0]
library_path = _find_library(base_paths, "nvinfer", tensorrt_version) library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
return { return {

View File

@ -22,6 +22,15 @@ _TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"] _TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"] _TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
_TF_TENSORRT_HEADERS_V6 = [
"NvInfer.h",
"NvUtils.h",
"NvInferPlugin.h",
"NvInferVersion.h",
"NvInferRTSafe.h",
"NvInferRTExt.h",
"NvInferPluginUtils.h",
]
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
@ -32,18 +41,10 @@ def _at_least_version(actual_version, required_version):
required = [int(v) for v in required_version.split(".")] required = [int(v) for v in required_version.split(".")]
return actual >= required return actual >= required
def _update_tensorrt_headers(tensorrt_version): def _get_tensorrt_headers(tensorrt_version):
if not _at_least_version(tensorrt_version, "6"): if _at_least_version(tensorrt_version, "6"):
return return _TF_TENSORRT_HEADERS_V6
_TF_TENSORRT_HEADERS = [ return _TF_TENSORRT_HEADERS
"NvInferVersion.h",
"NvInfer.h",
"NvUtils.h",
"NvInferPlugin.h",
"NvInferRTSafe.h",
"NvInferRTExt.h",
"NvInferPluginUtils.h",
]
def _tpl(repository_ctx, tpl, substitutions): def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template( repository_ctx.template(
@ -87,10 +88,9 @@ def _tensorrt_configure_impl(repository_ctx):
cpu_value = get_cpu_value(repository_ctx) cpu_value = get_cpu_value(repository_ctx)
# Copy the library and header files. # Copy the library and header files.
_update_tensorrt_headers(trt_version)
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
library_dir = config["tensorrt_library_dir"] + "/" library_dir = config["tensorrt_library_dir"] + "/"
headers = _TF_TENSORRT_HEADERS headers = _get_tensorrt_headers(trt_version)
include_dir = config["tensorrt_include_dir"] + "/" include_dir = config["tensorrt_include_dir"] + "/"
copy_rules = [ copy_rules = [
make_copy_files_rule( make_copy_files_rule(