From 8c521f81b10c60de8e690749b732c85197798d03 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Thu, 19 Sep 2019 15:56:01 -0700 Subject: [PATCH] Automated rollback of commit c00a1c6dc1161e9467215938af703a5363bdcaff PiperOrigin-RevId: 270144377 --- tensorflow/c/eager/c_api.cc | 2 +- .../base_rendezvous_mgr.cc | 22 +++++----- .../cluster_function_library_runtime.cc | 12 +++--- .../eager/eager_service_impl.cc | 10 ++--- .../rpc/rpc_rendezvous_mgr.cc | 10 ++--- .../core/distributed_runtime/session_mgr.cc | 16 +++---- .../distributed_runtime/session_mgr_test.cc | 4 +- tensorflow/core/distributed_runtime/worker.cc | 18 ++++---- .../distributed_runtime/worker_session.cc | 26 +++++------ .../core/distributed_runtime/worker_session.h | 43 +++++++------------ 10 files changed, 75 insertions(+), 88 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 9fcb9293ee7..10a1fa42f57 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -281,7 +281,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( std::move(server), grpc_server->worker_env(), worker_session, std::move(remote_eager_workers), std::move(remote_device_mgr), remote_workers, context_id, r, device_mgr, keep_alive_secs, - worker_session->cluster_flr(), std::move(remote_mgr))); + worker_session->cluster_flr.get(), std::move(remote_mgr))); // NOTE: We start the server after all other initialization, because the // GrpcServer cannot be destroyed after it is started. diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 0c804db7eb4..797da69b02b 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -154,13 +154,13 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { { mutex_lock l(mu_); if (session_ != nullptr) { - if (session_->worker_name() == session->worker_name()) { + if (session_->worker_name == session->worker_name) { LOG(INFO) << "Skipping rendezvous re-initialization."; return Status::OK(); } Status s = errors::Internal( "Double init! Worker names would have changed from: ", - session_->worker_name(), " -> ", session->worker_name()); + session_->worker_name, " -> ", session->worker_name); LOG(WARNING) << s; return s; } @@ -191,10 +191,10 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, tf_shared_lock l(mu_); if (!status_.ok()) return status_; DCHECK(is_initialized_locked()); - if (!IsLocalDevice(session_->worker_name(), parsed.src_device)) { + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { return errors::InvalidArgument( "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", - session_->worker_name()); + session_->worker_name); } } // Buffers "val" and "device_context" in local_. @@ -214,15 +214,13 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, } sess = session_; } - if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) { - return errors::InvalidArgument( - "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", - sess->worker_name()); + if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { + return errors::InvalidArgument("Invalid rendezvous key (src): ", + parsed.FullKey(), " @ ", sess->worker_name); } - if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) { - return errors::InvalidArgument( - "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ", - sess->worker_name()); + if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { + return errors::InvalidArgument("Invalid rendezvous key (dst): ", + parsed.FullKey(), " @ ", sess->worker_name); } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 3492949dafa..e9133fd45c6 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -160,8 +160,8 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { for (auto& function_data : function_data_) { - worker_session_->worker_cache()->ReleaseWorker(function_data.target, - function_data.wi); + worker_session_->worker_cache->ReleaseWorker(function_data.target, + function_data.wi); } } @@ -172,11 +172,11 @@ Status ClusterFunctionLibraryRuntime::Instantiate( VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target << " (this: " << this << ")"; WorkerInterface* wi = - worker_session_->worker_cache()->GetOrCreateWorker(options.target); + worker_session_->worker_cache->GetOrCreateWorker(options.target); if (wi == nullptr) { std::vector workers; - worker_session_->worker_cache()->ListWorkers(&workers); + worker_session_->worker_cache->ListWorkers(&workers); return errors::InvalidArgument( "Could not find worker with target: ", options.target, " Available workers: ", absl::StrJoin(workers, ", ")); @@ -199,7 +199,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate( } RegisterGraphRequest req; - req.set_session_handle(worker_session_->session_name()); + req.set_session_handle(worker_session_->session_name); req.set_create_worker_session_called(create_worker_session_called_); *req.mutable_graph_def() = std::move(gdef); req.mutable_graph_options() @@ -237,7 +237,7 @@ void ClusterFunctionLibraryRuntime::Run( } RunGraphRequest* req = new RunGraphRequest; - req->set_session_handle(worker_session_->session_name()); + req->set_session_handle(worker_session_->session_name); req->set_create_worker_session_called(create_worker_session_called_); req->set_graph_handle(function_data->graph_handle); req->set_step_id(opts.step_id); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 05c0343f040..0b3c8b5d449 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -132,20 +132,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), device_mgr, false, r, GetDefaultCustomKernelCreator(), - worker_session->cluster_flr()); + worker_session->cluster_flr.get()); // Ownership will be transferred to the ServerContext, or else in an error // case ctx will be deleted by this unref. core::ScopedUnref unref_ctx(ctx); std::vector remote_workers; - worker_session->worker_cache()->ListWorkers(&remote_workers); + worker_session->worker_cache->ListWorkers(&remote_workers); remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), - worker_session->worker_name()), + worker_session->worker_name), remote_workers.end()); std::unique_ptr remote_eager_workers; - TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache( - &remote_eager_workers)); + TF_RETURN_IF_ERROR( + worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers)); auto remote_mgr = absl::make_unique(/*is_master=*/false, ctx); diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 011d668f56a..ba68fe4afb8 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -225,10 +225,10 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( } WorkerSession* sess = session(); // The worker will be released in a subsequent call to - // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been + // `sess->worker_cache->ReleaseWorker()` (if the call has not yet been // initialized) or `call->ReleaseWorker()` (if it has been initialized). WorkerInterface* rwi = - sess->worker_cache()->GetOrCreateWorker(call->src_worker_); + sess->worker_cache->GetOrCreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } @@ -239,7 +239,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( } if (!s.ok()) { if (rwi != nullptr) { - sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi); + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); } get_call_freelist()->Release(call); done(s, Args(), recv_args, Tensor{}, false); @@ -257,7 +257,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // NOTE: `*sess` can potentially be deleted before we return from // `call->done()(...)`, so we must release the worker before calling the // callback. - call->ReleaseWorker(sess->worker_cache()); + call->ReleaseWorker(sess->worker_cache.get()); call->done()(call->status(), Args(), Args(), Tensor(), false); get_call_freelist()->Release(call); return; @@ -274,7 +274,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // NOTE: `*session()` can potentially be deleted before we return from // `call->done()(...)`, so we must release the worker before calling the // callback. - call->ReleaseWorker(session()->worker_cache()); + call->ReleaseWorker(session()->worker_cache.get()); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); get_call_freelist()->Release(call); Unref(); diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index a9be90fc5a5..ce80cb7e048 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -71,7 +71,7 @@ Status SessionMgr::CreateSession( string worker_name; if (server_def.cluster().job().empty()) { worker_cache = new WorkerCacheWrapper(default_worker_cache_.get()); - worker_name = legacy_session_->worker_name(); + worker_name = legacy_session_->worker_name; } else { TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); worker_name = WorkerNameFromServerDef(server_def); @@ -162,7 +162,7 @@ Status SessionMgr::WorkerSessionForSessionLocked( if (it == sessions_.end()) { return errors::Aborted("Session handle is not found: ", session_handle, ". Possibly this worker (\"", - legacy_session_->worker_name(), + legacy_session_->worker_name, "\") just restarted."); } else { *out_session = it->second; @@ -186,7 +186,7 @@ void SessionMgr::SetLogging(bool active) { this->is_logging_active_ = active; // Legacy Session if (legacy_session_) { - auto* worker_cache = legacy_session_->worker_cache(); + auto* worker_cache = legacy_session_->worker_cache.get(); if (worker_cache) { worker_cache->SetLogging(active); } @@ -195,7 +195,7 @@ void SessionMgr::SetLogging(bool active) { for (const auto& session_kv : sessions_) { auto session = session_kv.second.get(); if (session) { - auto* worker_cache = session->worker_cache(); + auto* worker_cache = session->worker_cache.get(); if (worker_cache) { worker_cache->SetLogging(active); } @@ -208,7 +208,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id, mutex_lock l(mu_); // Legacy Session if (legacy_session_) { - auto* worker_cache = legacy_session_->worker_cache(); + auto* worker_cache = legacy_session_->worker_cache.get(); if (worker_cache) { auto step_stats = StepStats(); if (worker_cache->RetrieveLogs(step_id, &step_stats)) { @@ -221,7 +221,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id, for (const auto& session_kv : sessions_) { auto session = session_kv.second.get(); if (session) { - auto* worker_cache = session->worker_cache(); + auto* worker_cache = session->worker_cache.get(); if (worker_cache) { auto step_stats = StepStats(); if (worker_cache->RetrieveLogs(step_id, &step_stats)) { @@ -238,7 +238,7 @@ void SessionMgr::ClearLogs() { mutex_lock l(mu_); // Legacy Session if (legacy_session_) { - auto* worker_cache = legacy_session_->worker_cache(); + auto* worker_cache = legacy_session_->worker_cache.get(); if (worker_cache) { worker_cache->ClearLogs(); } @@ -247,7 +247,7 @@ void SessionMgr::ClearLogs() { for (const auto& session_kv : sessions_) { auto session = session_kv.second.get(); if (session) { - auto* worker_cache = session->worker_cache(); + auto* worker_cache = session->worker_cache.get(); if (worker_cache) { worker_cache->ClearLogs(); } diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index f6e0551ff56..30a102b310f 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -102,7 +102,7 @@ TEST_F(SessionMgrTest, CreateSessionClusterDefWorkerName) { EXPECT_TRUE(device->IsLocal()); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name()); + EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } @@ -113,7 +113,7 @@ TEST_F(SessionMgrTest, CreateSessionDefaultWorkerName) { std::shared_ptr session; TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session)); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name()); + EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 908b3b9ff6f..686714bae84 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -77,10 +77,10 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, session = env_->session_mgr->LegacySession(); } if (s.ok()) { - s = session->graph_mgr()->Register( + s = session->graph_mgr->Register( request->session_handle(), request->graph_def(), session.get(), request->graph_options(), request->debug_options(), - request->collective_graph_key(), session->cluster_flr(), + request->collective_graph_key(), session->cluster_flr.get(), response->mutable_graph_handle()); } done(s); @@ -98,7 +98,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, session = env_->session_mgr->LegacySession(); } if (s.ok()) { - s = session->graph_mgr()->Deregister(request->graph_handle()); + s = session->graph_mgr->Deregister(request->graph_handle()); } done(s); @@ -218,14 +218,14 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, done(errors::Aborted("Call was aborted")); return; } - session->graph_mgr()->ExecuteAsync( + session->graph_mgr->ExecuteAsync( request->graph_handle(), step_id, session.get(), request->exec_opts(), collector, response, cm, in, [this, step_id, response, session, cm, out, token, collector, profiler_session, opts, done](const Status& status) { Status s = status; if (s.ok()) { - s = session->graph_mgr()->RecvOutputs(step_id, out); + s = session->graph_mgr->RecvOutputs(step_id, out); } opts->ClearCancelCallback(); @@ -311,7 +311,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, token = cancellation_manager_.get_cancellation_token(); cancellation_manager_.RegisterCallback(token, [cm]() { cm->StartCancel(); }); - session->graph_mgr()->ExecuteAsync( + session->graph_mgr->ExecuteAsync( graph_handle, step_id, session.get(), request->exec_opts(), nullptr /* collector */, nullptr /* response */, cm, in, [this, token, step_id, session](Status s) { @@ -320,14 +320,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts, }); } else { // Send the partial run's new inputs. - s = session->graph_mgr()->SendInputs(step_id, in); + s = session->graph_mgr->SendInputs(step_id, in); if (!s.ok()) { finish(s); return; } } - session->graph_mgr()->RecvOutputsAsync( + session->graph_mgr->RecvOutputsAsync( step_id, out, [this, out, request, response, step_id, finish](Status s) { if (s.ok()) { // Construct and return the resp. @@ -442,7 +442,7 @@ Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, "RecvTensor expects a different device incarnation: ", parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), ". Your worker job (\"", - env_->session_mgr->LegacySession()->worker_name(), + env_->session_mgr->LegacySession()->worker_name, "\") was probably restarted. Check your " "worker job for the reason why it was restarted."); } diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 3a612c229ba..b3459b64ebb 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -109,11 +109,11 @@ WorkerSession::WorkerSession( std::unique_ptr worker_cache, std::unique_ptr device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) - : session_name_(session_name), - worker_name_(worker_name), - worker_cache_(new WorkerFreeListCache(std::move(worker_cache))), - graph_mgr_(std::move(graph_mgr)), - cluster_flr_(new ClusterFunctionLibraryRuntime( + : session_name(session_name), + worker_name(worker_name), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + graph_mgr(std::move(graph_mgr)), + cluster_flr(new ClusterFunctionLibraryRuntime( this, !session_name.empty(), remote_device_mgr ? remote_device_mgr.get() : nullptr)), device_mgr_(std::move(device_mgr)), @@ -142,12 +142,12 @@ WorkerSession::WorkerSession( std::unique_ptr worker_cache, DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) - : session_name_(session_name), - worker_name_(worker_name), - worker_cache_(new WorkerFreeListCache(std::move(worker_cache))), - graph_mgr_(std::move(graph_mgr)), - cluster_flr_(new ClusterFunctionLibraryRuntime( - this, !session_name.empty(), remote_device_mgr.get())), + : session_name(session_name), + worker_name(worker_name), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + graph_mgr(std::move(graph_mgr)), + cluster_flr(new ClusterFunctionLibraryRuntime(this, !session_name.empty(), + remote_device_mgr.get())), device_mgr_(nullptr), borrowed_device_mgr_(borrowed_device_mgr), remote_device_mgr_(std::move(remote_device_mgr)) { @@ -159,8 +159,8 @@ WorkerSession::WorkerSession( } WorkerSession::~WorkerSession() { - if (graph_mgr_) { - Status s = graph_mgr_->DeregisterAll(); + if (graph_mgr) { + Status s = graph_mgr->DeregisterAll(); if (!s.ok()) { LOG(WARNING) << "Error during worker session deletion: " << s; } diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 83dc39f9463..30e31645a21 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -30,8 +30,16 @@ class GraphMgr; class WorkerCacheInterface; // WorkerSession encapsulates all of the state relating to a given session. -class WorkerSession { - public: +struct WorkerSession { + // The name of the session. + const string session_name; + + // The name of the worker. E.g., /job:mnist/replica:0/task:1. + const string worker_name; + + // Object from which WorkerInterface instances can be obtained. + const std::unique_ptr worker_cache; + // Collection of local devices. These devices are typically // RenamedDevices in all except the SessionMgr.legacy_session_ and // sessions created with `isolate_session_state == false`. In the @@ -43,15 +51,13 @@ class WorkerSession { DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); } - const string& session_name() const { return session_name_; } - const string& worker_name() const { return worker_name_; } + // graph_mgr keeps track of the registered graphs of this session. + // + // Note: graph_mgr must be deleted before rendezvous_mgr! + // Note: graph_mgr must be deleted before device_mgr! + const std::unique_ptr graph_mgr; - WorkerCacheInterface* worker_cache() const { return worker_cache_.get(); } - GraphMgr* graph_mgr() const { return graph_mgr_.get(); } - - ClusterFunctionLibraryRuntime* cluster_flr() const { - return cluster_flr_.get(); - } + std::unique_ptr cluster_flr; WorkerSession(const string& session_name, const string& worker_name, std::unique_ptr worker_cache, @@ -74,23 +80,6 @@ class WorkerSession { std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); - // The name of the session. - const string session_name_; - - // The name of the worker. E.g., /job:mnist/replica:0/task:1. - const string worker_name_; - - // Object from which WorkerInterface instances can be obtained. - const std::unique_ptr worker_cache_; - - // graph_mgr keeps track of the registered graphs of this session. - // - // Note: graph_mgr must be deleted before rendezvous_mgr! - // Note: graph_mgr must be deleted before device_mgr! - const std::unique_ptr graph_mgr_; - - std::unique_ptr cluster_flr_; - const std::unique_ptr device_mgr_; DeviceMgr* const borrowed_device_mgr_; // Not owned. const std::unique_ptr remote_device_mgr_;