Garbage collect old WorkerSession when the restarted master task create new one.
PiperOrigin-RevId: 324643608 Change-Id: I10165604d7ae03b25f15a31676d90f62aa6181be
This commit is contained in:
parent
cf43fd2af6
commit
dbc843d6ec
tensorflow
core
distributed_runtime
protobuf
python/training
@ -57,6 +57,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -1314,12 +1315,21 @@ Status MasterSession::CreateWorkerSessions(
|
||||
}
|
||||
});
|
||||
|
||||
string task_name;
|
||||
string local_device_name;
|
||||
DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
|
||||
&task_name, &local_device_name);
|
||||
const int64 client_device_incarnation =
|
||||
devices_->client_device()->attributes().incarnation();
|
||||
|
||||
Status status = Status::OK();
|
||||
// Create all the workers & kick off the computations.
|
||||
for (size_t i = 0; i < worker_names.size(); ++i) {
|
||||
workers[i].name = &worker_names[i];
|
||||
workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
|
||||
workers[i].request.set_session_handle(handle_);
|
||||
workers[i].request.set_master_task(task_name);
|
||||
workers[i].request.set_master_incarnation(client_device_incarnation);
|
||||
if (session_opts_.config.share_cluster_devices_in_session() ||
|
||||
session_opts_.config.experimental()
|
||||
.share_cluster_devices_in_session()) {
|
||||
|
@ -62,11 +62,46 @@ Status SessionMgr::CreateSession(
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>&
|
||||
cluster_device_attributes,
|
||||
bool isolate_session_state) {
|
||||
return CreateSession(session, server_def, cluster_device_attributes,
|
||||
isolate_session_state, /*master_task=*/"",
|
||||
/*master_incarnation=*/0);
|
||||
}
|
||||
|
||||
Status SessionMgr::CreateSession(
|
||||
const string& session, const ServerDef& server_def,
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>&
|
||||
cluster_device_attributes,
|
||||
bool isolate_session_state, string master_task, int64 master_incarnation) {
|
||||
mutex_lock l(mu_);
|
||||
if (session.empty()) {
|
||||
return errors::InvalidArgument("Session must be non-empty.");
|
||||
}
|
||||
|
||||
// For given master task name, check if one or more `WorkerSession`s have been
|
||||
// created previously on this worker, and if so garbage collect the expired
|
||||
// `WorkerSession`s. This happens when the master fails before sending
|
||||
// `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
|
||||
if (!master_task.empty()) {
|
||||
auto it_range = master_to_associated_sessions_.equal_range(master_task);
|
||||
if (it_range.first != it_range.second &&
|
||||
it_range.first->second.master_incarnation != master_incarnation) {
|
||||
LOG(INFO) << "When creating WorkerSession for master task " << master_task
|
||||
<< ", found old WorkerSessions created by the same master task "
|
||||
<< "with a different incarnation. These sessions will "
|
||||
<< "be garbage collected. Current WorkerSession count: "
|
||||
<< sessions_.size();
|
||||
|
||||
auto it = it_range.first;
|
||||
while (it != it_range.second) {
|
||||
auto session_it = sessions_.find(it->second.session_handle);
|
||||
if (session_it != sessions_.end()) {
|
||||
sessions_.erase(session_it);
|
||||
}
|
||||
it = master_to_associated_sessions_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WorkerCacheInterface* worker_cache = nullptr;
|
||||
string worker_name;
|
||||
if (server_def.cluster().job().empty()) {
|
||||
@ -141,6 +176,10 @@ Status SessionMgr::CreateSession(
|
||||
}
|
||||
|
||||
sessions_.insert(std::make_pair(session, std::move(worker_session)));
|
||||
if (!master_task.empty()) {
|
||||
MasterAssociatedSession s{master_incarnation, session};
|
||||
master_to_associated_sessions_.emplace(master_task, s);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
@ -53,6 +54,18 @@ class SessionMgr {
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
|
||||
bool isolate_session_state);
|
||||
|
||||
// Create WorkerSession from the master with the given `master_task` and
|
||||
// `master_incarnation`. We first look for existing WorkerSessions associated
|
||||
// with the specified master task. If there are sessions created by the same
|
||||
// master but with a different incarnation, it indicates that the remote
|
||||
// master has restarted before deleting the sessions on worker. When it
|
||||
// happens, old sessions associated with the master will be automatically
|
||||
// removed before the new session is created.
|
||||
Status CreateSession(
|
||||
const string& session, const ServerDef& server_def,
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
|
||||
bool isolate_session_state, string master_task, int64 master_incarnation);
|
||||
|
||||
void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache);
|
||||
|
||||
// Updates state (worker cache, devices) of worker session identified by
|
||||
@ -107,6 +120,15 @@ class SessionMgr {
|
||||
mutex mu_;
|
||||
// A map from session identifier to internal session structure.
|
||||
std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Incarnation and WorkerSession handle associated with a master task.
|
||||
struct MasterAssociatedSession {
|
||||
const int64 master_incarnation;
|
||||
const string session_handle;
|
||||
};
|
||||
// A map from master task name to its associated worker sessions.
|
||||
std::unordered_multimap<string, MasterAssociatedSession>
|
||||
master_to_associated_sessions_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -152,6 +152,90 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
|
||||
EXPECT_NE(devices_3[0]->resource_manager(), devices_4[0]->resource_manager());
|
||||
}
|
||||
|
||||
TEST_F(SessionMgrTest, CreateSessionWithMasterName) {
|
||||
ServerDef server_def;
|
||||
server_def.set_job_name("worker");
|
||||
server_def.set_task_index(3);
|
||||
auto job = server_def.mutable_cluster()->add_job();
|
||||
job->set_name("worker");
|
||||
job->mutable_tasks()->insert({3, "localhost:3333"});
|
||||
|
||||
protobuf::RepeatedPtrField<DeviceAttributes> cluster_device_attributes;
|
||||
|
||||
const string master_name = "/job:master/replica:0/task:1";
|
||||
const int64 old_incarnation = random::New64();
|
||||
const int64 new_incarnation = random::New64();
|
||||
|
||||
// Allow multiple worker sessions to be created by the same master
|
||||
string sess_handle1 = "test_session_handle_1";
|
||||
TF_EXPECT_OK(mgr_.CreateSession(sess_handle1, server_def,
|
||||
cluster_device_attributes, true, master_name,
|
||||
old_incarnation));
|
||||
string sess_handle2 = "test_session_handle_2";
|
||||
TF_EXPECT_OK(mgr_.CreateSession(sess_handle2, server_def,
|
||||
cluster_device_attributes, true, master_name,
|
||||
old_incarnation));
|
||||
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle1, &session));
|
||||
EXPECT_NE(nullptr, session) << "Session for " << sess_handle1 << "was null";
|
||||
|
||||
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle2, &session));
|
||||
EXPECT_NE(nullptr, session) << "Session for " << sess_handle2 << "was null";
|
||||
|
||||
// When the master creates a WorkerSession with new incarnation, the old
|
||||
// WorkerSessions should be garbage collected.
|
||||
string sess_handle3 = "test_session_handle_3";
|
||||
TF_EXPECT_OK(mgr_.CreateSession(sess_handle3, server_def,
|
||||
cluster_device_attributes, true, master_name,
|
||||
new_incarnation));
|
||||
|
||||
EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle1, &session),
|
||||
tensorflow::Status::OK())
|
||||
<< "Session for " << sess_handle1
|
||||
<< " should have been garbage collected.";
|
||||
|
||||
EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle2, &session),
|
||||
tensorflow::Status::OK())
|
||||
<< "Session for " << sess_handle2
|
||||
<< " should have been garbage collected.";
|
||||
|
||||
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle3, &session));
|
||||
EXPECT_NE(nullptr, session) << "Session for " << sess_handle3 << "was null";
|
||||
|
||||
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle2));
|
||||
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle3));
|
||||
}
|
||||
|
||||
TEST_F(SessionMgrTest, CreateSessionWithoutMasterName) {
|
||||
ServerDef server_def;
|
||||
server_def.set_job_name("worker");
|
||||
server_def.set_task_index(3);
|
||||
auto job = server_def.mutable_cluster()->add_job();
|
||||
job->set_name("worker");
|
||||
job->mutable_tasks()->insert({3, "localhost:3333"});
|
||||
|
||||
protobuf::RepeatedPtrField<DeviceAttributes> cluster_device_attributes;
|
||||
|
||||
// WorkerSession will NOT be garbage collected for empty master names.
|
||||
string sess_handle1 = "test_session_handle_no_master_1";
|
||||
TF_EXPECT_OK(mgr_.CreateSession(sess_handle1, server_def,
|
||||
cluster_device_attributes, true, "", 0));
|
||||
string sess_handle2 = "test_session_handle_no_master_2";
|
||||
TF_EXPECT_OK(mgr_.CreateSession(sess_handle2, server_def,
|
||||
cluster_device_attributes, true, "", 0));
|
||||
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle1, &session));
|
||||
EXPECT_NE(nullptr, session) << "Session for " << sess_handle1 << "was null";
|
||||
|
||||
TF_EXPECT_OK(mgr_.WorkerSessionForSession(sess_handle2, &session));
|
||||
EXPECT_NE(nullptr, session) << "Session for " << sess_handle2 << "was null";
|
||||
|
||||
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle1));
|
||||
TF_EXPECT_OK(mgr_.DeleteSession(sess_handle2));
|
||||
}
|
||||
|
||||
TEST_F(SessionMgrTest, LegacySession) {
|
||||
string session_handle = "";
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
|
@ -53,7 +53,8 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
|
||||
StatusCallback done) {
|
||||
Status s = env_->session_mgr->CreateSession(
|
||||
request->session_handle(), request->server_def(),
|
||||
request->cluster_device_attributes(), request->isolate_session_state());
|
||||
request->cluster_device_attributes(), request->isolate_session_state(),
|
||||
request->master_task(), request->master_incarnation());
|
||||
done(s);
|
||||
}
|
||||
|
||||
|
@ -70,6 +70,17 @@ message CreateWorkerSessionRequest {
|
||||
|
||||
// The device attributes of all the devices in the cluster.
|
||||
repeated DeviceAttributes cluster_device_attributes = 4;
|
||||
|
||||
// The master task name from which the request is sent.
|
||||
string master_task = 5;
|
||||
|
||||
// The incarnation ID of the master task local CPU device.
|
||||
// If the target worker already has a WorkerSession created previously with
|
||||
// the same master task name but a different incarnation, it usually indicates
|
||||
// that the previous master failed before deleting the WorkerSession on the
|
||||
// worker. To prevent memory leaks, the worker should garbage collect the old
|
||||
// WorkerSessions.
|
||||
int64 master_incarnation = 6;
|
||||
}
|
||||
|
||||
message CreateWorkerSessionResponse {}
|
||||
|
@ -22,6 +22,7 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -202,6 +203,63 @@ class GrpcServerTest(test.TestCase):
|
||||
self.assertEqual(0.1, server.server_def.default_session_config.gpu_options.
|
||||
per_process_gpu_memory_fraction)
|
||||
|
||||
def testRestartedMaster(self):
|
||||
master_old = server_lib.Server.create_local_server()
|
||||
master_new = server_lib.Server.create_local_server()
|
||||
worker = self._cached_server
|
||||
|
||||
def get_cluster_def(master, worker):
|
||||
cluster_def = cluster_pb2.ClusterDef()
|
||||
job = cluster_def.job.add()
|
||||
job.name = "master"
|
||||
job.tasks[0] = master.target[len("grpc://"):]
|
||||
job = cluster_def.job.add()
|
||||
job.name = "worker"
|
||||
job.tasks[0] = worker.target[len("grpc://"):]
|
||||
return cluster_def
|
||||
|
||||
def check_session_devices(sess):
|
||||
# Make sure we have the correct set of cluster devices
|
||||
devices = sess.list_devices()
|
||||
device_names = set(d.name for d in devices)
|
||||
self.assertIn("/job:master/replica:0/task:0/device:CPU:0", device_names)
|
||||
self.assertIn("/job:worker/replica:0/task:0/device:CPU:0", device_names)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
# Construct a simple graph that runs ops on remote worker
|
||||
with ops.device("/job:worker/replica:0/task:0/device:CPU:0"):
|
||||
a = constant_op.constant([1.0])
|
||||
b = a + a
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
cluster_def=get_cluster_def(master_old, worker))
|
||||
sess_old = session.Session(master_old.target, config=config)
|
||||
check_session_devices(sess_old)
|
||||
|
||||
# Create a session with the new master and the worker.
|
||||
# The new master has the same task name ('/job:master/replica:0/task:0')
|
||||
# as the old master, but is initiated from a different server thus has a
|
||||
# different incarnation. This triggers the WorkerSession on worker with
|
||||
# the old master incarnation to be garbage collected.
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
cluster_def=get_cluster_def(master_new, worker))
|
||||
sess_new = session.Session(master_new.target, config=config)
|
||||
check_session_devices(sess_new)
|
||||
|
||||
# Running on worker with the new session should work as expected
|
||||
v = sess_new.run(b)
|
||||
self.assertAllEqual(v, [2.0])
|
||||
|
||||
# Running on worker with the old session should raise an exception since
|
||||
# the WorkerSession of the old session has been garbage collected
|
||||
with self.assertRaisesRegex(errors_impl.AbortedError,
|
||||
"Session handle is not found"):
|
||||
sess_old.run(b)
|
||||
|
||||
sess_old.close()
|
||||
sess_new.close()
|
||||
|
||||
def testInvalidHostname(self):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError, "port"):
|
||||
_ = server_lib.Server(
|
||||
|
Loading…
Reference in New Issue
Block a user