[TF:TRT] Initialize TensorRT plugin registry before deserializing cuda engines.
When a TF-TRT converted graph with static cuda engines is executed, we call the TensorRT runtime to deserialize cuda engines without initializing the TensorRT plugin registry. This causes TensorRT runtime failure when the cuda engines contain plugins. Move InitializeTrtPlugins to common/utils.cc and replace the use of mutex with absl::call_once. PiperOrigin-RevId: 327679056 Change-Id: I5e50a01aa06f3b5a22a3114a2c54e3712461bd6b
This commit is contained in:
parent
0a8e341415
commit
dcf4c4f58b
tensorflow/compiler/tf2tensorrt
@ -16,8 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "third_party/tensorrt/NvInferPlugin.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
@ -46,3 +48,52 @@ std::tuple<int, int, int> GetLoadedTensorRTVersion() {
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace {
|
||||
|
||||
void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
|
||||
LOG(INFO) << "Linked TensorRT version: "
|
||||
<< absl::StrJoin(GetLinkedTensorRTVersion(), ".");
|
||||
LOG(INFO) << "Loaded TensorRT version: "
|
||||
<< absl::StrJoin(GetLoadedTensorRTVersion(), ".");
|
||||
|
||||
bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
|
||||
if (!plugin_initialized) {
|
||||
LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
|
||||
"fail later.";
|
||||
}
|
||||
|
||||
int num_trt_plugins = 0;
|
||||
nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
|
||||
getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
|
||||
if (!trt_plugin_creator_list) {
|
||||
LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
|
||||
} else {
|
||||
VLOG(1) << "Found the following " << num_trt_plugins
|
||||
<< " TensorRT plugins in registry:";
|
||||
for (int i = 0; i < num_trt_plugins; ++i) {
|
||||
if (!trt_plugin_creator_list[i]) {
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "TensorRT plugin at index " << i
|
||||
<< " is not accessible (null pointer returned by "
|
||||
"getPluginCreatorList for this plugin)";
|
||||
} else {
|
||||
VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
|
||||
static absl::once_flag once;
|
||||
absl::call_once(once, InitializeTrtPlugins, trt_logger);
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
#endif
|
||||
|
@ -33,12 +33,16 @@ std::tuple<int, int, int> GetLoadedTensorRTVersion();
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
#define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: "
|
||||
|
||||
// Initializes the TensorRT plugin registry if this hasn't been done yet.
|
||||
void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1197,44 +1197,6 @@ Status TrtNodeValidator::ConvertConstToWeights(
|
||||
return status;
|
||||
}
|
||||
|
||||
static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
|
||||
static mutex plugin_mutex(LINKER_INITIALIZED);
|
||||
static bool plugin_initialized = false;
|
||||
mutex_lock lock(plugin_mutex);
|
||||
if (plugin_initialized) return;
|
||||
|
||||
LOG(INFO) << "Linked TensorRT version: "
|
||||
<< absl::StrJoin(GetLinkedTensorRTVersion(), ".");
|
||||
LOG(INFO) << "Loaded TensorRT version: "
|
||||
<< absl::StrJoin(GetLoadedTensorRTVersion(), ".");
|
||||
|
||||
plugin_initialized = initLibNvInferPlugins(trt_logger, "");
|
||||
if (!plugin_initialized) {
|
||||
LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
|
||||
"fail later.";
|
||||
}
|
||||
|
||||
int num_trt_plugins = 0;
|
||||
nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
|
||||
getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
|
||||
if (!trt_plugin_creator_list) {
|
||||
LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
|
||||
} else {
|
||||
VLOG(1) << "Found the following " << num_trt_plugins
|
||||
<< " TensorRT plugins in registry:";
|
||||
for (int i = 0; i < num_trt_plugins; ++i) {
|
||||
if (!trt_plugin_creator_list[i]) {
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "TensorRT plugin at index " << i
|
||||
<< " is not accessible (null pointer returned by "
|
||||
"getPluginCreatorList for this plugin)";
|
||||
} else {
|
||||
VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
StatusOr<std::unique_ptr<Converter>> Converter::Create(
|
||||
TrtPrecisionMode precision_mode, bool use_calibration,
|
||||
@ -1251,7 +1213,7 @@ Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration,
|
||||
: precision_mode_(precision_mode),
|
||||
use_calibration_(use_calibration),
|
||||
use_implicit_batch_(use_implicit_batch) {
|
||||
InitializeTrtPlugins(trt_logger);
|
||||
MaybeInitializeTrtPlugins(trt_logger);
|
||||
this->RegisterOpConverters();
|
||||
}
|
||||
|
||||
|
@ -800,6 +800,9 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
|
||||
|
||||
TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
|
||||
infer->setGpuAllocator(allocator);
|
||||
// Need to initialize plugins in order to deserialize engines that contain
|
||||
// plugins.
|
||||
MaybeInitializeTrtPlugins(&logger);
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
|
||||
infer->deserializeCudaEngine(serialized_segment_.c_str(),
|
||||
serialized_segment_.size(), nullptr));
|
||||
|
Loading…
Reference in New Issue
Block a user