Convert WorkerSession from a struct to a class.

PiperOrigin-RevId: 270104608
This commit is contained in:
Haoyu Zhang 2019-09-19 12:54:01 -07:00 committed by TensorFlower Gardener
parent 39db8af832
commit c00a1c6dc1
10 changed files with 88 additions and 75 deletions

View File

@ -281,7 +281,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
std::move(server), grpc_server->worker_env(), worker_session, std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr), std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, remote_workers, context_id, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get(), std::move(remote_mgr))); worker_session->cluster_flr(), std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the // NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started. // GrpcServer cannot be destroyed after it is started.

View File

@ -154,13 +154,13 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (session_ != nullptr) { if (session_ != nullptr) {
if (session_->worker_name == session->worker_name) { if (session_->worker_name() == session->worker_name()) {
LOG(INFO) << "Skipping rendezvous re-initialization."; LOG(INFO) << "Skipping rendezvous re-initialization.";
return Status::OK(); return Status::OK();
} }
Status s = errors::Internal( Status s = errors::Internal(
"Double init! Worker names would have changed from: ", "Double init! Worker names would have changed from: ",
session_->worker_name, " -> ", session->worker_name); session_->worker_name(), " -> ", session->worker_name());
LOG(WARNING) << s; LOG(WARNING) << s;
return s; return s;
} }
@ -191,10 +191,10 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
if (!status_.ok()) return status_; if (!status_.ok()) return status_;
DCHECK(is_initialized_locked()); DCHECK(is_initialized_locked());
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { if (!IsLocalDevice(session_->worker_name(), parsed.src_device)) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name); session_->worker_name());
} }
} }
// Buffers "val" and "device_context" in local_. // Buffers "val" and "device_context" in local_.
@ -214,13 +214,15 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
} }
sess = session_; sess = session_;
} }
if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", return errors::InvalidArgument(
parsed.FullKey(), " @ ", sess->worker_name); "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
sess->worker_name());
} }
if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ", return errors::InvalidArgument(
parsed.FullKey(), " @ ", sess->worker_name); "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
sess->worker_name());
} }
return Status::OK(); return Status::OK();
} }

View File

@ -160,7 +160,7 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
for (auto& function_data : function_data_) { for (auto& function_data : function_data_) {
worker_session_->worker_cache->ReleaseWorker(function_data.target, worker_session_->worker_cache()->ReleaseWorker(function_data.target,
function_data.wi); function_data.wi);
} }
} }
@ -172,11 +172,11 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target
<< " (this: " << this << ")"; << " (this: " << this << ")";
WorkerInterface* wi = WorkerInterface* wi =
worker_session_->worker_cache->GetOrCreateWorker(options.target); worker_session_->worker_cache()->GetOrCreateWorker(options.target);
if (wi == nullptr) { if (wi == nullptr) {
std::vector<string> workers; std::vector<string> workers;
worker_session_->worker_cache->ListWorkers(&workers); worker_session_->worker_cache()->ListWorkers(&workers);
return errors::InvalidArgument( return errors::InvalidArgument(
"Could not find worker with target: ", options.target, "Could not find worker with target: ", options.target,
" Available workers: ", absl::StrJoin(workers, ", ")); " Available workers: ", absl::StrJoin(workers, ", "));
@ -199,7 +199,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
} }
RegisterGraphRequest req; RegisterGraphRequest req;
req.set_session_handle(worker_session_->session_name); req.set_session_handle(worker_session_->session_name());
req.set_create_worker_session_called(create_worker_session_called_); req.set_create_worker_session_called(create_worker_session_called_);
*req.mutable_graph_def() = std::move(gdef); *req.mutable_graph_def() = std::move(gdef);
req.mutable_graph_options() req.mutable_graph_options()
@ -237,7 +237,7 @@ void ClusterFunctionLibraryRuntime::Run(
} }
RunGraphRequest* req = new RunGraphRequest; RunGraphRequest* req = new RunGraphRequest;
req->set_session_handle(worker_session_->session_name); req->set_session_handle(worker_session_->session_name());
req->set_create_worker_session_called(create_worker_session_called_); req->set_create_worker_session_called(create_worker_session_called_);
req->set_graph_handle(function_data->graph_handle); req->set_graph_handle(function_data->graph_handle);
req->set_step_id(opts.step_id); req->set_step_id(opts.step_id);

