diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index cacaf838165..1918eae8751 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -970,7 +970,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env, handle_(strings::FpToString(random::New64())), stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), - runs_(5), + run_graphs_(5), + partial_run_graphs_(5), cancellation_manager_(new CancellationManager) { UpdateLastAccessTime(); @@ -996,8 +997,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env, MasterSession::~MasterSession() { delete cancellation_manager_; - for (const auto& iter : runs_) iter.second->Unref(); - for (const auto& iter : obsolete_) iter.second->Unref(); + for (const auto& iter : run_graphs_) iter.second->Unref(); + for (const auto& iter : partial_run_graphs_) iter.second->Unref(); for (Device* dev : remote_devs_) delete dev; } @@ -1065,23 +1066,23 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, // this session. int64* c = &subgraph_execution_counts_[hash]; *count = (*c)++; - auto iter = runs_.find(hash); - if (iter == runs_.end()) { + // TODO(suharshs): We cache partial run graphs and run graphs separately + // because there is preprocessing that needs to only be run for partial + // run calls. + RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_; + auto iter = m->find(hash); + if (iter == m->end()) { // We have not seen this subgraph before. Build the subgraph and // cache it. VLOG(1) << "Unseen hash " << hash << " for " - << BuildGraphOptionsString(opts); + << BuildGraphOptionsString(opts) << " is_partial = " << is_partial + << "\n"; std::unique_ptr client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); auto entry = new ReffedClientGraph( handle_, opts, std::move(client_graph), session_opts_, stats_publisher_factory_, execution_state_.get(), is_partial); - iter = runs_.insert({hash, entry}).first; - auto obs_iter = obsolete_.find(hash); - if (obs_iter != obsolete_.end()) { - to_unref = obs_iter->second; - obsolete_.erase(obs_iter); - } + iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } *rcg = iter->second; @@ -1383,8 +1384,8 @@ Status MasterSession::Close() { while (num_running_ != 0) { num_running_is_zero_.wait(l); } - ClearRunsTable(&to_unref, &runs_); - ClearRunsTable(&to_unref, &obsolete_); + ClearRunsTable(&to_unref, &run_graphs_); + ClearRunsTable(&to_unref, &partial_run_graphs_); } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); delete this; diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 96d759d9c8d..4af6ab66819 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -119,8 +119,8 @@ class MasterSession { // scope and lose their state. class ReffedClientGraph; typedef std::unordered_map RCGMap; - RCGMap runs_ GUARDED_BY(mu_); - RCGMap obsolete_ GUARDED_BY(mu_); + RCGMap run_graphs_ GUARDED_BY(mu_); + RCGMap partial_run_graphs_ GUARDED_BY(mu_); struct PerStepState { bool collect_costs = false;