update tf cuda config and tf2tensort convertor to build with trt_6
This commit is contained in:
parent
e00241511d
commit
81c5dde0c9
@ -790,6 +790,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
|
|||||||
float getDynamicRangeMax() const override { return 0.f; }
|
float getDynamicRangeMax() const override { return 0.f; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||||
|
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
|
||||||
|
|
||||||
|
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
|
||||||
|
|
||||||
|
bool isShape() const override { return false; }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
nvinfer1::DataType trt_dtype_;
|
nvinfer1::DataType trt_dtype_;
|
||||||
nvinfer1::Dims trt_dims_;
|
nvinfer1::Dims trt_dims_;
|
||||||
|
6
third_party/gpus/find_cuda_config.py
vendored
6
third_party/gpus/find_cuda_config.py
vendored
@ -388,6 +388,12 @@ def _find_tensorrt_config(base_paths, required_version):
|
|||||||
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
header_path, header_version = _find_header(base_paths, "NvInfer.h",
|
||||||
required_version,
|
required_version,
|
||||||
get_header_version)
|
get_header_version)
|
||||||
|
|
||||||
|
if ".." in header_version:
|
||||||
|
header_path, header_version = _find_header(base_paths, "NvInferRTSafe.h",
|
||||||
|
required_version,
|
||||||
|
get_header_version)
|
||||||
|
|
||||||
tensorrt_version = header_version.split(".")[0]
|
tensorrt_version = header_version.split(".")[0]
|
||||||
|
|
||||||
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
|
||||||
|
Loading…
Reference in New Issue
Block a user