Add TensorRT stub for 5.0 and 5.1.

This is the first step of enabling dynamic loading of TensorRT library in
open source build.

PiperOrigin-RevId: 253585107
This commit is contained in:
Guangda Lai 2019-06-17 08:29:47 -07:00 committed by TensorFlower Gardener
parent a20602f88f
commit a5b860a7cc
9 changed files with 416 additions and 0 deletions

View File

@ -31,6 +31,20 @@ package(
exports_files(["LICENSE"])
cc_library(
name = "tensorrt_stub",
srcs = if_tensorrt([
"stub/nvinfer_stub.cc",
"stub/nvinfer_plugin_stub.cc",
]),
textual_hdrs = glob(["stub/*.inc"]),
deps = if_tensorrt([
"@local_config_tensorrt//:tensorrt_headers",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/platform:dso_loader",
]),
)
tf_cuda_cc_test(
name = "tensorrt_test_cc",
size = "small",

View File

@ -0,0 +1,87 @@
// Auto-generated, do not edit.
extern "C" {
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
int nmsMaxOut, float iouThreshold,
float minBoxSize, float spatialScale,
nvinfer1::DimsHW pooling,
nvinfer1::Weights anchorRatios,
nvinfer1::Weights anchorScales) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
}
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
bool acrossSpatial,
bool channelShared, float eps) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
return func_ptr(scales, acrossSpatial, channelShared, eps);
}
nvinfer1::IPluginV2* createPriorBoxPlugin(
nvinfer1::plugin::PriorBoxParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
return func_ptr(param, numLayers);
}
nvinfer1::IPluginV2* createNMSPlugin(
nvinfer1::plugin::DetectionOutputParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
return func_ptr(negSlope);
}
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
return func_ptr(stride);
}
nvinfer1::IPluginV2* createRegionPlugin(
nvinfer1::plugin::RegionParameters params) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
return func_ptr(params);
}
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
float clipMax) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
return func_ptr(layerName, clipMin, clipMax);
}
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
using FuncPtr = bool ( *)(void *, const char *);
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
return func_ptr(logger, libNamespace);
}
} // extern "C"

View File

@ -0,0 +1,95 @@
// Auto-generated, do not edit.
extern "C" {
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
int nmsMaxOut, float iouThreshold,
float minBoxSize, float spatialScale,
nvinfer1::DimsHW pooling,
nvinfer1::Weights anchorRatios,
nvinfer1::Weights anchorScales) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
}
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
bool acrossSpatial,
bool channelShared, float eps) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
return func_ptr(scales, acrossSpatial, channelShared, eps);
}
nvinfer1::IPluginV2* createPriorBoxPlugin(
nvinfer1::plugin::PriorBoxParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
return func_ptr(param, numLayers);
}
nvinfer1::IPluginV2* createNMSPlugin(
nvinfer1::plugin::DetectionOutputParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
return func_ptr(negSlope);
}
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
return func_ptr(stride);
}
nvinfer1::IPluginV2* createRegionPlugin(
nvinfer1::plugin::RegionParameters params) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
return func_ptr(params);
}
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
float clipMax) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
return func_ptr(layerName, clipMin, clipMax);
}
nvinfer1::IPluginV2* createBatchedNMSPlugin(
nvinfer1::plugin::NMSParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::NMSParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createBatchedNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createBatchedNMSPlugin");
return func_ptr(param);
}
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
using FuncPtr = bool ( *)(void *, const char *);
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
return func_ptr(logger, libNamespace);
}
} // extern "C"

View File

@ -0,0 +1,40 @@
// Auto-generated, do not edit.
extern "C" {
void* createInferBuilder_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
return func_ptr(logger, version);
}
void* createInferRuntime_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
return func_ptr(logger, version);
}
nvinfer1::ILogger* getLogger() {
using FuncPtr = nvinfer1::ILogger * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
return func_ptr();
}
int getInferLibVersion() {
using FuncPtr = int (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
return func_ptr();
}
nvinfer1::IPluginRegistry* getPluginRegistry() {
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
return func_ptr();
}
} // extern "C"

View File

@ -0,0 +1,47 @@
// Auto-generated, do not edit.
extern "C" {
void* createInferBuilder_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
return func_ptr(logger, version);
}
void* createInferRefitter_INTERNAL(void* engine, void* logger, int version) {
using FuncPtr = void * (*)(void *, void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRefitter_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRefitter_INTERNAL");
return func_ptr(engine, logger, version);
}
void* createInferRuntime_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
return func_ptr(logger, version);
}
nvinfer1::ILogger* getLogger() {
using FuncPtr = nvinfer1::ILogger * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
return func_ptr();
}
int getInferLibVersion() {
using FuncPtr = int (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
return func_ptr();
}
nvinfer1::IPluginRegistry* getPluginRegistry() {
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
return func_ptr();
}
} // extern "C"

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "third_party/tensorrt/NvInferPlugin.h"
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
tensorflow::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
} // namespace
#if NV_TENSORRT_MAJOR < 5
#error TensorRT version earlier than 5 is not supported.
#elif NV_TENSORRT_MINOR < 1
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc"
#else
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc"
#endif

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "third_party/tensorrt/NvInfer.h"
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetNvInferDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
tensorflow::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
} // namespace
#if NV_TENSORRT_MAJOR < 5
#error TensorRT version earlier than 5 is not supported.
#elif NV_TENSORRT_MINOR < 1
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc"
#else
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc"
#endif

View File

@ -33,6 +33,11 @@ string GetCudaVersion() { return TF_CUDA_VERSION; }
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
// TODO(laigd): populate the version string during configuration process. For
// now hardcoded version 5 since 4.0 is not supported anyway.
#define TF_TENSORRT_VERSION "5"
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
auto filename = port::Env::Default()->FormatLibraryFileName(name, version);
void* dso_handle;
@ -108,6 +113,14 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
return GetDsoHandle("cudnn", GetCudnnVersion());
}
port::StatusOr<void*> GetNvInferDsoHandle() {
return GetDsoHandle("nvinfer", GetTensorRTVersion());
}
port::StatusOr<void*> GetNvInferPluginDsoHandle() {
return GetDsoHandle("nvinfer_plugin", GetTensorRTVersion());
}
port::StatusOr<void*> GetRocblasDsoHandle() {
return GetDsoHandle("rocblas", "");
}

View File

@ -43,6 +43,8 @@ port::StatusOr<void*> GetCusolverDsoHandle();
port::StatusOr<void*> GetCusparseDsoHandle();
port::StatusOr<void*> GetCuptiDsoHandle();
port::StatusOr<void*> GetCudnnDsoHandle();
port::StatusOr<void*> GetNvInferDsoHandle();
port::StatusOr<void*> GetNvInferPluginDsoHandle();
port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();