Initialize TRTOptimizationPass members in the constructor, and use a util
function to get the precision mode. PiperOrigin-RevId: 209641428
This commit is contained in:
parent
4f41091f88
commit
d648d7e6e1
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
|
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
|
||||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||||
@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init(
|
|||||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
||||||
VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
|
VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
|
||||||
if (config == nullptr) {
|
if (config == nullptr) {
|
||||||
maximum_workspace_size_ = 2 << 30;
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
const auto params = config->parameter_map();
|
const auto params = config->parameter_map();
|
||||||
@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init(
|
|||||||
if (params.count("max_batch_size")) {
|
if (params.count("max_batch_size")) {
|
||||||
maximum_batch_size_ = params.at("max_batch_size").i();
|
maximum_batch_size_ = params.at("max_batch_size").i();
|
||||||
}
|
}
|
||||||
is_dynamic_op_ = false;
|
|
||||||
if (params.count("is_dynamic_op")) {
|
if (params.count("is_dynamic_op")) {
|
||||||
is_dynamic_op_ = params.at("is_dynamic_op").b();
|
is_dynamic_op_ = params.at("is_dynamic_op").b();
|
||||||
}
|
}
|
||||||
@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init(
|
|||||||
batches_.push_back(i);
|
batches_.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
max_cached_batches_ = 1;
|
|
||||||
if (params.count("maximum_cached_engines")) {
|
if (params.count("maximum_cached_engines")) {
|
||||||
max_cached_batches_ = params.at("maximum_cached_engines").i();
|
max_cached_batches_ = params.at("maximum_cached_engines").i();
|
||||||
}
|
}
|
||||||
if (params.count("max_workspace_size_bytes")) {
|
if (params.count("max_workspace_size_bytes")) {
|
||||||
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
|
max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
|
||||||
}
|
}
|
||||||
if (params.count("precision_mode")) {
|
if (params.count("precision_mode")) {
|
||||||
string pm = Uppercase(params.at("precision_mode").s());
|
TF_RETURN_IF_ERROR(GetPrecisionMode(
|
||||||
if (pm == "FP32") {
|
Uppercase(params.at("precision_mode").s()), &precision_mode_));
|
||||||
precision_mode_ = 0;
|
|
||||||
} else if (pm == "FP16") {
|
|
||||||
precision_mode_ = 1;
|
|
||||||
} else if (pm == "INT8") {
|
|
||||||
precision_mode_ = 2;
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Unknown precision mode '" << pm << "'";
|
|
||||||
return tensorflow::errors::InvalidArgument(
|
|
||||||
"Unknown precision mode argument" + pm +
|
|
||||||
" Valid values are FP32, FP16, INT8");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
|
|||||||
cp.input_graph_def = &item.graph;
|
cp.input_graph_def = &item.graph;
|
||||||
cp.output_names = &nodes_to_preserve;
|
cp.output_names = &nodes_to_preserve;
|
||||||
cp.max_batch_size = maximum_batch_size_;
|
cp.max_batch_size = maximum_batch_size_;
|
||||||
cp.max_workspace_size_bytes = maximum_workspace_size_;
|
cp.max_workspace_size_bytes = max_workspace_size_bytes_;
|
||||||
cp.output_graph_def = optimized_graph;
|
cp.output_graph_def = optimized_graph;
|
||||||
cp.precision_mode = precision_mode_;
|
cp.precision_mode = precision_mode_;
|
||||||
cp.minimum_segment_size = minimum_segment_size_;
|
cp.minimum_segment_size = minimum_segment_size_;
|
||||||
|
@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
|
|||||||
minimum_segment_size_(3),
|
minimum_segment_size_(3),
|
||||||
precision_mode_(0),
|
precision_mode_(0),
|
||||||
maximum_batch_size_(-1),
|
maximum_batch_size_(-1),
|
||||||
maximum_workspace_size_(-1) {
|
is_dynamic_op_(false),
|
||||||
|
max_cached_batches_(1),
|
||||||
|
max_workspace_size_bytes_(256LL << 20) {
|
||||||
VLOG(1) << "Constructing " << name_;
|
VLOG(1) << "Constructing " << name_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
|
|||||||
const tensorflow::grappler::GrapplerItem& item);
|
const tensorflow::grappler::GrapplerItem& item);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string name_;
|
const string name_;
|
||||||
int minimum_segment_size_;
|
int minimum_segment_size_;
|
||||||
int precision_mode_;
|
int precision_mode_;
|
||||||
int maximum_batch_size_;
|
int maximum_batch_size_;
|
||||||
bool is_dynamic_op_;
|
bool is_dynamic_op_;
|
||||||
std::vector<int> batches_;
|
std::vector<int> batches_;
|
||||||
int max_cached_batches_;
|
int max_cached_batches_;
|
||||||
int64_t maximum_workspace_size_;
|
int64_t max_workspace_size_bytes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace convert
|
} // namespace convert
|
||||||
|
Loading…
Reference in New Issue
Block a user