Garbage collect old WorkerSession when the restarted master task create new one.

PiperOrigin-RevId: 324643608
Change-Id: I10165604d7ae03b25f15a31676d90f62aa6181be
This commit is contained in:
Haoyu Zhang 2020-08-03 11:23:13 -07:00 committed by TensorFlower Gardener
parent cf43fd2af6
commit dbc843d6ec
7 changed files with 226 additions and 1 deletions

View File

@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -1314,12 +1315,21 @@ Status MasterSession::CreateWorkerSessions(
}
});
string task_name;
string local_device_name;
DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
&task_name, &local_device_name);
const int64 client_device_incarnation =
devices_->client_device()->attributes().incarnation();
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
workers[i].request.set_master_task(task_name);
workers[i].request.set_master_incarnation(client_device_incarnation);
if (session_opts_.config.share_cluster_devices_in_session() ||
session_opts_.config.experimental()
.share_cluster_devices_in_session()) {

View File

@ -62,11 +62,46 @@ Status SessionMgr::CreateSession(
const protobuf::RepeatedPtrField<DeviceAttributes>&
cluster_device_attributes,
bool isolate_session_state) {
return CreateSession(session, server_def, cluster_device_attributes,
isolate_session_state, /*master_task=*/"",
/*master_incarnation=*/0);
}
Status SessionMgr::CreateSession(
const string& session, const ServerDef& server_def,
const protobuf::RepeatedPtrField<DeviceAttributes>&
cluster_device_attributes,
bool isolate_session_state, string master_task, int64 master_incarnation) {
mutex_lock l(mu_);
if (session.empty()) {
return errors::InvalidArgument("Session must be non-empty.");
}
// For given master task name, check if one or more `WorkerSession`s have been
// created previously on this worker, and if so garbage collect the expired
// `WorkerSession`s. This happens when the master fails before sending
// `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
if (!master_task.empty()) {
auto it_range = master_to_associated_sessions_.equal_range(master_task);
if (it_range.first != it_range.second &&
it_range.first->second.master_incarnation != master_incarnation) {
LOG(INFO) << "When creating WorkerSession for master task " << master_task
<< ", found old WorkerSessions created by the same master task "
<< "with a different incarnation. These sessions will "
<< "be garbage collected. Current WorkerSession count: "
<< sessions_.size();
auto it = it_range.first;
while (it != it_range.second) {
auto session_it = sessions_.find(it->second.session_handle);
if (session_it != sessions_.end()) {
sessions_.erase(session_it);
}
it = master_to_associated_sessions_.erase(it);
}
}
}
WorkerCacheInterface* worker_cache = nullptr;
string worker_name;
if (server_def.cluster().job().empty()) {
@ -141,6 +176,10 @@ Status SessionMgr::CreateSession(
}
sessions_.insert(std::make_pair(session, std::move(worker_session)));
if (!master_task.empty()) {
MasterAssociatedSession s{master_incarnation, session};
master_to_associated_sessions_.emplace(master_task, s);
}
return Status::OK();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -53,6 +54,18 @@ class SessionMgr {
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
bool isolate_session_state);
// Create WorkerSession from the master with the given `master_task` and
// `master_incarnation`. We first look for existing WorkerSessions associated
// with the specified master task. If there are sessions created by the same
// master but with a different incarnation, it indicates that the remote
// master has restarted before deleting the sessions on worker. When it
// happens, old sessions associated with the master will be automatically
// removed before the new session is created.
Status CreateSession(
const string& session, const ServerDef& server_def,
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
bool isolate_session_state, string master_task, int64 master_incarnation);
void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache);
// Updates state (worker cache, devices) of worker session identified by
@ -107,6 +120,15 @@ class SessionMgr {
mutex mu_;
// A map from session identifier to internal session structure.
std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);
// Incarnation and WorkerSession handle associated with a master task.
struct MasterAssociatedSession {
const int64 master_incarnation;
const string session_handle;
};
// A map from master task name to its associated worker sessions.
std::unordered_multimap<string, MasterAssociatedSession>
master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -152,6 +152,90 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
EXPECT_NE(devices_3[0]->resource_manager(), devices_4[0]->resource_manager());
}
TEST_F(SessionMgrTest, CreateSessionWithMasterName) {
ServerDef server_def;
server_def.set_job_name("worker");
server_def.set_task_index(3);
auto job = server_def.mutable_cluster()->add_job();
job->set_name("worker");
job->mutable_tasks()->insert({3, "localhost:3333"});
protobuf::RepeatedPtrField<DeviceAttributes> cluster_device_attributes;
const string master_name = "/job:master/replica:0/task:1";
const int64 old_incarnation = random::New64();
const int64 new_incarnation = random::New64();
// Allow multiple worker sessions to be created by the same master
string sess_handle1 = "test_session_handle_1";
TF_EXPECT_OK(mgr_.CreateSession(sess_handle1, server_def,
cluster_device_attributes, true, master_name,
old_incarnation));
string sess_handle2 = "test_session_handle_2";
TF_EXPECT_OK(mgr_.CreateSession(sess_handle2, server_def,
cluster_device_attributes, true, master_name,
old_incarnation));
std::shared_ptr<WorkerSession> session;
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle1, &session));
EXPECT_NE(nullptr, session) << "Session for " << sess_handle1 << "was null";
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle2, &session));
EXPECT_NE(nullptr, session) << "Session for " << sess_handle2 << "was null";
// When the master creates a WorkerSession with new incarnation, the old
// WorkerSessions should be garbage collected.
string sess_handle3 = "test_session_handle_3";
TF_EXPECT_OK(mgr_.CreateSession(sess_handle3, server_def,
cluster_device_attributes, true, master_name,
new_incarnation));
EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle1, &session),
tensorflow::Status::OK())
<< "Session for " << sess_handle1
<< " should have been garbage collected.";
EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle2, &session),
tensorflow::Status::OK())
<< "Session for " << sess_handle2
<< " should have been garbage collected.";
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle3, &session));
EXPECT_NE(nullptr, session) << "Session for " << sess_handle3 << "was null";
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle2));
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle3));
}
TEST_F(SessionMgrTest, CreateSessionWithoutMasterName) {
ServerDef server_def;
server_def.set_job_name("worker");
server_def.set_task_index(3);
auto job = server_def.mutable_cluster()->add_job();
job->set_name("worker");
job->mutable_tasks()->insert({3, "localhost:3333"});
protobuf::RepeatedPtrField<DeviceAttributes> cluster_device_attributes;
// WorkerSession will NOT be garbage collected for empty master names.
string sess_handle1 = "test_session_handle_no_master_1";
TF_EXPECT_OK(mgr_.CreateSession(sess_handle1, server_def,
cluster_device_attributes, true, "", 0));
string sess_handle2 = "test_session_handle_no_master_2";
TF_EXPECT_OK(mgr_.CreateSession(sess_handle2, server_def,
cluster_device_attributes, true, "", 0));
std::shared_ptr<WorkerSession> session;
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle1, &session));
EXPECT_NE(nullptr, session) << "Session for " << sess_handle1 << "was null";
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle2, &session));
EXPECT_NE(nullptr, session) << "Session for " << sess_handle2 << "was null";
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle1));
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle2));
}
TEST_F(SessionMgrTest, LegacySession) {
string session_handle = "";
std::shared_ptr<WorkerSession> session;

View File

@ -53,7 +53,8 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
StatusCallback done) {
Status s = env_->session_mgr->CreateSession(
request->session_handle(), request->server_def(),
request->cluster_device_attributes(), request->isolate_session_state());
request->cluster_device_attributes(), request->isolate_session_state(),
request->master_task(), request->master_incarnation());
done(s);
}

