TFTRT: Load TensorRT plugin library

This commit is contained in:
Pooya Davoodi 2019-01-08 14:42:09 -08:00
parent 4129fff6b4
commit 083eece828
3 changed files with 35 additions and 1 deletions

View File

@ -61,6 +61,7 @@ limitations under the License.
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
#include "tensorrt/include/NvInferPlugin.h"
namespace tensorflow {
namespace tensorrt {
namespace convert {
@ -957,6 +958,26 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(INFO) << "Number of TensorRT candidate segments: "
<< initial_segments.size();
// Check if plugins can be aaccessed.
int num_trt_plugins = 0;
nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
if (!trt_plugin_creator_list) {
LOG(WARNING) << "Can not find any TensorRT plugins in registry.";
}
else {
VLOG(1) << "Found the following " << num_trt_plugins << " TensorRT plugins in registry:";
for (int i = 0; i < num_trt_plugins; ++i) {
if (!trt_plugin_creator_list[i]) {
LOG(WARNING) << "TensorRT plugin at index " << i <<
" is not accessible (null pointer returned by getPluginCreatorList for this plugin)";
}
else {
VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
}
}
}
// Get the EngineInfo for each segment.
std::unordered_map<string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));

View File

@ -45,6 +45,7 @@ from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
import ctypes
def _to_bytes(s):
@ -482,6 +483,18 @@ class TrtGraphConverter(GraphConverter):
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
# Load TensorRT plugins.
try:
plugin_lib = ctypes.CDLL("libnvinfer_plugin.so")
except Exception as e:
tf_logging.warn("Failed to load libnvinfer_plugin.so" + str(e))
else:
# Initialize and register TensorRT plugins.
plugin_lib_registered = plugin_lib.initLibNvInferPlugins(None, "")
if plugin_lib_registered != 1:
tf_logging.warn("Failed to initialize and register TensorRT plugins "
"with initLibNvInferPlugins")
# Check input arguments.
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(
):

View File

@ -24,7 +24,7 @@ _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
_TF_TENSORRT_LIBS = ["nvinfer"]
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h"]
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"