Get rid of a number of gratuitous graph copies in Grappler code.
PiperOrigin-RevId: 249376925
This commit is contained in:
parent
74f30571be
commit
731409f8d2
@ -128,6 +128,11 @@ class Cluster {
|
||||
const std::vector<string>& fetch,
|
||||
RunMetadata* metadata) = 0;
|
||||
|
||||
// Run the specified GrapplerItem and return the corresponding metadata.
|
||||
virtual Status Run(const GrapplerItem& item, RunMetadata* metadata) {
|
||||
return Run(item.graph, item.feed, item.fetch, metadata);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unordered_map<string, DeviceProperties> devices_;
|
||||
const int timeout_s_;
|
||||
|
@ -67,14 +67,17 @@ Status VirtualCluster::Run(const GraphDef& graph,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
const std::vector<string>& fetch,
|
||||
RunMetadata* metadata) {
|
||||
// Initializes an analytical cost estimator to estimate the graph cost. Makes
|
||||
// sure to use static shape inference to prevent the virtual scheduler from
|
||||
// calling the Run method on the cluster and creating an infinite loop.
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
item.feed = feed;
|
||||
item.fetch = fetch;
|
||||
return Run(item, metadata);
|
||||
}
|
||||
|
||||
Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) {
|
||||
// Initializes an analytical cost estimator to estimate the graph cost. Makes
|
||||
// sure to use static shape inference to prevent the virtual scheduler from
|
||||
// calling the Run method on the cluster and creating an infinite loop.
|
||||
if (metadata) {
|
||||
metadata->clear_step_stats();
|
||||
metadata->clear_cost_graph();
|
||||
|
@ -45,9 +45,10 @@ class VirtualCluster : public Cluster {
|
||||
|
||||
Status Provision() override;
|
||||
Status Initialize(const GrapplerItem& item) override;
|
||||
Status Run(const GraphDef& item,
|
||||
Status Run(const GraphDef& graph,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
const std::vector<string>& fetch, RunMetadata* metadata) override;
|
||||
Status Run(const GrapplerItem& item, RunMetadata* metadata) override;
|
||||
const DeviceSet* GetDeviceSet() const override { return device_set_; }
|
||||
|
||||
private:
|
||||
|
@ -142,15 +142,15 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||
}
|
||||
|
||||
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||
item_ = item;
|
||||
item_ = &item;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
RunMetadata* run_metadata,
|
||||
Costs* costs) const {
|
||||
GrapplerItem item = item_;
|
||||
item.graph = optimized_graph;
|
||||
GraphDef graph_copy = optimized_graph;
|
||||
GrapplerItem item = item_->WithGraph(std::move(graph_copy));
|
||||
|
||||
auto status = scheduler_->Init(&item);
|
||||
if (!status.ok()) {
|
||||
|
@ -55,7 +55,6 @@ class AnalyticalCostEstimator : public CostEstimator {
|
||||
bool use_aggressive_shape_inference);
|
||||
~AnalyticalCostEstimator() override {}
|
||||
|
||||
// Initializes the estimator for the specified grappler item.
|
||||
// This implementation always returns OK.
|
||||
Status Initialize(const GrapplerItem& item) override;
|
||||
|
||||
@ -68,7 +67,7 @@ class AnalyticalCostEstimator : public CostEstimator {
|
||||
const VirtualScheduler* GetScheduler() const { return scheduler_.get(); }
|
||||
|
||||
private:
|
||||
GrapplerItem item_;
|
||||
const GrapplerItem* item_;
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
|
||||
std::unique_ptr<ReadyNodeManager> node_manager_;
|
||||
std::unique_ptr<VirtualScheduler> scheduler_;
|
||||
|
@ -37,7 +37,7 @@ Status GraphMemory::InferStatically(
|
||||
TF_RETURN_IF_ERROR(cluster.Provision());
|
||||
TF_RETURN_IF_ERROR(cluster.Initialize(item_));
|
||||
RunMetadata metadata;
|
||||
Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
|
||||
Status s = cluster.Run(item_, &metadata);
|
||||
// The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
|
||||
// that the model would run out of memory. We still get the metadata we need
|
||||
// out of the simulation, so we just ignore this error.
|
||||
@ -55,8 +55,7 @@ Status GraphMemory::InferDynamically(Cluster* cluster) {
|
||||
|
||||
TF_RETURN_IF_ERROR(cluster->Initialize(item_));
|
||||
RunMetadata metadata;
|
||||
TF_RETURN_IF_ERROR(
|
||||
cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
|
||||
TF_RETURN_IF_ERROR(cluster->Run(item_, &metadata));
|
||||
InferFromTrace(metadata.step_stats());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ class GraphMemory {
|
||||
|
||||
void InferFromTrace(const StepStats& timeline);
|
||||
|
||||
GrapplerItem item_;
|
||||
const GrapplerItem& item_;
|
||||
std::unordered_map<string, int64> worst_case_memory_usage_;
|
||||
std::unordered_map<string, MemoryUsage> peak_usage_;
|
||||
const MemoryUsage unknown_usage_;
|
||||
|
Loading…
Reference in New Issue
Block a user