update version information file. Also upadate tensorrt bazel configuration file

This commit is contained in:
jhalakp 2019-05-09 15:55:58 -07:00
parent 81c5dde0c9
commit 8db2e909e5
3 changed files with 21 additions and 1 deletions

View File

@ -280,6 +280,14 @@ class FakeITensor : public nvinfer1::ITensor {
float getDynamicRangeMax() const override { return 0.f; }
#endif
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
bool isShape() const override { return false; }
#endif
private:
string name_;
nvinfer1::Dims dims_;

View File

@ -390,7 +390,8 @@ def _find_tensorrt_config(base_paths, required_version):
get_header_version)
if ".." in header_version:
header_path, header_version = _find_header(base_paths, "NvInferRTSafe.h",
# From TRT 6.0 onwards, version information has been moved to NvInferVersion.h.
header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
required_version,
get_header_version)

View File

@ -27,6 +27,16 @@ _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
def _at_least_version(actual_version, required_version):
actual = [int(v) for v in actual_version.split(".")]
required = [int(v) for v in required_version.split(".")]
return actual >= required
def _update_tensorrt_headers(tensorrt_version):
if not _at_least_version(tensorrt_version, "6"):
return
_TF_TENSORRT_HEADERS = ["NvInferVersion.h", "NvInfer.h", "NvUtils.h", "NvInferPlugin.h", "NvInferRTSafe.h", "NvInferRTExt.h", "NvInferPluginUtils.h"]
def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template(
tpl,
@ -69,6 +79,7 @@ def _tensorrt_configure_impl(repository_ctx):
cpu_value = get_cpu_value(repository_ctx)
# Copy the library and header files.
_update_tensorrt_headers(trt_version)
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
library_dir = config["tensorrt_library_dir"] + "/"
headers = _TF_TENSORRT_HEADERS