Get rid of a number of gratuitous graph copies in Grappler code.

PiperOrigin-RevId: 249376925
This commit is contained in:
A. Unique TensorFlower 2019-05-21 20:25:48 -07:00 committed by TensorFlower Gardener
parent 74f30571be
commit 731409f8d2
7 changed files with 20 additions and 13 deletions

View File

@ -128,6 +128,11 @@ class Cluster {
const std::vector<string>& fetch,
RunMetadata* metadata) = 0;
// Run the specified GrapplerItem and return the corresponding metadata.
virtual Status Run(const GrapplerItem& item, RunMetadata* metadata) {
return Run(item.graph, item.feed, item.fetch, metadata);
}
protected:
std::unordered_map<string, DeviceProperties> devices_;
const int timeout_s_;

View File

@ -67,14 +67,17 @@ Status VirtualCluster::Run(const GraphDef& graph,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch,
RunMetadata* metadata) {
// Initializes an analytical cost estimator to estimate the graph cost. Makes
// sure to use static shape inference to prevent the virtual scheduler from
// calling the Run method on the cluster and creating an infinite loop.
GrapplerItem item;
item.graph = graph;
item.feed = feed;
item.fetch = fetch;
return Run(item, metadata);
}
Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) {
// Initializes an analytical cost estimator to estimate the graph cost. Makes
// sure to use static shape inference to prevent the virtual scheduler from
// calling the Run method on the cluster and creating an infinite loop.
if (metadata) {
metadata->clear_step_stats();
metadata->clear_cost_graph();

View File

@ -45,9 +45,10 @@ class VirtualCluster : public Cluster {
Status Provision() override;
Status Initialize(const GrapplerItem& item) override;
Status Run(const GraphDef& item,
Status Run(const GraphDef& graph,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
Status Run(const GrapplerItem& item, RunMetadata* metadata) override;
const DeviceSet* GetDeviceSet() const override { return device_set_; }
private:

View File

@ -142,15 +142,15 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
item_ = item;
item_ = &item;
return Status::OK();
}
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
RunMetadata* run_metadata,
Costs* costs) const {
GrapplerItem item = item_;
item.graph = optimized_graph;
GraphDef graph_copy = optimized_graph;
GrapplerItem item = item_->WithGraph(std::move(graph_copy));
auto status = scheduler_->Init(&item);
if (!status.ok()) {

View File

@ -55,7 +55,6 @@ class AnalyticalCostEstimator : public CostEstimator {
bool use_aggressive_shape_inference);
~AnalyticalCostEstimator() override {}
// Initializes the estimator for the specified grappler item.
// This implementation always returns OK.
Status Initialize(const GrapplerItem& item) override;
@ -68,7 +67,7 @@ class AnalyticalCostEstimator : public CostEstimator {
const VirtualScheduler* GetScheduler() const { return scheduler_.get(); }
private:
GrapplerItem item_;
const GrapplerItem* item_;
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
std::unique_ptr<VirtualScheduler> scheduler_;

View File

@ -37,7 +37,7 @@ Status GraphMemory::InferStatically(
TF_RETURN_IF_ERROR(cluster.Provision());
TF_RETURN_IF_ERROR(cluster.Initialize(item_));
RunMetadata metadata;
Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
Status s = cluster.Run(item_, &metadata);
// The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
// that the model would run out of memory. We still get the metadata we need
// out of the simulation, so we just ignore this error.
@ -55,8 +55,7 @@ Status GraphMemory::InferDynamically(Cluster* cluster) {
TF_RETURN_IF_ERROR(cluster->Initialize(item_));
RunMetadata metadata;
TF_RETURN_IF_ERROR(
cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
TF_RETURN_IF_ERROR(cluster->Run(item_, &metadata));
InferFromTrace(metadata.step_stats());
return Status::OK();
}

View File

@ -69,7 +69,7 @@ class GraphMemory {
void InferFromTrace(const StepStats& timeline);
GrapplerItem item_;
const GrapplerItem& item_;
std::unordered_map<string, int64> worst_case_memory_usage_;
std::unordered_map<string, MemoryUsage> peak_usage_;
const MemoryUsage unknown_usage_;