diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index f3db42509ec..3d46163f75b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -61,6 +61,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" +#include "tensorrt/include/NvInferPlugin.h" namespace tensorflow { namespace tensorrt { namespace convert { @@ -957,6 +958,26 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(INFO) << "Number of TensorRT candidate segments: " << initial_segments.size(); + // Check if plugins can be aaccessed. + int num_trt_plugins = 0; + nvinfer1::IPluginCreator* const* trt_plugin_creator_list = + getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); + if (!trt_plugin_creator_list) { + LOG(WARNING) << "Can not find any TensorRT plugins in registry."; + } + else { + VLOG(1) << "Found the following " << num_trt_plugins << " TensorRT plugins in registry:"; + for (int i = 0; i < num_trt_plugins; ++i) { + if (!trt_plugin_creator_list[i]) { + LOG(WARNING) << "TensorRT plugin at index " << i << + " is not accessible (null pointer returned by getPluginCreatorList for this plugin)"; + } + else { + VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); + } + } + } + // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 33b5e50418f..90ef73621bb 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -45,6 +45,7 @@ from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver +import ctypes def _to_bytes(s): @@ -482,6 +483,18 @@ class TrtGraphConverter(GraphConverter): tf_logging.info("Running against TensorRT version %s" % ".".join( [str(x) for x in loaded_version])) + # Load TensorRT plugins. + try: + plugin_lib = ctypes.CDLL("libnvinfer_plugin.so") + except Exception as e: + tf_logging.warn("Failed to load libnvinfer_plugin.so" + str(e)) + else: + # Initialize and register TensorRT plugins. + plugin_lib_registered = plugin_lib.initLibNvInferPlugins(None, "") + if plugin_lib_registered != 1: + tf_logging.warn("Failed to initialize and register TensorRT plugins " + "with initLibNvInferPlugins") + # Check input arguments. if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes( ): diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index c6de25b33e3..ee32905eca3 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -24,7 +24,7 @@ _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO" _TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" _TF_TENSORRT_LIBS = ["nvinfer"] -_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h"] +_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"] _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"