From d648d7e6e12774d5c60418a899d15b81a387c770 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Tue, 21 Aug 2018 12:35:42 -0700 Subject: [PATCH] Initialize TRTOptimizationPass members in the constructor, and use a util function to get the precision mode. PiperOrigin-RevId: 209641428 --- .../tensorrt/convert/trt_optimization_pass.cc | 23 ++++--------------- .../tensorrt/convert/trt_optimization_pass.h | 8 ++++--- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index f33f2cc4d68..ff4fba58bfc 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -14,6 +14,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { VLOG(1) << "Called INIT for " << name_ << " with config = " << config; if (config == nullptr) { - maximum_workspace_size_ = 2 << 30; return tensorflow::Status::OK(); } const auto params = config->parameter_map(); @@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - is_dynamic_op_ = false; if (params.count("is_dynamic_op")) { is_dynamic_op_ = params.at("is_dynamic_op").b(); } @@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init( batches_.push_back(i); } } - max_cached_batches_ = 1; if (params.count("maximum_cached_engines")) { max_cached_batches_ = params.at("maximum_cached_engines").i(); } if (params.count("max_workspace_size_bytes")) { - maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i(); } if (params.count("precision_mode")) { - string pm = Uppercase(params.at("precision_mode").s()); - if (pm == "FP32") { - precision_mode_ = 0; - } else if (pm == "FP16") { - precision_mode_ = 1; - } else if (pm == "INT8") { - precision_mode_ = 2; - } else { - LOG(ERROR) << "Unknown precision mode '" << pm << "'"; - return tensorflow::errors::InvalidArgument( - "Unknown precision mode argument" + pm + - " Valid values are FP32, FP16, INT8"); - } + TF_RETURN_IF_ERROR(GetPrecisionMode( + Uppercase(params.at("precision_mode").s()), &precision_mode_)); } return tensorflow::Status::OK(); } @@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.input_graph_def = &item.graph; cp.output_names = &nodes_to_preserve; cp.max_batch_size = maximum_batch_size_; - cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.max_workspace_size_bytes = max_workspace_size_bytes_; cp.output_graph_def = optimized_graph; cp.precision_mode = precision_mode_; cp.minimum_segment_size = minimum_segment_size_; diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index 463ed3883e4..71b51d13681 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { minimum_segment_size_(3), precision_mode_(0), maximum_batch_size_(-1), - maximum_workspace_size_(-1) { + is_dynamic_op_(false), + max_cached_batches_(1), + max_workspace_size_bytes_(256LL << 20) { VLOG(1) << "Constructing " << name_; } @@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { const tensorflow::grappler::GrapplerItem& item); private: - string name_; + const string name_; int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; bool is_dynamic_op_; std::vector batches_; int max_cached_batches_; - int64_t maximum_workspace_size_; + int64_t max_workspace_size_bytes_; }; } // namespace convert