Initialize TRTOptimizationPass members in the constructor, and use a util

function to get the precision mode.

PiperOrigin-RevId: 209641428
This commit is contained in:
Guangda Lai 2018-08-21 12:35:42 -07:00 committed by TensorFlower Gardener
parent 4f41091f88
commit d648d7e6e1
2 changed files with 10 additions and 21 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" #include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
VLOG(1) << "Called INIT for " << name_ << " with config = " << config; VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
if (config == nullptr) { if (config == nullptr) {
maximum_workspace_size_ = 2 << 30;
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
const auto params = config->parameter_map(); const auto params = config->parameter_map();
@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init(
if (params.count("max_batch_size")) { if (params.count("max_batch_size")) {
maximum_batch_size_ = params.at("max_batch_size").i(); maximum_batch_size_ = params.at("max_batch_size").i();
} }
is_dynamic_op_ = false;
if (params.count("is_dynamic_op")) { if (params.count("is_dynamic_op")) {
is_dynamic_op_ = params.at("is_dynamic_op").b(); is_dynamic_op_ = params.at("is_dynamic_op").b();
} }
@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init(
batches_.push_back(i); batches_.push_back(i);
} }
} }
max_cached_batches_ = 1;
if (params.count("maximum_cached_engines")) { if (params.count("maximum_cached_engines")) {
max_cached_batches_ = params.at("maximum_cached_engines").i(); max_cached_batches_ = params.at("maximum_cached_engines").i();
} }
if (params.count("max_workspace_size_bytes")) { if (params.count("max_workspace_size_bytes")) {
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
} }
if (params.count("precision_mode")) { if (params.count("precision_mode")) {
string pm = Uppercase(params.at("precision_mode").s()); TF_RETURN_IF_ERROR(GetPrecisionMode(
if (pm == "FP32") { Uppercase(params.at("precision_mode").s()), &precision_mode_));
precision_mode_ = 0;
} else if (pm == "FP16") {
precision_mode_ = 1;
} else if (pm == "INT8") {
precision_mode_ = 2;
} else {
LOG(ERROR) << "Unknown precision mode '" << pm << "'";
return tensorflow::errors::InvalidArgument(
"Unknown precision mode argument" + pm +
" Valid values are FP32, FP16, INT8");
}
} }
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
cp.input_graph_def = &item.graph; cp.input_graph_def = &item.graph;
cp.output_names = &nodes_to_preserve; cp.output_names = &nodes_to_preserve;
cp.max_batch_size = maximum_batch_size_; cp.max_batch_size = maximum_batch_size_;
cp.max_workspace_size_bytes = maximum_workspace_size_; cp.max_workspace_size_bytes = max_workspace_size_bytes_;
cp.output_graph_def = optimized_graph; cp.output_graph_def = optimized_graph;
cp.precision_mode = precision_mode_; cp.precision_mode = precision_mode_;
cp.minimum_segment_size = minimum_segment_size_; cp.minimum_segment_size = minimum_segment_size_;

View File

@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
minimum_segment_size_(3), minimum_segment_size_(3),
precision_mode_(0), precision_mode_(0),
maximum_batch_size_(-1), maximum_batch_size_(-1),
maximum_workspace_size_(-1) { is_dynamic_op_(false),
max_cached_batches_(1),
max_workspace_size_bytes_(256LL << 20) {
VLOG(1) << "Constructing " << name_; VLOG(1) << "Constructing " << name_;
} }
@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
const tensorflow::grappler::GrapplerItem& item); const tensorflow::grappler::GrapplerItem& item);
private: private:
string name_; const string name_;
int minimum_segment_size_; int minimum_segment_size_;
int precision_mode_; int precision_mode_;
int maximum_batch_size_; int maximum_batch_size_;
bool is_dynamic_op_; bool is_dynamic_op_;
std::vector<int> batches_; std::vector<int> batches_;
int max_cached_batches_; int max_cached_batches_;
int64_t maximum_workspace_size_; int64_t max_workspace_size_bytes_;
}; };
} // namespace convert } // namespace convert