Merge pull request #28604 from jhalakpatel:build_tf_with_trt_6
PiperOrigin-RevId: 247755315
This commit is contained in:
commit
7bd23b650b
@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
|
|||||||
float getDynamicRangeMax() const override { return 0.f; }
|
float getDynamicRangeMax() const override { return 0.f; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||||
|
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||||
|
|
||||||
|
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||||
|
|
||||||
|
bool isShape() const override { return false; }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
nvinfer1::DataType trt_dtype_;
|
nvinfer1::DataType trt_dtype_;
|
||||||
nvinfer1::Dims trt_dims_;
|
nvinfer1::Dims trt_dims_;
|
||||||
|
@ -280,6 +280,14 @@ class FakeITensor : public nvinfer1::ITensor {
|
|||||||
float getDynamicRangeMax() const override { return 0.f; }
|
float getDynamicRangeMax() const override { return 0.f; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||||
|
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||||
|
|
||||||
|
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||||
|
|
||||||
|
bool isShape() const override { return false; }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string name_;
|
string name_;
|
||||||
nvinfer1::Dims dims_;
|
nvinfer1::Dims dims_;
|
||||||
|
7
third_party/gpus/find_cuda_config.py
vendored
7
third_party/gpus/find_cuda_config.py
vendored
@ -389,6 +389,13 @@ def _find_tensorrt_config(base_paths, required_version):
|
|||||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||||
required_version,
|
required_version,
|
||||||
get_header_version)
|
get_header_version)
|
||||||
|
|
||||||
|
if ".." in header_version:
|
||||||
|
# From TRT 6.0 onwards, version information has been moved to NvInferVersion.h.
|
||||||
|
header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
|
||||||
|
required_version,
|
||||||
|
get_header_version)
|
||||||
|
|
||||||
tensorrt_version = header_version.split(".")[0]
|
tensorrt_version = header_version.split(".")[0]
|
||||||
|
|
||||||
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
||||||
|
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
19
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -27,6 +27,24 @@ _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
|||||||
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
||||||
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
|
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
|
||||||
|
|
||||||
|
def _at_least_version(actual_version, required_version):
|
||||||
|
actual = [int(v) for v in actual_version.split(".")]
|
||||||
|
required = [int(v) for v in required_version.split(".")]
|
||||||
|
return actual >= required
|
||||||
|
|
||||||
|
def _update_tensorrt_headers(tensorrt_version):
|
||||||
|
if not _at_least_version(tensorrt_version, "6"):
|
||||||
|
return
|
||||||
|
_TF_TENSORRT_HEADERS = [
|
||||||
|
"NvInferVersion.h",
|
||||||
|
"NvInfer.h",
|
||||||
|
"NvUtils.h",
|
||||||
|
"NvInferPlugin.h",
|
||||||
|
"NvInferRTSafe.h",
|
||||||
|
"NvInferRTExt.h",
|
||||||
|
"NvInferPluginUtils.h",
|
||||||
|
]
|
||||||
|
|
||||||
def _tpl(repository_ctx, tpl, substitutions):
|
def _tpl(repository_ctx, tpl, substitutions):
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
tpl,
|
tpl,
|
||||||
@ -69,6 +87,7 @@ def _tensorrt_configure_impl(repository_ctx):
|
|||||||
cpu_value = get_cpu_value(repository_ctx)
|
cpu_value = get_cpu_value(repository_ctx)
|
||||||
|
|
||||||
# Copy the library and header files.
|
# Copy the library and header files.
|
||||||
|
_update_tensorrt_headers(trt_version)
|
||||||
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
|
||||||
library_dir = config["tensorrt_library_dir"] + "/"
|
library_dir = config["tensorrt_library_dir"] + "/"
|
||||||
headers = _TF_TENSORRT_HEADERS
|
headers = _TF_TENSORRT_HEADERS
|
||||||
|
Loading…
Reference in New Issue
Block a user