diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index ae2e91bcac2..a3e29738d35 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -280,6 +280,14 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRangeMax() const override { return 0.f; } #endif +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + void setAllowedFormats(nvinfer1::TensorFormats formats) override {} + + nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } + + bool isShape() const override { return false; } +#endif + private: string name_; nvinfer1::Dims dims_; diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index f6c86ad702e..24e13190d0f 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -390,7 +390,8 @@ def _find_tensorrt_config(base_paths, required_version): get_header_version) if ".." in header_version: - header_path, header_version = _find_header(base_paths, "NvInferRTSafe.h", + # From TRT 6.0 onwards, version information has been moved to NvInferVersion.h. + header_path, header_version = _find_header(base_paths, "NvInferVersion.h", required_version, get_header_version) diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index 3c5550abc9d..381a183049a 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -27,6 +27,16 @@ _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" _DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" +def _at_least_version(actual_version, required_version): + actual = [int(v) for v in actual_version.split(".")] + required = [int(v) for v in required_version.split(".")] + return actual >= required + +def _update_tensorrt_headers(tensorrt_version): + if not _at_least_version(tensorrt_version, "6"): + return + _TF_TENSORRT_HEADERS = ["NvInferVersion.h", "NvInfer.h", "NvUtils.h", "NvInferPlugin.h", "NvInferRTSafe.h", "NvInferRTExt.h", "NvInferPluginUtils.h"] + def _tpl(repository_ctx, tpl, substitutions): repository_ctx.template( tpl, @@ -69,6 +79,7 @@ def _tensorrt_configure_impl(repository_ctx): cpu_value = get_cpu_value(repository_ctx) # Copy the library and header files. + _update_tensorrt_headers(trt_version) libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] library_dir = config["tensorrt_library_dir"] + "/" headers = _TF_TENSORRT_HEADERS