TFTRT: Load TensorRT plugin library
This commit is contained in:
parent
4129fff6b4
commit
083eece828
@ -61,6 +61,7 @@ limitations under the License.
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
#include "tensorrt/include/NvInferPlugin.h"
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
@ -957,6 +958,26 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
|
||||
LOG(INFO) << "Number of TensorRT candidate segments: "
|
||||
<< initial_segments.size();
|
||||
|
||||
// Check if plugins can be aaccessed.
|
||||
int num_trt_plugins = 0;
|
||||
nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
|
||||
getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
|
||||
if (!trt_plugin_creator_list) {
|
||||
LOG(WARNING) << "Can not find any TensorRT plugins in registry.";
|
||||
}
|
||||
else {
|
||||
VLOG(1) << "Found the following " << num_trt_plugins << " TensorRT plugins in registry:";
|
||||
for (int i = 0; i < num_trt_plugins; ++i) {
|
||||
if (!trt_plugin_creator_list[i]) {
|
||||
LOG(WARNING) << "TensorRT plugin at index " << i <<
|
||||
" is not accessible (null pointer returned by getPluginCreatorList for this plugin)";
|
||||
}
|
||||
else {
|
||||
VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the EngineInfo for each segment.
|
||||
std::unordered_map<string, tensorflow::Node*> node_map;
|
||||
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.saved_model import builder
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import saver
|
||||
import ctypes
|
||||
|
||||
|
||||
def _to_bytes(s):
|
||||
@ -482,6 +483,18 @@ class TrtGraphConverter(GraphConverter):
|
||||
tf_logging.info("Running against TensorRT version %s" % ".".join(
|
||||
[str(x) for x in loaded_version]))
|
||||
|
||||
# Load TensorRT plugins.
|
||||
try:
|
||||
plugin_lib = ctypes.CDLL("libnvinfer_plugin.so")
|
||||
except Exception as e:
|
||||
tf_logging.warn("Failed to load libnvinfer_plugin.so" + str(e))
|
||||
else:
|
||||
# Initialize and register TensorRT plugins.
|
||||
plugin_lib_registered = plugin_lib.initLibNvInferPlugins(None, "")
|
||||
if plugin_lib_registered != 1:
|
||||
tf_logging.warn("Failed to initialize and register TensorRT plugins "
|
||||
"with initLibNvInferPlugins")
|
||||
|
||||
# Check input arguments.
|
||||
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(
|
||||
):
|
||||
|
2
third_party/tensorrt/tensorrt_configure.bzl
vendored
2
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -24,7 +24,7 @@ _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
|
||||
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
|
||||
|
||||
_TF_TENSORRT_LIBS = ["nvinfer"]
|
||||
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h"]
|
||||
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
|
||||
|
||||
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
||||
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
||||
|
Loading…
Reference in New Issue
Block a user