Fix a bug where it sets context_id_ to 0 when closing any remote contexts.

The context_id_ should only be set to an invalid value (0) if it's closing and clearing all remote contexts.

PiperOrigin-RevId: 275535526
Change-Id: I069e4724c64ef0f3ca1782cf8b94bcf98e14a36a
This commit is contained in:
Haoyu Zhang 2019-10-18 13:32:13 -07:00 committed by TensorFlower Gardener
parent a450c64d09
commit 479131ef80
2 changed files with 17 additions and 10 deletions

View File

@ -214,14 +214,6 @@ bool EagerContext::MirrorTensors() const {
#if !defined(IS_MOBILE_PLATFORM)
void EagerContext::CloseAndClearAllRemoteContexts() {
CloseRemoteContexts(remote_contexts_);
remote_contexts_.clear();
}
void EagerContext::CloseRemoteContexts(
const std::vector<string>& remote_contexts) {
// Close all remote contexts.
eager::CloseContextRequest request;
uint64 context_id;
{
mutex_lock l(remote_state_mu_);
@ -229,6 +221,14 @@ void EagerContext::CloseRemoteContexts(
context_id = context_id_;
context_id_ = kInvalidContextId;
}
CloseRemoteContexts(remote_contexts_, context_id);
remote_contexts_.clear();
}
void EagerContext::CloseRemoteContexts(
const std::vector<string>& remote_contexts, uint64 context_id) {
// Close all remote contexts.
eager::CloseContextRequest request;
request.set_context_id(context_id);
// Setting context_id to a new value can avoid us issuing DestroyTensorHandle
// request to closed remote workers.
@ -763,7 +763,13 @@ Status EagerContext::UpdateRemoteMaster(
}
if (!remove_remote_contexts.empty()) {
CloseRemoteContexts(remove_remote_contexts);
// N.B. remove_remote_contexts include both removed and replaced workers. It
// is safe to send CloseContextRequest to them using the old copy of eager
// client cache (i.e., `remote_eager_workers_`) because the replaced workers
// will be resolved to the old eager clients. Thus, it correctly closes
// contexts on workers that are replaced by new ones. It must be called
// before overwriting `remote_eager_workers_` in current master context.
CloseRemoteContexts(remove_remote_contexts, context_id);
for (const string& remote_context : remove_remote_contexts) {
remote_contexts_.erase(
std::remove(remote_contexts_.begin(), remote_contexts_.end(),

View File

@ -460,7 +460,8 @@ class EagerContext : public core::RefCounted {
#if !defined(IS_MOBILE_PLATFORM)
void CloseAndClearAllRemoteContexts();
void CloseRemoteContexts(const std::vector<string>& remote_contexts);
void CloseRemoteContexts(const std::vector<string>& remote_contexts,
uint64 context_id);
Status SetMasterContextState(
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,