Avoid re-creating cluster FLR and process FLR when updating worker context.
PiperOrigin-RevId: 308153596 Change-Id: I2f8bc377213855fd5eeb6307fd1a2e74d26f4140
This commit is contained in:
parent
3e2ea0fcbe
commit
70a0ea2da3
tensorflow
c/eager
core
common_runtime/eager
distributed_runtime
@ -611,13 +611,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
@ -647,10 +646,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r));
|
||||
/*isolate_session_state=*/true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
|
@ -1096,11 +1096,10 @@ Status EagerContext::InitializeRemoteMaster(
|
||||
}
|
||||
|
||||
Status EagerContext::UpdateRemoteMaster(
|
||||
WorkerEnv* worker_env,
|
||||
uint64 context_id,
|
||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
const std::vector<string>& add_remote_contexts,
|
||||
const std::vector<string>& remove_remote_contexts, uint64 context_id,
|
||||
Rendezvous* r) {
|
||||
const std::vector<string>& remove_remote_contexts) {
|
||||
{
|
||||
tf_shared_lock l(remote_state_mu_);
|
||||
if (context_id != context_id_) {
|
||||
@ -1136,9 +1135,6 @@ Status EagerContext::UpdateRemoteMaster(
|
||||
mutex_lock l(remote_state_mu_);
|
||||
context_view_id_++;
|
||||
|
||||
worker_env_ = worker_env;
|
||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||
rendezvous_ = r;
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
pflr_->InitializeDeviceSet();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
@ -1338,11 +1334,8 @@ Status EagerContext::InitializeRemoteWorker(
|
||||
}
|
||||
|
||||
Status EagerContext::UpdateRemoteWorker(
|
||||
const DeviceMgr* worker_session_device_mgr,
|
||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
DynamicDeviceMgr* remote_device_mgr,
|
||||
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||
const std::vector<string>& remote_contexts, uint64 context_id) {
|
||||
{
|
||||
mutex_lock l(remote_state_mu_);
|
||||
if (context_id != context_id_) {
|
||||
@ -1352,15 +1345,20 @@ Status EagerContext::UpdateRemoteWorker(
|
||||
" but current id = ", context_id_);
|
||||
}
|
||||
context_view_id_++;
|
||||
|
||||
remote_contexts_ = remote_contexts;
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
InitPrioritizedDeviceTypeList();
|
||||
pflr_->InitializeDeviceSet();
|
||||
}
|
||||
|
||||
remote_contexts_ = remote_contexts;
|
||||
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
ResetClusterFLR(cluster_flr);
|
||||
|
||||
remote_device_manager_.Reset(remote_device_mgr);
|
||||
InitPrioritizedDeviceTypeList();
|
||||
// No need to update remote_device_manager_ since it's not owned for remote
|
||||
// worker context (owned by the corresponding worker session).
|
||||
if (remote_device_manager_.Owned()) {
|
||||
return errors::FailedPrecondition(
|
||||
"EagerContext::UpdateRemoteWorker failed because the context was "
|
||||
"initialized as a master context.");
|
||||
}
|
||||
|
||||
ClearCachesAndThreadExecutors();
|
||||
default_executor_.ClearError();
|
||||
@ -1370,13 +1368,6 @@ Status EagerContext::UpdateRemoteWorker(
|
||||
entry.second->ClearError();
|
||||
}
|
||||
}
|
||||
|
||||
SessionOptions options = SessionOptions();
|
||||
const auto* config = pflr_->config();
|
||||
ResetPFLR(worker_session_device_mgr, options.env, config,
|
||||
TF_GRAPH_DEF_VERSION, FuncLibDef(),
|
||||
config->graph_options().optimizer_options(), thread_pool_.get(),
|
||||
cluster_flr_.Get(), custom_kernel_creator_);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
@ -385,11 +385,10 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
// can still be accessed, and will automatically register existing functions
|
||||
// if there are newly added hosts.
|
||||
Status UpdateRemoteMaster(
|
||||
WorkerEnv* worker_env,
|
||||
uint64 context_id,
|
||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
const std::vector<string>& add_remote_contexts,
|
||||
const std::vector<string>& remove_remote_contexts, uint64 context_id,
|
||||
Rendezvous* r);
|
||||
const std::vector<string>& remove_remote_contexts);
|
||||
|
||||
// Similar with InitializeRemoteMaster but this context will not kill remote
|
||||
// contexts in shutdown.
|
||||
@ -407,11 +406,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
// Similar with InitializeRemoteWorker but will reuse existing context and
|
||||
// increment context_view_id.
|
||||
Status UpdateRemoteWorker(
|
||||
const DeviceMgr* worker_session_device_mgr,
|
||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
DynamicDeviceMgr* remote_device_mgr,
|
||||
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr);
|
||||
const std::vector<string>& remote_contexts, uint64 context_id);
|
||||
|
||||
Status StoreCollectiveOpsServer(
|
||||
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
|
||||
|
@ -286,14 +286,9 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
|
||||
TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
|
||||
|
||||
ctx->ClearCachesAndThreadExecutors();
|
||||
Status s = ctx->UpdateRemoteWorker(
|
||||
device_mgr, std::move(remote_eager_workers),
|
||||
worker_session->remote_device_mgr(), remote_workers,
|
||||
request->context_id(), cluster_flr);
|
||||
Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
|
||||
remote_workers, request->context_id());
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
|
||||
return s;
|
||||
|
@ -76,7 +76,8 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Ignoring an error encountered when setting "
|
||||
"remote shape of tensor handle: "
|
||||
<< retvals[i] << " with status: " << status.ToString()
|
||||
<< retvals[i]
|
||||
<< " with execute status: " << status.ToString()
|
||||
<< " and SetRemoteShape status: " << s.ToString()
|
||||
<< "\nThis should never happen. "
|
||||
"Please file an issue with the TensorFlow Team.";
|
||||
|
@ -134,9 +134,6 @@ Status WorkerSession::UpdateWorkerCacheAndDevices(
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
|
||||
TF_RETURN_IF_ERROR(
|
||||
remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
|
||||
cluster_flr_ = std::unique_ptr<ClusterFunctionLibraryRuntime>(
|
||||
new ClusterFunctionLibraryRuntime(this, !session_name_.empty(),
|
||||
remote_device_mgr()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user