Initialize TRTOptimizationPass members in the constructor, and use a util
function to get the precision mode. PiperOrigin-RevId: 209641428
This commit is contained in:
parent
4f41091f88
commit
d648d7e6e1
@ -14,6 +14,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
||||
VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
|
||||
if (config == nullptr) {
|
||||
maximum_workspace_size_ = 2 << 30;
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
const auto params = config->parameter_map();
|
||||
@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init(
|
||||
if (params.count("max_batch_size")) {
|
||||
maximum_batch_size_ = params.at("max_batch_size").i();
|
||||
}
|
||||
is_dynamic_op_ = false;
|
||||
if (params.count("is_dynamic_op")) {
|
||||
is_dynamic_op_ = params.at("is_dynamic_op").b();
|
||||
}
|
||||
@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init(
|
||||
batches_.push_back(i);
|
||||
}
|
||||
}
|
||||
max_cached_batches_ = 1;
|
||||
if (params.count("maximum_cached_engines")) {
|
||||
max_cached_batches_ = params.at("maximum_cached_engines").i();
|
||||
}
|
||||
if (params.count("max_workspace_size_bytes")) {
|
||||
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
|
||||
max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
|
||||
}
|
||||
if (params.count("precision_mode")) {
|
||||
string pm = Uppercase(params.at("precision_mode").s());
|
||||
if (pm == "FP32") {
|
||||
precision_mode_ = 0;
|
||||
} else if (pm == "FP16") {
|
||||
precision_mode_ = 1;
|
||||
} else if (pm == "INT8") {
|
||||
precision_mode_ = 2;
|
||||
} else {
|
||||
LOG(ERROR) << "Unknown precision mode '" << pm << "'";
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unknown precision mode argument" + pm +
|
||||
" Valid values are FP32, FP16, INT8");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetPrecisionMode(
|
||||
Uppercase(params.at("precision_mode").s()), &precision_mode_));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
|
||||
cp.input_graph_def = &item.graph;
|
||||
cp.output_names = &nodes_to_preserve;
|
||||
cp.max_batch_size = maximum_batch_size_;
|
||||
cp.max_workspace_size_bytes = maximum_workspace_size_;
|
||||
cp.max_workspace_size_bytes = max_workspace_size_bytes_;
|
||||
cp.output_graph_def = optimized_graph;
|
||||
cp.precision_mode = precision_mode_;
|
||||
cp.minimum_segment_size = minimum_segment_size_;
|
||||
|
@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
|
||||
minimum_segment_size_(3),
|
||||
precision_mode_(0),
|
||||
maximum_batch_size_(-1),
|
||||
maximum_workspace_size_(-1) {
|
||||
is_dynamic_op_(false),
|
||||
max_cached_batches_(1),
|
||||
max_workspace_size_bytes_(256LL << 20) {
|
||||
VLOG(1) << "Constructing " << name_;
|
||||
}
|
||||
|
||||
@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
|
||||
const tensorflow::grappler::GrapplerItem& item);
|
||||
|
||||
private:
|
||||
string name_;
|
||||
const string name_;
|
||||
int minimum_segment_size_;
|
||||
int precision_mode_;
|
||||
int maximum_batch_size_;
|
||||
bool is_dynamic_op_;
|
||||
std::vector<int> batches_;
|
||||
int max_cached_batches_;
|
||||
int64_t maximum_workspace_size_;
|
||||
int64_t max_workspace_size_bytes_;
|
||||
};
|
||||
|
||||
} // namespace convert
|
||||
|
Loading…
Reference in New Issue
Block a user