parent
9b5f9aabd8
commit
4de76f66eb
@ -211,7 +211,6 @@ tf_cuda_library(
|
||||
hdrs = ["utils/trt_logger.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":logger_registry",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
@ -301,19 +300,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "logger_registry",
|
||||
srcs = ["convert/logger_registry.cc"],
|
||||
hdrs = [
|
||||
"convert/logger_registry.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
# Library for the node-level conversion portion of TensorRT operation creation
|
||||
tf_cuda_library(
|
||||
name = "trt_conversion",
|
||||
@ -328,7 +314,6 @@ tf_cuda_library(
|
||||
"convert/trt_optimization_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":logger_registry",
|
||||
":segment",
|
||||
":trt_allocator",
|
||||
":trt_plugins",
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||
@ -418,15 +417,16 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
// Build the engine and get its serialized representation.
|
||||
string segment_string;
|
||||
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
|
||||
auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
|
||||
// Create static engine for fp32/fp16 mode.
|
||||
Logger trt_logger;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
// TODO(sami): What happens if 1st dim is not batch?
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
|
||||
info.segment_graph_def,
|
||||
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
|
||||
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
|
||||
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
|
||||
max_batch_size, info.max_workspace_size_bytes, input_shapes,
|
||||
&trt_logger, alloc, /*calibrator=*/nullptr, &engine,
|
||||
info.use_calibration,
|
||||
/*convert_successfully=*/nullptr));
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
||||
segment_string = string(static_cast<const char*>(engine_data->data()),
|
||||
|
@ -35,7 +35,6 @@ namespace convert {
|
||||
struct ConversionParams {
|
||||
const GraphDef* input_graph_def = nullptr;
|
||||
const std::vector<string>* output_names = nullptr;
|
||||
string trt_logger_name;
|
||||
size_t max_batch_size = 1;
|
||||
size_t max_workspace_size_bytes = 1 << 30;
|
||||
GraphDef* output_graph_def = nullptr;
|
||||
|
@ -1176,13 +1176,14 @@ Status TrtNodeValidator::ConvertConstToWeights(
|
||||
return status;
|
||||
}
|
||||
|
||||
static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
|
||||
static void InitializeTrtPlugins() {
|
||||
static mutex plugin_mutex(LINKER_INITIALIZED);
|
||||
static bool plugin_initialized = false;
|
||||
static Logger trt_logger;
|
||||
mutex_lock lock(plugin_mutex);
|
||||
if (plugin_initialized) return;
|
||||
|
||||
plugin_initialized = initLibNvInferPlugins(trt_logger, "");
|
||||
plugin_initialized = initLibNvInferPlugins(&trt_logger, "");
|
||||
if (!plugin_initialized) {
|
||||
LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
|
||||
"fail later.";
|
||||
@ -1209,12 +1210,11 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
|
||||
}
|
||||
|
||||
Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration,
|
||||
nvinfer1::ILogger* trt_logger)
|
||||
TrtPrecisionMode precision_mode, bool use_calibration)
|
||||
: trt_network_(trt_network),
|
||||
precision_mode_(precision_mode),
|
||||
use_calibration_(use_calibration) {
|
||||
InitializeTrtPlugins(trt_logger);
|
||||
InitializeTrtPlugins();
|
||||
this->RegisterOpConverters();
|
||||
}
|
||||
|
||||
@ -5476,9 +5476,8 @@ void Converter::RegisterOpConverters() {
|
||||
Status ConvertGraphDefToEngine(
|
||||
const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
|
||||
nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
|
||||
bool* convert_successfully) {
|
||||
engine->reset();
|
||||
@ -5486,7 +5485,7 @@ Status ConvertGraphDefToEngine(
|
||||
|
||||
// Create the builder.
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(*trt_logger));
|
||||
nvinfer1::createInferBuilder(*logger));
|
||||
builder->setMaxBatchSize(max_batch_size);
|
||||
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
|
||||
builder->setGpuAllocator(allocator);
|
||||
@ -5518,8 +5517,7 @@ Status ConvertGraphDefToEngine(
|
||||
TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str));
|
||||
VLOG(1) << "Starting engine conversion, precision mode: " << mode_str;
|
||||
}
|
||||
Converter converter(trt_network.get(), precision_mode, use_calibration,
|
||||
trt_logger);
|
||||
Converter converter(trt_network.get(), precision_mode, use_calibration);
|
||||
std::vector<Converter::EngineOutputInfo> output_tensors;
|
||||
// Graph nodes are already topologically sorted during construction
|
||||
for (const auto& node_def : gdef.node()) {
|
||||
|
@ -147,9 +147,8 @@ Status ConvertSegmentToGraphDef(
|
||||
Status ConvertGraphDefToEngine(
|
||||
const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
|
||||
nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
|
||||
bool* convert_successfully);
|
||||
|
||||
@ -444,8 +443,7 @@ class Converter {
|
||||
};
|
||||
|
||||
Converter(nvinfer1::INetworkDefinition* trt_network,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration,
|
||||
nvinfer1::ILogger* trt_logger);
|
||||
TrtPrecisionMode precision_mode, bool use_calibration);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Methods used by the TRT engine builder to build a TRT network from a TF
|
||||
|
@ -1,60 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class LoggerRegistryImpl : public LoggerRegistry {
|
||||
Status Register(const string& name, nvinfer1::ILogger* logger) override {
|
||||
mutex_lock lock(mu_);
|
||||
if (!registry_.emplace(name, std::unique_ptr<nvinfer1::ILogger>(logger))
|
||||
.second) {
|
||||
return errors::AlreadyExists("Logger ", name, " already registered");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* LookUp(const string& name) override {
|
||||
mutex_lock lock(mu_);
|
||||
const auto found = registry_.find(name);
|
||||
if (found == registry_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return found->second.get();
|
||||
}
|
||||
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
mutable std::unordered_map<string, std::unique_ptr<nvinfer1::ILogger>>
|
||||
registry_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
LoggerRegistry* GetLoggerRegistry() {
|
||||
static LoggerRegistryImpl* registry = new LoggerRegistryImpl;
|
||||
return registry;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_
|
||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class LoggerRegistry {
|
||||
public:
|
||||
virtual Status Register(const string& name, nvinfer1::ILogger* logger) = 0;
|
||||
virtual nvinfer1::ILogger* LookUp(const string& name) = 0;
|
||||
virtual ~LoggerRegistry() {}
|
||||
};
|
||||
|
||||
LoggerRegistry* GetLoggerRegistry();
|
||||
|
||||
class RegisterLogger {
|
||||
public:
|
||||
RegisterLogger(const string& name, nvinfer1::ILogger* logger) {
|
||||
TF_CHECK_OK(GetLoggerRegistry()->Register(name, logger));
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_TENSORRT_LOGGER(name, logger) \
|
||||
REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(__COUNTER__, name, logger)
|
||||
#define REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(ctr, name, logger) \
|
||||
REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger)
|
||||
#define REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger) \
|
||||
static ::tensorflow::tensorrt::RegisterLogger register_trt_logger##ctr \
|
||||
TF_ATTRIBUTE_UNUSED = \
|
||||
::tensorflow::tensorrt::RegisterLogger(name, logger)
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_
|
@ -1,34 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace {
|
||||
|
||||
class TestLogger : public nvinfer1::ILogger {
|
||||
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {}
|
||||
};
|
||||
|
||||
TestLogger test_logger;
|
||||
|
||||
REGISTER_TENSORRT_LOGGER("test_logger", &test_logger);
|
||||
|
||||
TEST(LoggerRegistryTest, RegistersCorrectly) {
|
||||
auto registered_logger = GetLoggerRegistry()->LookUp("test_logger");
|
||||
EXPECT_THAT(registered_logger, Eq(&test_logger));
|
||||
}
|
||||
|
||||
} // namespace
|
@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init(
|
||||
if (params.count("use_calibration")) {
|
||||
use_calibration_ = params.at("use_calibration").b();
|
||||
}
|
||||
if (params.count("trt_logger")) {
|
||||
trt_logger_name_ = params.at("trt_logger").s();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -248,7 +245,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
||||
}
|
||||
cp.input_graph_def = &item.graph;
|
||||
cp.output_names = &nodes_to_preserve;
|
||||
cp.trt_logger_name = trt_logger_name_;
|
||||
cp.max_batch_size = maximum_batch_size_;
|
||||
cp.max_workspace_size_bytes = max_workspace_size_bytes_;
|
||||
cp.output_graph_def = optimized_graph;
|
||||
|
@ -34,7 +34,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
public:
|
||||
TRTOptimizationPass(const string& name = "TRTOptimizationPass")
|
||||
: name_(name),
|
||||
trt_logger_name_("DefaultLogger"),
|
||||
minimum_segment_size_(3),
|
||||
precision_mode_(TrtPrecisionMode::FP32),
|
||||
maximum_batch_size_(-1),
|
||||
@ -64,7 +63,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
|
||||
private:
|
||||
const string name_;
|
||||
string trt_logger_name_;
|
||||
int minimum_segment_size_;
|
||||
TrtPrecisionMode precision_mode_;
|
||||
int maximum_batch_size_;
|
||||
@ -73,6 +71,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
int max_cached_batches_;
|
||||
int64_t max_workspace_size_bytes_;
|
||||
bool use_calibration_;
|
||||
|
||||
};
|
||||
|
||||
} // namespace convert
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -55,15 +54,6 @@ void Logger::log(Severity severity, const char* msg) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
Logger* Logger::GetLogger() {
|
||||
static Logger* logger = new Logger("DefaultLogger");
|
||||
return logger;
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_LOGGER("DefaultLogger", Logger::GetLogger());
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -31,8 +31,6 @@ class Logger : public nvinfer1::ILogger {
|
||||
Logger(string name = "DefaultLogger") : name_(name) {}
|
||||
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
|
||||
|
||||
static Logger* GetLogger();
|
||||
|
||||
private:
|
||||
string name_;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user