diff --git a/tensorflow/core/common_runtime/costmodel_manager.cc b/tensorflow/core/common_runtime/costmodel_manager.cc index 6119eeba5c1..53257d1f9b9 100644 --- a/tensorflow/core/common_runtime/costmodel_manager.cc +++ b/tensorflow/core/common_runtime/costmodel_manager.cc @@ -42,6 +42,17 @@ CostModel* CostModelManager::FindOrCreateCostModel(const Graph* graph) { return cost_model; } +bool CostModelManager::RemoveCostModelForGraph(const Graph* graph) { + mutex_lock l(mu_); + auto itr = cost_models_.find(graph); + if (itr == cost_models_.end()) { + return false; + } + delete itr->second; + cost_models_.erase(graph); + return true; +} + Status CostModelManager::AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph) { mutex_lock l(mu_); diff --git a/tensorflow/core/common_runtime/costmodel_manager.h b/tensorflow/core/common_runtime/costmodel_manager.h index f5e945a6b87..e8ea2498b91 100644 --- a/tensorflow/core/common_runtime/costmodel_manager.h +++ b/tensorflow/core/common_runtime/costmodel_manager.h @@ -41,6 +41,8 @@ class CostModelManager { CostModel* FindOrCreateCostModel(const Graph* graph); + bool RemoveCostModelForGraph(const Graph* graph); + Status AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph); private: diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 577f6617f79..5487785c8a8 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -53,6 +53,9 @@ GraphMgr::~GraphMgr() { GraphMgr::Item::~Item() { for (const auto& unit : this->units) { CHECK_NOTNULL(unit.device); + if (!graph_mgr->skip_cost_models_) { + graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph); + } delete unit.root; delete unit.lib; unit.device->op_segment()->RemoveHold(this->session); @@ -139,6 +142,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, Status s; item->units.reserve(partitions.size()); + item->graph_mgr = this; const auto& optimizer_opts = graph_options.optimizer_options(); GraphOptimizer optimizer(optimizer_opts); for (auto&& p : partitions) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index a8994f14834..a3771e67473 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -118,6 +118,10 @@ class GraphMgr { // A graph is partitioned over multiple devices. Each partition // has a root executor which may call into the runtime library. std::vector<ExecutionUnit> units; + + // Used to deresgister a cost model when cost model is requried in graph + // manager. + GraphMgr* graph_mgr; }; // Not owned.