Add TensorRT stub for 5.0 and 5.1.
This is the first step of enabling dynamic loading of TensorRT library in open source build. PiperOrigin-RevId: 253585107
This commit is contained in:
parent
a20602f88f
commit
a5b860a7cc
@ -31,6 +31,20 @@ package(
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensorrt_stub",
|
||||||
|
srcs = if_tensorrt([
|
||||||
|
"stub/nvinfer_stub.cc",
|
||||||
|
"stub/nvinfer_plugin_stub.cc",
|
||||||
|
]),
|
||||||
|
textual_hdrs = glob(["stub/*.inc"]),
|
||||||
|
deps = if_tensorrt([
|
||||||
|
"@local_config_tensorrt//:tensorrt_headers",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "tensorrt_test_cc",
|
name = "tensorrt_test_cc",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Auto-generated, do not edit.
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||||
|
int nmsMaxOut, float iouThreshold,
|
||||||
|
float minBoxSize, float spatialScale,
|
||||||
|
nvinfer1::DimsHW pooling,
|
||||||
|
nvinfer1::Weights anchorRatios,
|
||||||
|
nvinfer1::Weights anchorScales) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||||
|
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||||
|
bool acrossSpatial,
|
||||||
|
bool channelShared, float eps) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||||
|
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||||
|
nvinfer1::plugin::PriorBoxParameters param) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||||
|
return func_ptr(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||||
|
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||||
|
return func_ptr(param, numLayers);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createNMSPlugin(
|
||||||
|
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||||
|
return func_ptr(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||||
|
return func_ptr(negSlope);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||||
|
return func_ptr(stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createRegionPlugin(
|
||||||
|
nvinfer1::plugin::RegionParameters params) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||||
|
return func_ptr(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||||
|
float clipMax) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||||
|
return func_ptr(layerName, clipMin, clipMax);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||||
|
using FuncPtr = bool ( *)(void *, const char *);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||||
|
return func_ptr(logger, libNamespace);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Auto-generated, do not edit.
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||||
|
int nmsMaxOut, float iouThreshold,
|
||||||
|
float minBoxSize, float spatialScale,
|
||||||
|
nvinfer1::DimsHW pooling,
|
||||||
|
nvinfer1::Weights anchorRatios,
|
||||||
|
nvinfer1::Weights anchorScales) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||||
|
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||||
|
bool acrossSpatial,
|
||||||
|
bool channelShared, float eps) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||||
|
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||||
|
nvinfer1::plugin::PriorBoxParameters param) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||||
|
return func_ptr(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||||
|
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||||
|
return func_ptr(param, numLayers);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createNMSPlugin(
|
||||||
|
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||||
|
return func_ptr(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||||
|
return func_ptr(negSlope);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||||
|
return func_ptr(stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createRegionPlugin(
|
||||||
|
nvinfer1::plugin::RegionParameters params) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||||
|
return func_ptr(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||||
|
float clipMax) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||||
|
return func_ptr(layerName, clipMin, clipMax);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* createBatchedNMSPlugin(
|
||||||
|
nvinfer1::plugin::NMSParameters param) {
|
||||||
|
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::NMSParameters);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createBatchedNMSPlugin");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createBatchedNMSPlugin");
|
||||||
|
return func_ptr(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||||
|
using FuncPtr = bool ( *)(void *, const char *);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||||
|
return func_ptr(logger, libNamespace);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Auto-generated, do not edit.
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||||
|
using FuncPtr = void * (*)(void *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||||
|
return func_ptr(logger, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||||
|
using FuncPtr = void * (*)(void *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||||
|
return func_ptr(logger, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::ILogger* getLogger() {
|
||||||
|
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getInferLibVersion() {
|
||||||
|
using FuncPtr = int (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||||
|
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
// Auto-generated, do not edit.
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||||
|
using FuncPtr = void * (*)(void *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||||
|
return func_ptr(logger, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* createInferRefitter_INTERNAL(void* engine, void* logger, int version) {
|
||||||
|
using FuncPtr = void * (*)(void *, void *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRefitter_INTERNAL");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createInferRefitter_INTERNAL");
|
||||||
|
return func_ptr(engine, logger, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||||
|
using FuncPtr = void * (*)(void *, int);
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||||
|
return func_ptr(logger, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::ILogger* getLogger() {
|
||||||
|
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getInferLibVersion() {
|
||||||
|
using FuncPtr = int (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||||
|
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||||
|
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||||
|
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||||
|
return func_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
|
#include "third_party/tensorrt/NvInferPlugin.h"
|
||||||
|
|
||||||
|
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Returns DSO handle or null if loading the DSO fails.
|
||||||
|
void* GetDsoHandle() {
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
return nullptr;
|
||||||
|
#else
|
||||||
|
static auto handle = []() -> void* {
|
||||||
|
auto handle_or =
|
||||||
|
stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle();
|
||||||
|
if (!handle_or.ok()) return nullptr;
|
||||||
|
return handle_or.ValueOrDie();
|
||||||
|
}();
|
||||||
|
return handle;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T LoadSymbol(const char* symbol_name) {
|
||||||
|
void* symbol = nullptr;
|
||||||
|
if (auto handle = GetDsoHandle()) {
|
||||||
|
tensorflow::Env::Default()
|
||||||
|
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||||
|
.IgnoreError();
|
||||||
|
}
|
||||||
|
return reinterpret_cast<T>(symbol);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||||
|
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#if NV_TENSORRT_MAJOR < 5
|
||||||
|
#error TensorRT version earlier than 5 is not supported.
|
||||||
|
#elif NV_TENSORRT_MINOR < 1
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc"
|
||||||
|
#else
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc"
|
||||||
|
#endif
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
|
#include "third_party/tensorrt/NvInfer.h"
|
||||||
|
|
||||||
|
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Returns DSO handle or null if loading the DSO fails.
|
||||||
|
void* GetDsoHandle() {
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
return nullptr;
|
||||||
|
#else
|
||||||
|
static auto handle = []() -> void* {
|
||||||
|
auto handle_or =
|
||||||
|
stream_executor::internal::DsoLoader::GetNvInferDsoHandle();
|
||||||
|
if (!handle_or.ok()) return nullptr;
|
||||||
|
return handle_or.ValueOrDie();
|
||||||
|
}();
|
||||||
|
return handle;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T LoadSymbol(const char* symbol_name) {
|
||||||
|
void* symbol = nullptr;
|
||||||
|
if (auto handle = GetDsoHandle()) {
|
||||||
|
tensorflow::Env::Default()
|
||||||
|
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||||
|
.IgnoreError();
|
||||||
|
}
|
||||||
|
return reinterpret_cast<T>(symbol);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||||
|
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#if NV_TENSORRT_MAJOR < 5
|
||||||
|
#error TensorRT version earlier than 5 is not supported.
|
||||||
|
#elif NV_TENSORRT_MINOR < 1
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc"
|
||||||
|
#else
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc"
|
||||||
|
#endif
|
@ -33,6 +33,11 @@ string GetCudaVersion() { return TF_CUDA_VERSION; }
|
|||||||
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
|
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
|
||||||
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
||||||
|
|
||||||
|
// TODO(laigd): populate the version string during configuration process. For
|
||||||
|
// now hardcoded version 5 since 4.0 is not supported anyway.
|
||||||
|
#define TF_TENSORRT_VERSION "5"
|
||||||
|
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
|
||||||
|
|
||||||
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||||
auto filename = port::Env::Default()->FormatLibraryFileName(name, version);
|
auto filename = port::Env::Default()->FormatLibraryFileName(name, version);
|
||||||
void* dso_handle;
|
void* dso_handle;
|
||||||
@ -108,6 +113,14 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
|
|||||||
return GetDsoHandle("cudnn", GetCudnnVersion());
|
return GetDsoHandle("cudnn", GetCudnnVersion());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetNvInferDsoHandle() {
|
||||||
|
return GetDsoHandle("nvinfer", GetTensorRTVersion());
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetNvInferPluginDsoHandle() {
|
||||||
|
return GetDsoHandle("nvinfer_plugin", GetTensorRTVersion());
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||||
return GetDsoHandle("rocblas", "");
|
return GetDsoHandle("rocblas", "");
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,8 @@ port::StatusOr<void*> GetCusolverDsoHandle();
|
|||||||
port::StatusOr<void*> GetCusparseDsoHandle();
|
port::StatusOr<void*> GetCusparseDsoHandle();
|
||||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||||
|
port::StatusOr<void*> GetNvInferDsoHandle();
|
||||||
|
port::StatusOr<void*> GetNvInferPluginDsoHandle();
|
||||||
|
|
||||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||||
|
Loading…
Reference in New Issue
Block a user