Avoid copying a GraphDef and GrapplerItem in AnalyticalCostEstimator::PredictCosts().

PredictCosts() is most often called from VirtualCluster::Run(), which has two special properties:

1. The `optimized_graph` argument is the same as `this->item_->graph`.
2. The `costs` output parameter is ignored.

Optimizing for this case avoids unnecessary GraphDef copies, which can be expensive at startup time.

PiperOrigin-RevId: 261120012
This commit is contained in:
Derek Murray 2019-08-01 07:45:21 -07:00 committed by TensorFlower Gardener
parent 48af54a586
commit 3f89bc5175
2 changed files with 22 additions and 7 deletions
tensorflow/core/grappler

View File

@ -85,9 +85,8 @@ Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) {
}
TF_RETURN_IF_ERROR(estimator_->Initialize(item));
Costs ignored_costs;
TF_RETURN_IF_ERROR(
estimator_->PredictCosts(item.graph, metadata, &ignored_costs));
estimator_->PredictCosts(item.graph, metadata, /*cost=*/nullptr));
const std::unordered_map<string, DeviceProperties>& device = GetDevices();
std::unordered_map<string, int64> peak_mem_usage =

View File

@ -149,12 +149,24 @@ Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
RunMetadata* run_metadata,
Costs* costs) const {
GraphDef graph_copy = optimized_graph;
GrapplerItem item = item_->WithGraph(std::move(graph_copy));
std::unique_ptr<GrapplerItem> item_storage;
const GrapplerItem* item;
// Many callers to PredictCosts() pass the same optimized_graph as was used
// to initialize the estimator.
if (&optimized_graph == &item_->graph) {
item = item_;
} else {
GraphDef graph_copy = optimized_graph;
item_storage = absl::make_unique<GrapplerItem>(
item_->WithGraph(std::move(graph_copy)));
item = item_storage.get();
}
auto status = scheduler_->Init(&item);
auto status = scheduler_->Init(item);
if (!status.ok()) {
costs->execution_time = Costs::Duration::max();
if (costs) {
costs->execution_time = Costs::Duration::max();
}
return status;
}
@ -203,7 +215,11 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
}
// run_metadata gets step_stats and partition_graphs from Summary.
*costs = scheduler_->Summary(run_metadata);
if (costs) {
*costs = scheduler_->Summary(run_metadata);
} else if (run_metadata) {
scheduler_->GenerateRunMetadata(run_metadata);
}
if (VLOG_IS_ON(1)) {
bool verbose = VLOG_IS_ON(2);