update tf cuda config and tf2tensort convertor to build with trt_6
This commit is contained in:
parent
e00241511d
commit
81c5dde0c9
@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
|
||||
float getDynamicRangeMax() const override { return 0.f; }
|
||||
#endif
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||
|
||||
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||
|
||||
bool isShape() const override { return false; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
nvinfer1::DataType trt_dtype_;
|
||||
nvinfer1::Dims trt_dims_;
|
||||
|
6
third_party/gpus/find_cuda_config.py
vendored
6
third_party/gpus/find_cuda_config.py
vendored
@ -388,6 +388,12 @@ def _find_tensorrt_config(base_paths, required_version):
|
||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
|
||||
if ".." in header_version:
|
||||
header_path, header_version = _find_header(base_paths, "NvInferRTSafe.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
|
||||
tensorrt_version = header_version.split(".")[0]
|
||||
|
||||
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
||||
|
Loading…
Reference in New Issue
Block a user