[tf.data service] Avoid holding locks during RPC calls.
This CL fixes a deadlock where a worker holds its lock while making an RPC to the master, and the master holds its lock while making an RPC to the worker. The RPCs require locks to serve, so we end up deadlocked. We can avoid this by never holding a lock while performing RPCs. This CL modifies the master locking to release the lock when making the `ProcessTask` RPC to the worker. This change shouldn't affect any functionality - it should only reduce the scope of some locking. PiperOrigin-RevId: 317364346 Change-Id: I21e5ed8cdaced1192a89ffda4f8f93418e5dc4a5
This commit is contained in:
parent
bd98ba765a
commit
5ed030734c
tensorflow/core/data/service
@ -61,7 +61,8 @@ Status DataServiceMasterImpl::RegisterWorker(
|
||||
VLOG(3) << "Received register worker request";
|
||||
mutex_lock l(mu_);
|
||||
int64 worker_id = next_worker_id_++;
|
||||
workers_.emplace_back(worker_id, request->worker_address());
|
||||
workers_.push_back(
|
||||
std::make_shared<Worker>(worker_id, request->worker_address()));
|
||||
response->set_worker_id(worker_id);
|
||||
|
||||
// Allocate tasks to the worker.
|
||||
@ -70,17 +71,18 @@ Status DataServiceMasterImpl::RegisterWorker(
|
||||
if (job->finished()) {
|
||||
continue;
|
||||
}
|
||||
int64 task_id = CreateTask(job.get(), request->worker_address());
|
||||
const Task& task = CreateTaskLocked(job.get(), request->worker_address());
|
||||
|
||||
TaskDef* task_def = response->add_tasks();
|
||||
*task_def->mutable_dataset() =
|
||||
datasets_by_id_[job->dataset_id()]->dataset_def();
|
||||
task_def->set_dataset_id(job->dataset_id());
|
||||
task_def->set_job_id(job->job_id());
|
||||
task_def->set_task_id(task_id);
|
||||
task_def->set_task_id(task.task_id());
|
||||
}
|
||||
|
||||
VLOG(1) << "Registered worker " << workers_.back().DebugString();
|
||||
VLOG(1) << "Registered worker at address " << request->worker_address()
|
||||
<< " with id " << worker_id;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -145,7 +147,6 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||
VLOG(3) << "Received create job request for dataset id "
|
||||
<< request->dataset_id();
|
||||
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
||||
mutex_lock l(mu_);
|
||||
int64 job_id;
|
||||
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
|
||||
absl::optional<std::string>(), &job_id));
|
||||
@ -161,25 +162,30 @@ Status DataServiceMasterImpl::GetOrCreateJob(
|
||||
VLOG(3) << "Received get or create job request for dataset id "
|
||||
<< request->dataset_id() << " with name " << request->job_name()
|
||||
<< " and index " << request->job_name_index();
|
||||
mutex_lock l(mu_);
|
||||
NamedJobKey key(request->job_name(), request->job_name_index());
|
||||
ProcessingMode requested_processing_mode =
|
||||
ProcessingMode(request->processing_mode());
|
||||
std::shared_ptr<Job>* job = gtl::FindOrNull(named_jobs_, key);
|
||||
if (job != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
|
||||
request->dataset_id()));
|
||||
int64 job_id = (*job)->job_id();
|
||||
response->set_job_id(job_id);
|
||||
VLOG(3) << "Found existing job for name=" << request->job_name()
|
||||
<< ", index=" << request->job_name_index()
|
||||
<< ". job_id: " << job_id;
|
||||
return Status::OK();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
std::shared_ptr<Job>* job = gtl::FindOrNull(named_jobs_, key);
|
||||
if (job != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
|
||||
request->dataset_id()));
|
||||
int64 job_id = (*job)->job_id();
|
||||
response->set_job_id(job_id);
|
||||
VLOG(3) << "Found existing job for name=" << request->job_name()
|
||||
<< ", index=" << request->job_name_index()
|
||||
<< ". job_id: " << job_id;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
int64 job_id;
|
||||
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), requested_processing_mode,
|
||||
request->job_name(), &job_id));
|
||||
named_jobs_[key] = jobs_[job_id];
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
named_jobs_[key] = jobs_[job_id];
|
||||
}
|
||||
response->set_job_id(job_id);
|
||||
VLOG(3) << "Created job " << job_id << " for dataset "
|
||||
<< request->dataset_id() << " and name " << request->job_name();
|
||||
@ -211,8 +217,7 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
|
||||
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name,
|
||||
int64* out_job_id)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
|
||||
switch (processing_mode) {
|
||||
case ProcessingMode::PARALLEL_EPOCHS:
|
||||
break;
|
||||
@ -225,41 +230,64 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||
ProcessingModeToString(processing_mode),
|
||||
" not recognized");
|
||||
}
|
||||
if (!datasets_by_id_.contains(dataset_id)) {
|
||||
return errors::NotFound("Dataset id: <", dataset_id, "> not found.");
|
||||
std::shared_ptr<Job> job;
|
||||
std::vector<std::shared_ptr<Worker>> workers;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!datasets_by_id_.contains(dataset_id)) {
|
||||
return errors::NotFound("Dataset id: <", dataset_id, "> not found.");
|
||||
}
|
||||
|
||||
int64 job_id = next_job_id_++;
|
||||
DCHECK(!jobs_.contains(job_id));
|
||||
job = std::make_shared<Job>(job_id, dataset_id, processing_mode, job_name);
|
||||
jobs_[job_id] = job;
|
||||
|
||||
// Copy workers_ so that we can iterate through the workers without holding
|
||||
// the lock. When a new worker is added in `RegisterWorker`, we iterate
|
||||
// through the jobs in `jobs_` and give it a task for each job. So even if a
|
||||
// new worker is registered after we release the lock, because this job has
|
||||
// been added to `jobs_`, it will still receive a task for this job.
|
||||
workers = workers_;
|
||||
const Dataset& dataset = *datasets_by_id_[dataset_id];
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Sending tasks to workers for job " << job->job_id()
|
||||
<< ". Dataset id: " << dataset_id
|
||||
<< ". Dataset fingerprint: " << dataset.fingerprint()
|
||||
<< ". Dataset definition size: "
|
||||
<< datasets_by_id_[dataset_id]->dataset_def().ByteSizeLong();
|
||||
}
|
||||
}
|
||||
|
||||
int64 job_id = next_job_id_++;
|
||||
DCHECK(!jobs_.contains(job_id));
|
||||
auto job =
|
||||
std::make_shared<Job>(job_id, dataset_id, processing_mode, job_name);
|
||||
jobs_[job_id] = job;
|
||||
|
||||
for (auto& worker : workers_) {
|
||||
int64 task_id = CreateTask(job.get(), worker.address());
|
||||
|
||||
// TODO(aaudibert): perform these calls asynchronously.
|
||||
// TODO(aaudibert): clean up in case some calls succeed, but later calls
|
||||
// fail
|
||||
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
|
||||
for (auto& worker : workers) {
|
||||
const Task& task = CreateTask(job.get(), worker->address());
|
||||
Status s = AllocateTaskToWorker(task, worker.get());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to allocate task with id " << task.task_id()
|
||||
<< " to worker at address " << worker->address() << ": "
|
||||
<< s.error_message();
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Done sending tasks to workers for job " << job->job_id();
|
||||
|
||||
*out_job_id = job_id;
|
||||
*out_job_id = job->job_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 DataServiceMasterImpl::CreateTask(Job* job,
|
||||
const std::string& worker_address)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
|
||||
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
return CreateTaskLocked(job, worker_address);
|
||||
}
|
||||
|
||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
||||
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
int64 task_id = next_task_id_++;
|
||||
DCHECK(!tasks_.contains(task_id));
|
||||
auto result =
|
||||
tasks_.emplace(std::piecewise_construct, std::forward_as_tuple(task_id),
|
||||
std::forward_as_tuple(task_id, job->job_id(),
|
||||
job->dataset_id(), worker_address));
|
||||
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
||||
worker_address)});
|
||||
job->add_task_id(task_id);
|
||||
DCHECK(result.second);
|
||||
return task_id;
|
||||
return tasks_.at(task_id);
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
@ -273,14 +301,17 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
|
||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
||||
Worker* worker)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
||||
grpc::ClientContext client_ctx;
|
||||
ProcessTaskRequest req;
|
||||
req.mutable_task()->set_dataset_id(task.dataset_id());
|
||||
DCHECK(datasets_by_id_.contains(task.dataset_id()));
|
||||
*req.mutable_task()->mutable_dataset() =
|
||||
datasets_by_id_.at(task.dataset_id())->dataset_def();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(datasets_by_id_.contains(task.dataset_id()));
|
||||
*req.mutable_task()->mutable_dataset() =
|
||||
datasets_by_id_.at(task.dataset_id())->dataset_def();
|
||||
}
|
||||
req.mutable_task()->set_task_id(task.task_id());
|
||||
ProcessTaskResponse resp;
|
||||
grpc::Status s = worker->stub()->ProcessTask(&client_ctx, req, &resp);
|
||||
@ -321,8 +352,8 @@ Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
|
||||
VLOG(3) << "Enter GetWorkers";
|
||||
for (auto& worker : workers_) {
|
||||
WorkerInfo* info = response->add_workers();
|
||||
info->set_address(worker.address());
|
||||
info->set_id(worker.worker_id());
|
||||
info->set_address(worker->address());
|
||||
info->set_id(worker->worker_id());
|
||||
}
|
||||
VLOG(3) << "Returning list of " << workers_.size()
|
||||
<< " workers from GetWorkers";
|
||||
|
@ -177,16 +177,23 @@ class DataServiceMasterImpl {
|
||||
};
|
||||
|
||||
// Registers a dataset with the given fingerprint, returning a new dataset id.
|
||||
int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset);
|
||||
int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Initializes a workers stub, if it hasn't been initialized already.
|
||||
Status EnsureWorkerStubInitialized(Worker* worker);
|
||||
// Instructs a worker to begin processing a task.
|
||||
Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
|
||||
Status AllocateTaskToWorker(const Task& task_id, Worker* worker)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Creates a job and stores its job_id in `*job_id`.
|
||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name, int64* out_job_id);
|
||||
// Creates a new task for a job, returning the new task's id.
|
||||
int64 CreateTask(Job* job, const std::string& worker_address);
|
||||
absl::optional<std::string> job_name, int64* out_job_id)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Creates a new task for a job, returning a reference to the task.
|
||||
const Task& CreateTask(Job* job, const std::string& worker_address)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Same as `CreateTask`, but expects that the master lock is already held.
|
||||
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Validates that an existing job matches the given processing_mode and
|
||||
// dataset_id, returning an error status describing any difference.
|
||||
Status ValidateMatchingJob(const Job& job, ProcessingMode processing_mode,
|
||||
@ -202,7 +209,7 @@ class DataServiceMasterImpl {
|
||||
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Registered workers.
|
||||
std::vector<Worker> workers_ TF_GUARDED_BY(mu_);
|
||||
std::vector<std::shared_ptr<Worker>> workers_ TF_GUARDED_BY(mu_);
|
||||
// Registered datasets, keyed by dataset ids.
|
||||
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
Loading…
Reference in New Issue
Block a user