View File

@ -70,6 +70,17 @@ message CreateWorkerSessionRequest {
// The device attributes of all the devices in the cluster.
repeated DeviceAttributes cluster_device_attributes = 4;
// The master task name from which the request is sent.
string master_task = 5;
// The incarnation ID of the master task local CPU device.
// If the target worker already has a WorkerSession created previously with
// the same master task name but a different incarnation, it usually indicates
// that the previous master failed before deleting the WorkerSession on the
// worker. To prevent memory leaks, the worker should garbage collect the old
// WorkerSessions.
int64 master_incarnation = 6;
}
message CreateWorkerSessionResponse {}

View File

@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.client import session
@ -202,6 +203,63 @@ class GrpcServerTest(test.TestCase):
self.assertEqual(0.1, server.server_def.default_session_config.gpu_options.
per_process_gpu_memory_fraction)
def testRestartedMaster(self):
master_old = server_lib.Server.create_local_server()
master_new = server_lib.Server.create_local_server()
worker = self._cached_server
def get_cluster_def(master, worker):
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = "master"
job.tasks[0] = master.target[len("grpc://"):]
job = cluster_def.job.add()
job.name = "worker"
job.tasks[0] = worker.target[len("grpc://"):]
return cluster_def
def check_session_devices(sess):
# Make sure we have the correct set of cluster devices
devices = sess.list_devices()
device_names = set(d.name for d in devices)
self.assertIn("/job:master/replica:0/task:0/device:CPU:0", device_names)
self.assertIn("/job:worker/replica:0/task:0/device:CPU:0", device_names)
with ops.Graph().as_default():
# Construct a simple graph that runs ops on remote worker
with ops.device("/job:worker/replica:0/task:0/device:CPU:0"):
a = constant_op.constant([1.0])
b = a + a
config = config_pb2.ConfigProto(
cluster_def=get_cluster_def(master_old, worker))
sess_old = session.Session(master_old.target, config=config)
check_session_devices(sess_old)
# Create a session with the new master and the worker.
# The new master has the same task name ('/job:master/replica:0/task:0')
# as the old master, but is initiated from a different server thus has a
# different incarnation. This triggers the WorkerSession on worker with
# the old master incarnation to be garbage collected.
config = config_pb2.ConfigProto(
cluster_def=get_cluster_def(master_new, worker))
sess_new = session.Session(master_new.target, config=config)
check_session_devices(sess_new)
# Running on worker with the new session should work as expected
v = sess_new.run(b)
self.assertAllEqual(v, [2.0])
# Running on worker with the old session should raise an exception since
# the WorkerSession of the old session has been garbage collected
with self.assertRaisesRegex(errors_impl.AbortedError,
"Session handle is not found"):
sess_old.run(b)
sess_old.close()
sess_new.close()
def testInvalidHostname(self):
with self.assertRaisesRegex(errors_impl.InvalidArgumentError, "port"):
_ = server_lib.Server(