Add TensorRT stub for 5.0 and 5.1.
This is the first step of enabling dynamic loading of TensorRT library in open source build. PiperOrigin-RevId: 253585107
This commit is contained in:
parent
a20602f88f
commit
a5b860a7cc
@ -31,6 +31,20 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt_stub",
|
||||
srcs = if_tensorrt([
|
||||
"stub/nvinfer_stub.cc",
|
||||
"stub/nvinfer_plugin_stub.cc",
|
||||
]),
|
||||
textual_hdrs = glob(["stub/*.inc"]),
|
||||
deps = if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
size = "small",
|
||||
|
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
@ -0,0 +1,87 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||
int nmsMaxOut, float iouThreshold,
|
||||
float minBoxSize, float spatialScale,
|
||||
nvinfer1::DimsHW pooling,
|
||||
nvinfer1::Weights anchorRatios,
|
||||
nvinfer1::Weights anchorScales) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||
bool acrossSpatial,
|
||||
bool channelShared, float eps) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||
nvinfer1::plugin::PriorBoxParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||
return func_ptr(param, numLayers);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNMSPlugin(
|
||||
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||
return func_ptr(negSlope);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||
return func_ptr(stride);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createRegionPlugin(
|
||||
nvinfer1::plugin::RegionParameters params) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||
return func_ptr(params);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||
float clipMax) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||
return func_ptr(layerName, clipMin, clipMax);
|
||||
}
|
||||
|
||||
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||
using FuncPtr = bool ( *)(void *, const char *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||
return func_ptr(logger, libNamespace);
|
||||
}
|
||||
|
||||
} // extern "C"
|
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
@ -0,0 +1,95 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||
int nmsMaxOut, float iouThreshold,
|
||||
float minBoxSize, float spatialScale,
|
||||
nvinfer1::DimsHW pooling,
|
||||
nvinfer1::Weights anchorRatios,
|
||||
nvinfer1::Weights anchorScales) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||
bool acrossSpatial,
|
||||
bool channelShared, float eps) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||
nvinfer1::plugin::PriorBoxParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||
return func_ptr(param, numLayers);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNMSPlugin(
|
||||
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||
return func_ptr(negSlope);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||
return func_ptr(stride);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createRegionPlugin(
|
||||
nvinfer1::plugin::RegionParameters params) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||
return func_ptr(params);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||
float clipMax) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||
return func_ptr(layerName, clipMin, clipMax);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createBatchedNMSPlugin(
|
||||
nvinfer1::plugin::NMSParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::NMSParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createBatchedNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createBatchedNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||
using FuncPtr = bool ( *)(void *, const char *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||
return func_ptr(logger, libNamespace);
|
||||
}
|
||||
|
||||
} // extern "C"
|
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
@ -0,0 +1,40 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* getLogger() {
|
||||
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
int getInferLibVersion() {
|
||||
using FuncPtr = int (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
} // extern "C"
|
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
@ -0,0 +1,47 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
void* createInferRefitter_INTERNAL(void* engine, void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRefitter_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRefitter_INTERNAL");
|
||||
return func_ptr(engine, logger, version);
|
||||
}
|
||||
|
||||
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* getLogger() {
|
||||
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
int getInferLibVersion() {
|
||||
using FuncPtr = int (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
} // extern "C"
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "third_party/tensorrt/NvInferPlugin.h"
|
||||
|
||||
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
tensorflow::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if NV_TENSORRT_MAJOR < 5
|
||||
#error TensorRT version earlier than 5 is not supported.
|
||||
#elif NV_TENSORRT_MINOR < 1
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc"
|
||||
#else
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc"
|
||||
#endif
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetNvInferDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
tensorflow::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if NV_TENSORRT_MAJOR < 5
|
||||
#error TensorRT version earlier than 5 is not supported.
|
||||
#elif NV_TENSORRT_MINOR < 1
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc"
|
||||
#else
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc"
|
||||
#endif
|
@ -33,6 +33,11 @@ string GetCudaVersion() { return TF_CUDA_VERSION; }
|
||||
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
|
||||
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
||||
|
||||
// TODO(laigd): populate the version string during configuration process. For
|
||||
// now hardcoded version 5 since 4.0 is not supported anyway.
|
||||
#define TF_TENSORRT_VERSION "5"
|
||||
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
|
||||
|
||||
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
auto filename = port::Env::Default()->FormatLibraryFileName(name, version);
|
||||
void* dso_handle;
|
||||
@ -108,6 +113,14 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||
return GetDsoHandle("cudnn", GetCudnnVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetNvInferDsoHandle() {
|
||||
return GetDsoHandle("nvinfer", GetTensorRTVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetNvInferPluginDsoHandle() {
|
||||
return GetDsoHandle("nvinfer_plugin", GetTensorRTVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||
return GetDsoHandle("rocblas", "");
|
||||
}
|
||||
|
@ -43,6 +43,8 @@ port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
port::StatusOr<void*> GetCusparseDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
port::StatusOr<void*> GetNvInferDsoHandle();
|
||||
port::StatusOr<void*> GetNvInferPluginDsoHandle();
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
|
Loading…
Reference in New Issue
Block a user