Avoid re-creating cluster FLR and process FLR when updating worker context.

PiperOrigin-RevId: 308153596
Change-Id: I2f8bc377213855fd5eeb6307fd1a2e74d26f4140
This commit is contained in:
Haoyu Zhang 2020-04-23 16:31:01 -07:00 committed by TensorFlower Gardener
parent 3e2ea0fcbe
commit 70a0ea2da3
6 changed files with 30 additions and 51 deletions
tensorflow

View File

@ -611,13 +611,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
}
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
if (reset_context) {
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
@ -647,10 +646,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r));
/*isolate_session_state=*/true));
LOG_AND_RETURN_IF_ERROR(
context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers),
added_workers, removed_workers));
}
#undef LOG_AND_RETURN_IF_ERROR

View File

@ -1096,11 +1096,10 @@ Status EagerContext::InitializeRemoteMaster(
}
Status EagerContext::UpdateRemoteMaster(
WorkerEnv* worker_env,
uint64 context_id,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
const std::vector<string>& add_remote_contexts,
const std::vector<string>& remove_remote_contexts, uint64 context_id,
Rendezvous* r) {
const std::vector<string>& remove_remote_contexts) {
{
tf_shared_lock l(remote_state_mu_);
if (context_id != context_id_) {
@ -1136,9 +1135,6 @@ Status EagerContext::UpdateRemoteMaster(
mutex_lock l(remote_state_mu_);
context_view_id_++;
worker_env_ = worker_env;
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
InitPrioritizedDeviceTypeList();
@ -1338,11 +1334,8 @@ Status EagerContext::InitializeRemoteWorker(
}
Status EagerContext::UpdateRemoteWorker(
const DeviceMgr* worker_session_device_mgr,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
DynamicDeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
DistributedFunctionLibraryRuntime* cluster_flr) {
const std::vector<string>& remote_contexts, uint64 context_id) {
{
mutex_lock l(remote_state_mu_);
if (context_id != context_id_) {
@ -1352,15 +1345,20 @@ Status EagerContext::UpdateRemoteWorker(
" but current id = ", context_id_);
}
context_view_id_++;
remote_contexts_ = remote_contexts;
remote_eager_workers_ = std::move(remote_eager_workers);
InitPrioritizedDeviceTypeList();
pflr_->InitializeDeviceSet();
}
remote_contexts_ = remote_contexts;
remote_eager_workers_ = std::move(remote_eager_workers);
ResetClusterFLR(cluster_flr);
remote_device_manager_.Reset(remote_device_mgr);
InitPrioritizedDeviceTypeList();
// No need to update remote_device_manager_ since it's not owned for remote
// worker context (owned by the corresponding worker session).
if (remote_device_manager_.Owned()) {
return errors::FailedPrecondition(
"EagerContext::UpdateRemoteWorker failed because the context was "
"initialized as a master context.");
}
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
@ -1370,13 +1368,6 @@ Status EagerContext::UpdateRemoteWorker(
entry.second->ClearError();
}
}
SessionOptions options = SessionOptions();
const auto* config = pflr_->config();
ResetPFLR(worker_session_device_mgr, options.env, config,
TF_GRAPH_DEF_VERSION, FuncLibDef(),
config->graph_options().optimizer_options(), thread_pool_.get(),
cluster_flr_.Get(), custom_kernel_creator_);
return Status::OK();
}
#endif // !IS_MOBILE_PLATFORM

View File

@ -385,11 +385,10 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
// can still be accessed, and will automatically register existing functions
// if there are newly added hosts.
Status UpdateRemoteMaster(
WorkerEnv* worker_env,
uint64 context_id,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
const std::vector<string>& add_remote_contexts,
const std::vector<string>& remove_remote_contexts, uint64 context_id,
Rendezvous* r);
const std::vector<string>& remove_remote_contexts);
// Similar with InitializeRemoteMaster but this context will not kill remote
// contexts in shutdown.
@ -407,11 +406,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
// Similar with InitializeRemoteWorker but will reuse existing context and
// increment context_view_id.
Status UpdateRemoteWorker(
const DeviceMgr* worker_session_device_mgr,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
DynamicDeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
DistributedFunctionLibraryRuntime* cluster_flr);
const std::vector<string>& remote_contexts, uint64 context_id);
Status StoreCollectiveOpsServer(
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,

View File

@ -286,14 +286,9 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
&remote_eager_workers));
DistributedFunctionLibraryRuntime* cluster_flr =
eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
ctx->ClearCachesAndThreadExecutors();
Status s = ctx->UpdateRemoteWorker(
device_mgr, std::move(remote_eager_workers),
worker_session->remote_device_mgr(), remote_workers,
request->context_id(), cluster_flr);
Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
remote_workers, request->context_id());
if (!s.ok()) {
VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
return s;

View File

@ -76,7 +76,8 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
if (!s.ok()) {
LOG(ERROR) << "Ignoring an error encountered when setting "
"remote shape of tensor handle: "
<< retvals[i] << " with status: " << status.ToString()
<< retvals[i]
<< " with execute status: " << status.ToString()
<< " and SetRemoteShape status: " << s.ToString()
<< "\nThis should never happen. "
"Please file an issue with the TensorFlow Team.";

View File

@ -134,9 +134,6 @@ Status WorkerSession::UpdateWorkerCacheAndDevices(
TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
TF_RETURN_IF_ERROR(
remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
cluster_flr_ = std::unique_ptr<ClusterFunctionLibraryRuntime>(
new ClusterFunctionLibraryRuntime(this, !session_name_.empty(),
remote_device_mgr()));
return Status::OK();
}