diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index f568b947959..21e25fde582 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { float getDynamicRangeMax() const override { return 0.f; } #endif +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + void setAllowedFormats(nvinfer1::TensorFormats formats) override {} + + nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } + + bool isShape() const override { return false; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 576456e3646..f6c86ad702e 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -388,6 +388,12 @@ def _find_tensorrt_config(base_paths, required_version): header_path, header_version = _find_header(base_paths, "NvInfer.h", required_version, get_header_version) + + if ".." in header_version: + header_path, header_version = _find_header(base_paths, "NvInferRTSafe.h", + required_version, + get_header_version) + tensorrt_version = header_version.split(".")[0] library_path = _find_library(base_paths, "nvinfer", tensorrt_version)