Automated rollback of commit d68a02f6d1

PiperOrigin-RevId: 272811870
This commit is contained in:
Brian Zhao 2019-10-03 22:45:45 -07:00 committed by TensorFlower Gardener
parent 9b5f9aabd8
commit 4de76f66eb
12 changed files with 17 additions and 205 deletions

View File

@ -211,7 +211,6 @@ tf_cuda_library(
hdrs = ["utils/trt_logger.h"],
visibility = ["//visibility:public"],
deps = [
":logger_registry",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([":tensorrt_lib"]),
)
@ -301,19 +300,6 @@ tf_cc_test(
],
)
tf_cuda_library(
name = "logger_registry",
srcs = ["convert/logger_registry.cc"],
hdrs = [
"convert/logger_registry.h",
],
copts = tf_copts(),
deps = [
"@com_google_absl//absl/strings",
"//tensorflow/core:lib",
] + if_tensorrt([":tensorrt_lib"]),
)
# Library for the node-level conversion portion of TensorRT operation creation
tf_cuda_library(
name = "trt_conversion",
@ -328,7 +314,6 @@ tf_cuda_library(
"convert/trt_optimization_pass.h",
],
deps = [
":logger_registry",
":segment",
":trt_allocator",
":trt_plugins",

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
@ -418,15 +417,16 @@ Status CreateTRTNode(const ConversionParams& params,
// Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
// Create static engine for fp32/fp16 mode.
Logger trt_logger;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
max_batch_size, info.max_workspace_size_bytes, input_shapes,
&trt_logger, alloc, /*calibrator=*/nullptr, &engine,
info.use_calibration,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string = string(static_cast<const char*>(engine_data->data()),

View File

@ -35,7 +35,6 @@ namespace convert {
struct ConversionParams {
const GraphDef* input_graph_def = nullptr;
const std::vector<string>* output_names = nullptr;
string trt_logger_name;
size_t max_batch_size = 1;
size_t max_workspace_size_bytes = 1 << 30;
GraphDef* output_graph_def = nullptr;

View File

@ -1176,13 +1176,14 @@ Status TrtNodeValidator::ConvertConstToWeights(
return status;
}
static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
static void InitializeTrtPlugins() {
static mutex plugin_mutex(LINKER_INITIALIZED);
static bool plugin_initialized = false;
static Logger trt_logger;
mutex_lock lock(plugin_mutex);
if (plugin_initialized) return;
plugin_initialized = initLibNvInferPlugins(trt_logger, "");
plugin_initialized = initLibNvInferPlugins(&trt_logger, "");
if (!plugin_initialized) {
LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
"fail later.";
@ -1209,12 +1210,11 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
}
Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
TrtPrecisionMode precision_mode, bool use_calibration,
nvinfer1::ILogger* trt_logger)
TrtPrecisionMode precision_mode, bool use_calibration)
: trt_network_(trt_network),
precision_mode_(precision_mode),
use_calibration_(use_calibration) {
InitializeTrtPlugins(trt_logger);
InitializeTrtPlugins();
this->RegisterOpConverters();
}
@ -5476,9 +5476,8 @@ void Converter::RegisterOpConverters() {
Status ConvertGraphDefToEngine(
const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
size_t max_workspace_size_bytes,
const std::vector<PartialTensorShape>& input_shapes,
nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
bool* convert_successfully) {
engine->reset();
@ -5486,7 +5485,7 @@ Status ConvertGraphDefToEngine(
// Create the builder.
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(*trt_logger));
nvinfer1::createInferBuilder(*logger));
builder->setMaxBatchSize(max_batch_size);
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
builder->setGpuAllocator(allocator);
@ -5518,8 +5517,7 @@ Status ConvertGraphDefToEngine(
TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str));
VLOG(1) << "Starting engine conversion, precision mode: " << mode_str;
}
Converter converter(trt_network.get(), precision_mode, use_calibration,
trt_logger);
Converter converter(trt_network.get(), precision_mode, use_calibration);
std::vector<Converter::EngineOutputInfo> output_tensors;
// Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {

View File

@ -147,9 +147,8 @@ Status ConvertSegmentToGraphDef(
Status ConvertGraphDefToEngine(
const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
size_t max_workspace_size_bytes,
const std::vector<PartialTensorShape>& input_shapes,
nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
bool* convert_successfully);
@ -444,8 +443,7 @@ class Converter {
};
Converter(nvinfer1::INetworkDefinition* trt_network,
TrtPrecisionMode precision_mode, bool use_calibration,
nvinfer1::ILogger* trt_logger);
TrtPrecisionMode precision_mode, bool use_calibration);
//////////////////////////////////////////////////////////////////////////////
// Methods used by the TRT engine builder to build a TRT network from a TF

View File

@ -1,60 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorrt {
class LoggerRegistryImpl : public LoggerRegistry {
Status Register(const string& name, nvinfer1::ILogger* logger) override {
mutex_lock lock(mu_);
if (!registry_.emplace(name, std::unique_ptr<nvinfer1::ILogger>(logger))
.second) {
return errors::AlreadyExists("Logger ", name, " already registered");
}
return Status::OK();
}
nvinfer1::ILogger* LookUp(const string& name) override {
mutex_lock lock(mu_);
const auto found = registry_.find(name);
if (found == registry_.end()) {
return nullptr;
}
return found->second.get();
}
private:
mutable mutex mu_;
mutable std::unordered_map<string, std::unique_ptr<nvinfer1::ILogger>>
registry_ GUARDED_BY(mu_);
};
LoggerRegistry* GetLoggerRegistry() {
static LoggerRegistryImpl* registry = new LoggerRegistryImpl;
return registry;
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -1,57 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
class LoggerRegistry {
public:
virtual Status Register(const string& name, nvinfer1::ILogger* logger) = 0;
virtual nvinfer1::ILogger* LookUp(const string& name) = 0;
virtual ~LoggerRegistry() {}
};
LoggerRegistry* GetLoggerRegistry();
class RegisterLogger {
public:
RegisterLogger(const string& name, nvinfer1::ILogger* logger) {
TF_CHECK_OK(GetLoggerRegistry()->Register(name, logger));
}
};
#define REGISTER_TENSORRT_LOGGER(name, logger) \
REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(__COUNTER__, name, logger)
#define REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(ctr, name, logger) \
REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger)
#define REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger) \
static ::tensorflow::tensorrt::RegisterLogger register_trt_logger##ctr \
TF_ATTRIBUTE_UNUSED = \
::tensorflow::tensorrt::RegisterLogger(name, logger)
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_

View File

@ -1,34 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace {
class TestLogger : public nvinfer1::ILogger {
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {}
};
TestLogger test_logger;
REGISTER_TENSORRT_LOGGER("test_logger", &test_logger);
TEST(LoggerRegistryTest, RegistersCorrectly) {
auto registered_logger = GetLoggerRegistry()->LookUp("test_logger");
EXPECT_THAT(registered_logger, Eq(&test_logger));
}
} // namespace

View File

@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init(
if (params.count("use_calibration")) {
use_calibration_ = params.at("use_calibration").b();
}
if (params.count("trt_logger")) {
trt_logger_name_ = params.at("trt_logger").s();
}
return Status::OK();
}
@ -248,7 +245,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
}
cp.input_graph_def = &item.graph;
cp.output_names = &nodes_to_preserve;
cp.trt_logger_name = trt_logger_name_;
cp.max_batch_size = maximum_batch_size_;
cp.max_workspace_size_bytes = max_workspace_size_bytes_;
cp.output_graph_def = optimized_graph;

View File

@ -34,7 +34,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
public:
TRTOptimizationPass(const string& name = "TRTOptimizationPass")
: name_(name),
trt_logger_name_("DefaultLogger"),
minimum_segment_size_(3),
precision_mode_(TrtPrecisionMode::FP32),
maximum_batch_size_(-1),
@ -64,7 +63,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
private:
const string name_;
string trt_logger_name_;
int minimum_segment_size_;
TrtPrecisionMode precision_mode_;
int maximum_batch_size_;
@ -73,6 +71,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
int max_cached_batches_;
int64_t max_workspace_size_bytes_;
bool use_calibration_;
};
} // namespace convert

View File

@ -17,7 +17,6 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -55,15 +54,6 @@ void Logger::log(Severity severity, const char* msg) {
}
}
}
// static
Logger* Logger::GetLogger() {
static Logger* logger = new Logger("DefaultLogger");
return logger;
}
REGISTER_TENSORRT_LOGGER("DefaultLogger", Logger::GetLogger());
} // namespace tensorrt
} // namespace tensorflow

View File

@ -31,8 +31,6 @@ class Logger : public nvinfer1::ILogger {
Logger(string name = "DefaultLogger") : name_(name) {}
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
static Logger* GetLogger();
private:
string name_;
};