View File

@ -132,20 +132,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, GetDefaultCustomKernelCreator(), device_mgr, false, r, GetDefaultCustomKernelCreator(),
worker_session->cluster_flr.get()); worker_session->cluster_flr());
// Ownership will be transferred to the ServerContext, or else in an error // Ownership will be transferred to the ServerContext, or else in an error
// case ctx will be deleted by this unref. // case ctx will be deleted by this unref.
core::ScopedUnref unref_ctx(ctx); core::ScopedUnref unref_ctx(ctx);
std::vector<string> remote_workers; std::vector<string> remote_workers;
worker_session->worker_cache->ListWorkers(&remote_workers); worker_session->worker_cache()->ListWorkers(&remote_workers);
remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
worker_session->worker_name), worker_session->worker_name()),
remote_workers.end()); remote_workers.end());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers)); &remote_eager_workers));
auto remote_mgr = auto remote_mgr =
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx); absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);

View File

@ -225,10 +225,10 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
} }
WorkerSession* sess = session(); WorkerSession* sess = session();
// The worker will be released in a subsequent call to // The worker will be released in a subsequent call to
// `sess->worker_cache->ReleaseWorker()` (if the call has not yet been // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
// initialized) or `call->ReleaseWorker()` (if it has been initialized). // initialized) or `call->ReleaseWorker()` (if it has been initialized).
WorkerInterface* rwi = WorkerInterface* rwi =
sess->worker_cache->GetOrCreateWorker(call->src_worker_); sess->worker_cache()->GetOrCreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) { if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_); s = errors::Internal("No worker known as ", call->src_worker_);
} }
@ -239,7 +239,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
} }
if (!s.ok()) { if (!s.ok()) {
if (rwi != nullptr) { if (rwi != nullptr) {
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
} }
get_call_freelist()->Release(call); get_call_freelist()->Release(call);
done(s, Args(), recv_args, Tensor{}, false); done(s, Args(), recv_args, Tensor{}, false);
@ -257,7 +257,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// NOTE: `*sess` can potentially be deleted before we return from // NOTE: `*sess` can potentially be deleted before we return from
// `call->done()(...)`, so we must release the worker before calling the // `call->done()(...)`, so we must release the worker before calling the
// callback. // callback.
call->ReleaseWorker(sess->worker_cache.get()); call->ReleaseWorker(sess->worker_cache());
call->done()(call->status(), Args(), Args(), Tensor(), false); call->done()(call->status(), Args(), Args(), Tensor(), false);
get_call_freelist()->Release(call); get_call_freelist()->Release(call);
return; return;
@ -274,7 +274,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// NOTE: `*session()` can potentially be deleted before we return from // NOTE: `*session()` can potentially be deleted before we return from
// `call->done()(...)`, so we must release the worker before calling the // `call->done()(...)`, so we must release the worker before calling the
// callback. // callback.
call->ReleaseWorker(session()->worker_cache.get()); call->ReleaseWorker(session()->worker_cache());
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
get_call_freelist()->Release(call); get_call_freelist()->Release(call);
Unref(); Unref();

View File

@ -71,7 +71,7 @@ Status SessionMgr::CreateSession(
string worker_name; string worker_name;
if (server_def.cluster().job().empty()) { if (server_def.cluster().job().empty()) {
worker_cache = new WorkerCacheWrapper(default_worker_cache_.get()); worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
worker_name = legacy_session_->worker_name; worker_name = legacy_session_->worker_name();
} else { } else {
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
worker_name = WorkerNameFromServerDef(server_def); worker_name = WorkerNameFromServerDef(server_def);
@ -162,7 +162,7 @@ Status SessionMgr::WorkerSessionForSessionLocked(
if (it == sessions_.end()) { if (it == sessions_.end()) {
return errors::Aborted("Session handle is not found: ", session_handle, return errors::Aborted("Session handle is not found: ", session_handle,
". Possibly this worker (\"", ". Possibly this worker (\"",
legacy_session_->worker_name, legacy_session_->worker_name(),
"\") just restarted."); "\") just restarted.");
} else { } else {
*out_session = it->second; *out_session = it->second;
@ -186,7 +186,7 @@ void SessionMgr::SetLogging(bool active) {
this->is_logging_active_ = active; this->is_logging_active_ = active;
// Legacy Session // Legacy Session
if (legacy_session_) { if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache.get(); auto* worker_cache = legacy_session_->worker_cache();
if (worker_cache) { if (worker_cache) {
worker_cache->SetLogging(active); worker_cache->SetLogging(active);
} }
@ -195,7 +195,7 @@ void SessionMgr::SetLogging(bool active) {
for (const auto& session_kv : sessions_) { for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get(); auto session = session_kv.second.get();
if (session) { if (session) {
auto* worker_cache = session->worker_cache.get(); auto* worker_cache = session->worker_cache();
if (worker_cache) { if (worker_cache) {
worker_cache->SetLogging(active); worker_cache->SetLogging(active);
} }
@ -208,7 +208,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
mutex_lock l(mu_); mutex_lock l(mu_);
// Legacy Session // Legacy Session
if (legacy_session_) { if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache.get(); auto* worker_cache = legacy_session_->worker_cache();
if (worker_cache) { if (worker_cache) {
auto step_stats = StepStats(); auto step_stats = StepStats();
if (worker_cache->RetrieveLogs(step_id, &step_stats)) { if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
@ -221,7 +221,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
for (const auto& session_kv : sessions_) { for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get(); auto session = session_kv.second.get();
if (session) { if (session) {
auto* worker_cache = session->worker_cache.get(); auto* worker_cache = session->worker_cache();
if (worker_cache) { if (worker_cache) {
auto step_stats = StepStats(); auto step_stats = StepStats();
if (worker_cache->RetrieveLogs(step_id, &step_stats)) { if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
@ -238,7 +238,7 @@ void SessionMgr::ClearLogs() {
mutex_lock l(mu_); mutex_lock l(mu_);
// Legacy Session // Legacy Session
if (legacy_session_) { if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache.get(); auto* worker_cache = legacy_session_->worker_cache();
if (worker_cache) { if (worker_cache) {
worker_cache->ClearLogs(); worker_cache->ClearLogs();
} }
@ -247,7 +247,7 @@ void SessionMgr::ClearLogs() {
for (const auto& session_kv : sessions_) { for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get(); auto session = session_kv.second.get();
if (session) { if (session) {
auto* worker_cache = session->worker_cache.get(); auto* worker_cache = session->worker_cache();
if (worker_cache) { if (worker_cache) {
worker_cache->ClearLogs(); worker_cache->ClearLogs();
} }

View File

@ -102,7 +102,7 @@ TEST_F(SessionMgrTest, CreateSessionClusterDefWorkerName) {
EXPECT_TRUE(device->IsLocal()); EXPECT_TRUE(device->IsLocal());
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name); EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name());
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }
@ -113,7 +113,7 @@ TEST_F(SessionMgrTest, CreateSessionDefaultWorkerName) {
std::shared_ptr<WorkerSession> session; std::shared_ptr<WorkerSession> session;
TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session)); TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name); EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name());
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }

View File

@ -77,10 +77,10 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
session = env_->session_mgr->LegacySession(); session = env_->session_mgr->LegacySession();
} }
if (s.ok()) { if (s.ok()) {
s = session->graph_mgr->Register( s = session->graph_mgr()->Register(
request->session_handle(), request->graph_def(), session.get(), request->session_handle(), request->graph_def(), session.get(),
request->graph_options(), request->debug_options(), request->graph_options(), request->debug_options(),
request->collective_graph_key(), session->cluster_flr.get(), request->collective_graph_key(), session->cluster_flr(),
response->mutable_graph_handle()); response->mutable_graph_handle());
} }
done(s); done(s);
@ -98,7 +98,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
session = env_->session_mgr->LegacySession(); session = env_->session_mgr->LegacySession();
} }
if (s.ok()) { if (s.ok()) {
s = session->graph_mgr->Deregister(request->graph_handle()); s = session->graph_mgr()->Deregister(request->graph_handle());
} }
done(s); done(s);
@ -218,14 +218,14 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
done(errors::Aborted("Call was aborted")); done(errors::Aborted("Call was aborted"));
return; return;
} }
session->graph_mgr->ExecuteAsync( session->graph_mgr()->ExecuteAsync(
request->graph_handle(), step_id, session.get(), request->exec_opts(), request->graph_handle(), step_id, session.get(), request->exec_opts(),
collector, response, cm, in, collector, response, cm, in,
[this, step_id, response, session, cm, out, token, collector, [this, step_id, response, session, cm, out, token, collector,
profiler_session, opts, done](const Status& status) { profiler_session, opts, done](const Status& status) {
Status s = status; Status s = status;
if (s.ok()) { if (s.ok()) {
s = session->graph_mgr->RecvOutputs(step_id, out); s = session->graph_mgr()->RecvOutputs(step_id, out);
} }
opts->ClearCancelCallback(); opts->ClearCancelCallback();
@ -311,7 +311,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
token = cancellation_manager_.get_cancellation_token(); token = cancellation_manager_.get_cancellation_token();
cancellation_manager_.RegisterCallback(token, cancellation_manager_.RegisterCallback(token,
[cm]() { cm->StartCancel(); }); [cm]() { cm->StartCancel(); });
session->graph_mgr->ExecuteAsync( session->graph_mgr()->ExecuteAsync(
graph_handle, step_id, session.get(), request->exec_opts(), graph_handle, step_id, session.get(), request->exec_opts(),
nullptr /* collector */, nullptr /* response */, cm, in, nullptr /* collector */, nullptr /* response */, cm, in,
[this, token, step_id, session](Status s) { [this, token, step_id, session](Status s) {
@ -320,14 +320,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
}); });
} else { } else {
// Send the partial run's new inputs. // Send the partial run's new inputs.
s = session->graph_mgr->SendInputs(step_id, in); s = session->graph_mgr()->SendInputs(step_id, in);
if (!s.ok()) { if (!s.ok()) {
finish(s); finish(s);
return; return;
} }
} }
session->graph_mgr->RecvOutputsAsync( session->graph_mgr()->RecvOutputsAsync(
step_id, out, [this, out, request, response, step_id, finish](Status s) { step_id, out, [this, out, request, response, step_id, finish](Status s) {
if (s.ok()) { if (s.ok()) {
// Construct and return the resp. // Construct and return the resp.
@ -442,7 +442,7 @@ Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
"RecvTensor expects a different device incarnation: ", "RecvTensor expects a different device incarnation: ",
parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
". Your worker job (\"", ". Your worker job (\"",
env_->session_mgr->LegacySession()->worker_name, env_->session_mgr->LegacySession()->worker_name(),
"\") was probably restarted. Check your " "\") was probably restarted. Check your "
"worker job for the reason why it was restarted."); "worker job for the reason why it was restarted.");
} }

View File

@ -109,11 +109,11 @@ WorkerSession::WorkerSession(
std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
: session_name(session_name), : session_name_(session_name),
worker_name(worker_name), worker_name_(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))), worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr(std::move(graph_mgr)), graph_mgr_(std::move(graph_mgr)),
cluster_flr(new ClusterFunctionLibraryRuntime( cluster_flr_(new ClusterFunctionLibraryRuntime(
this, !session_name.empty(), this, !session_name.empty(),
remote_device_mgr ? remote_device_mgr.get() : nullptr)), remote_device_mgr ? remote_device_mgr.get() : nullptr)),
device_mgr_(std::move(device_mgr)), device_mgr_(std::move(device_mgr)),
@ -142,12 +142,12 @@ WorkerSession::WorkerSession(
std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<WorkerCacheInterface> worker_cache,
DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
: session_name(session_name), : session_name_(session_name),
worker_name(worker_name), worker_name_(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))), worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr(std::move(graph_mgr)), graph_mgr_(std::move(graph_mgr)),
cluster_flr(new ClusterFunctionLibraryRuntime(this, !session_name.empty(), cluster_flr_(new ClusterFunctionLibraryRuntime(
remote_device_mgr.get())), this, !session_name.empty(), remote_device_mgr.get())),
device_mgr_(nullptr), device_mgr_(nullptr),
borrowed_device_mgr_(borrowed_device_mgr), borrowed_device_mgr_(borrowed_device_mgr),
remote_device_mgr_(std::move(remote_device_mgr)) { remote_device_mgr_(std::move(remote_device_mgr)) {
@ -159,8 +159,8 @@ WorkerSession::WorkerSession(
} }
WorkerSession::~WorkerSession() { WorkerSession::~WorkerSession() {
if (graph_mgr) { if (graph_mgr_) {
Status s = graph_mgr->DeregisterAll(); Status s = graph_mgr_->DeregisterAll();
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Error during worker session deletion: " << s; LOG(WARNING) << "Error during worker session deletion: " << s;
} }

View File

@ -30,16 +30,8 @@ class GraphMgr;
class WorkerCacheInterface; class WorkerCacheInterface;
// WorkerSession encapsulates all of the state relating to a given session. // WorkerSession encapsulates all of the state relating to a given session.
struct WorkerSession { class WorkerSession {
// The name of the session. public:
const string session_name;
// The name of the worker. E.g., /job:mnist/replica:0/task:1.
const string worker_name;
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache;
// Collection of local devices. These devices are typically // Collection of local devices. These devices are typically
// RenamedDevices in all except the SessionMgr.legacy_session_ and // RenamedDevices in all except the SessionMgr.legacy_session_ and
// sessions created with `isolate_session_state == false`. In the // sessions created with `isolate_session_state == false`. In the
@ -51,13 +43,15 @@ struct WorkerSession {
DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); } DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); }
// graph_mgr keeps track of the registered graphs of this session. const string& session_name() const { return session_name_; }
// const string& worker_name() const { return worker_name_; }
// Note: graph_mgr must be deleted before rendezvous_mgr!
// Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr;
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr; WorkerCacheInterface* worker_cache() const { return worker_cache_.get(); }
GraphMgr* graph_mgr() const { return graph_mgr_.get(); }
ClusterFunctionLibraryRuntime* cluster_flr() const {
return cluster_flr_.get();
}
WorkerSession(const string& session_name, const string& worker_name, WorkerSession(const string& session_name, const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<WorkerCacheInterface> worker_cache,
@ -80,6 +74,23 @@ struct WorkerSession {
std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);
// The name of the session.
const string session_name_;
// The name of the worker. E.g., /job:mnist/replica:0/task:1.
const string worker_name_;
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache_;
// graph_mgr keeps track of the registered graphs of this session.
//
// Note: graph_mgr must be deleted before rendezvous_mgr!
// Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr_;
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
const std::unique_ptr<DeviceMgr> device_mgr_; const std::unique_ptr<DeviceMgr> device_mgr_;
DeviceMgr* const borrowed_device_mgr_; // Not owned. DeviceMgr* const borrowed_device_mgr_; // Not owned.
const std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_; const std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_;