Use round-robin approach to reading from tf.data service workers.

PiperOrigin-RevId: 311367134
Change-Id: I5408de5d85c13514c55681ecf09dcecec5c2168a
This commit is contained in:
Andrew Audibert 2020-05-13 11:18:28 -07:00 committed by TensorFlower Gardener
parent b97bf5ae0b
commit 59239ab499

View File

@ -189,7 +189,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
VLOG(1) << "Destroying data service dataset iterator for job id "
<< job_id_;
cancelled_ = true;
cv_.notify_all();
worker_thread_cv_.notify_all();
manager_thread_cv_.notify_all();
get_next_cv_.notify_all();
// Thread destructors will block until the threads finish, no need to wait
// here.
}
@ -222,12 +224,16 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
});
}
while (results_.empty() && !job_finished_ && !cancelled_) {
cv_.wait(l);
while (results_.empty() && !job_finished_ && !cancelled_ &&
status_.ok()) {
get_next_cv_.wait(l);
}
if (cancelled_) {
return errors::Cancelled("Data service iterator was cancelled");
}
if (!status_.ok()) {
return status_;
}
if (results_.empty()) {
*end_of_sequence = true;
return Status::OK();
@ -236,7 +242,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
*end_of_sequence = false;
out_tensors->swap(results_.front());
results_.pop();
cv_.notify_all();
worker_thread_cv_.notify_one();
return Status::OK();
}
@ -259,16 +265,21 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
private:
typedef struct TaskThread {
int64 task_id;
// Cached address of the worker for task `task_id`.
std::string address;
std::unique_ptr<DataServiceWorkerClient> worker;
std::unique_ptr<Thread> thread;
bool end_of_sequence = false;
// Indicates that the thread has finished running.
bool finished = false;
} TaskThread;
struct Task {
Task(int64 task_id, const std::string& address,
std::unique_ptr<DataServiceWorkerClient> worker)
: task_id(task_id), address(address), worker(std::move(worker)) {}
const int64 task_id;
// Address of the tf.data service worker for task `task_id`.
const std::string address;
// Client for fetching task elements from the tf.data service worker.
const std::unique_ptr<DataServiceWorkerClient> worker;
// Indicates whether a worker thread is currently processing the task.
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
// Indicates whether the worker has returned end_of_sequence for the task.
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
};
// Periodically refresh the task list.
// Maintain one thread fetching elements for each task.
@ -286,22 +297,23 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
int64 remaining_time = next_check - Env::Default()->NowMicros();
VLOG(3) << "Task thread manager waiting for " << remaining_time
<< "us";
cv_.wait_for(l, std::chrono::microseconds(remaining_time));
manager_thread_cv_.wait_for(
l, std::chrono::microseconds(remaining_time));
}
if (cancelled_) {
VLOG(3) << "Task thread manager finished";
return;
}
}
UpdateTaskThreads(&master, ctx.get());
UpdateTasks(&master);
UpdateWorkerThreads(ctx.get());
next_check = Env::Default()->NowMicros() +
dataset()->task_refresh_interval_ms_ * 1000;
}
}
void UpdateTaskThreads(DataServiceMasterClient* master,
IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating task threads";
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating tasks";
std::vector<TaskInfo> tasks;
bool job_finished;
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
@ -310,94 +322,119 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
<< s;
return;
}
absl::flat_hash_set<int64> task_ids;
absl::flat_hash_map<int64, TaskInfo> task_id_to_task;
for (auto& task : tasks) {
task_id_to_task[task.id()] = task;
}
mutex_lock l(mu_);
job_finished_ = job_finished;
for (auto& task : tasks) {
task_ids.insert(task.id());
if (task_threads_.contains(task.id())) {
continue;
}
task_threads_[task.id()] = absl::make_unique<TaskThread>();
TaskThread* task_thread = task_threads_[task.id()].get();
task_thread->task_id = task.id();
task_thread->address = task.worker_address();
num_unfinished_tasks_++;
outstanding_requests_++;
auto done = [this, task_thread]() {
mutex_lock l(mu_);
num_unfinished_tasks_--;
outstanding_requests_--;
cv_.notify_all();
task_thread->finished = true;
VLOG(3) << "Task thread " << task_thread->task_id << " finished";
};
task_thread->thread =
ctx->StartThread("tf-data-service-task_thread",
[this, task_thread, done = std::move(done)]() {
RunTaskThread(task_thread, std::move(done));
});
if (job_finished) {
get_next_cv_.notify_all();
return;
}
// Mark deleted tasks and clean up finished task threads.
for (auto it = task_threads_.begin(); it != task_threads_.end();) {
TaskThread* task_thread = it->second.get();
if (task_thread->finished) {
task_threads_.erase(it++);
for (int i = 0; i < tasks_.size(); ++i) {
std::shared_ptr<Task> task = tasks_[i];
if (task_id_to_task.contains(task->task_id)) {
// Remove already-known tasks from `task_id_to_task`, so that at the
// end of the loop, only new tasks remain.
task_id_to_task.erase(task->task_id);
} else {
// Task has been removed.
if (task->end_of_sequence) {
finished_tasks_--;
}
tasks_[i] = tasks_[tasks_.size() - 1];
tasks_.pop_back();
}
}
for (auto& new_task_entry : task_id_to_task) {
TaskInfo& task_info = new_task_entry.second;
std::unique_ptr<DataServiceWorkerClient> worker;
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
dataset()->protocol_, &worker);
if (!s.ok()) {
status_ = s;
get_next_cv_.notify_all();
continue;
}
if (!task_ids.contains(task_thread->task_id)) {
VLOG(3) << "Marking removed task thread " << task_thread->task_id
<< " as finished";
task_thread->end_of_sequence = true;
}
++it;
tasks_.push_back(std::make_shared<Task>(
task_info.id(), task_info.worker_address(), std::move(worker)));
}
if (dataset()->max_outstanding_requests_ == model::kAutotune) {
// Adjust max_outstanding_requests to account for newly added tasks.
max_outstanding_requests_ = task_threads_.size();
max_outstanding_requests_ = tasks_.size();
}
}
void RunTaskThread(TaskThread* task_thread, std::function<void()> done) {
void UpdateWorkerThreads(IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
while (num_running_worker_threads_ < max_outstanding_requests_) {
num_running_worker_threads_++;
outstanding_requests_++;
auto done = [this]() {
mutex_lock l(mu_);
num_running_worker_threads_--;
outstanding_requests_--;
VLOG(3) << "Exiting worker thread";
};
worker_threads_.push_back(ctx->StartThread(
"tf-data-service-task_thread", [this, done = std::move(done)]() {
RunWorkerThread(std::move(done));
}));
}
}
void RunWorkerThread(std::function<void()> done) {
auto cleanup = gtl::MakeCleanup([done = std::move(done)]() { done(); });
VLOG(3) << "Starting task thread for task " << task_thread->task_id
<< " with worker address " << task_thread->address;
VLOG(3) << "Starting worker thread";
std::shared_ptr<Task> task_to_process;
while (true) {
if (!task_thread->worker) {
Status s = CreateDataServiceWorkerClient(
task_thread->address, dataset()->protocol_, &task_thread->worker);
if (!s.ok()) {
LOG(WARNING) << "Failed to create a worker client for "
<< task_thread->address << ": " << s;
}
}
{
mutex_lock l(mu_);
if (task_thread->end_of_sequence) {
VLOG(3) << "Task thread" << task_thread->task_id
<< " reached end_of_sequence";
return;
if (task_to_process) {
task_to_process->in_use = false;
task_to_process = nullptr;
worker_thread_cv_.notify_one();
}
outstanding_requests_--;
while (!cancelled_ && results_.size() + outstanding_requests_ >=
max_outstanding_requests_) {
VLOG(3) << "Task thread for task " << task_thread->task_id
<< " waiting. results_.size()=" << results_.size()
<< " outstanding_requests_=" << outstanding_requests_;
cv_.wait(l);
while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable())) {
if (VLOG_IS_ON(3)) {
VLOG(3) << "Sleeping with results_.size=" << results_.size()
<< ", outstanding_requests_=" << outstanding_requests_
<< ", max_oustanding_requests="
<< max_outstanding_requests_
<< " finished_tasks=" << finished_tasks_
<< " tasks_.size()=" << tasks_.size();
}
worker_thread_cv_.wait(l);
}
outstanding_requests_++;
if (cancelled_) {
return;
}
outstanding_requests_++;
// Search for a task to update.
int num_tasks = tasks_.size();
for (int i = 0; i < num_tasks; ++i) {
int index = (next_task_index_ + i) % num_tasks;
std::shared_ptr<Task>& task = tasks_[index];
if (!task->in_use && !task->end_of_sequence) {
task->in_use = true;
task_to_process = task;
next_task_index_ = (index + 1) % num_tasks;
break;
}
}
DCHECK(task_to_process != nullptr);
VLOG(3) << "Processing task " << task_to_process->task_id;
}
// TODO(aaudibert): add backoff and max retries.
int64 deadline_micros =
Env::Default()->NowMicros() + kRetryTimeoutMicros;
Status s = GetElement(task_thread, deadline_micros);
Status s = GetElement(task_to_process.get(), deadline_micros);
if (!s.ok()) {
LOG(WARNING) << "Failed to get element from worker at "
<< task_thread->address << ": " << s;
mutex_lock l(mu_);
status_ = s;
get_next_cv_.notify_all();
return;
}
}
}
@ -407,18 +444,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// If the task reaches end_of_sequence or is cancelled (e.g. due to a
// worker dying), GetElement returns Status::OK() without adding to
// `results_`.
Status GetElement(TaskThread* task_thread, int64 deadline_micros) {
VLOG(3) << "Getting an element for task id " << task_thread->task_id;
Status GetElement(Task* task, int64 deadline_micros)
TF_LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Getting an element for task id " << task->task_id;
tensorflow::profiler::TraceMe activity(
"GetElement", tensorflow::profiler::TraceMeLevel::kInfo);
CompressedElement compressed;
bool end_of_sequence;
for (int num_retries = 0;; ++num_retries) {
Status s = task_thread->worker->GetElement(
task_thread->task_id, &compressed, &end_of_sequence);
Status s = task->worker->GetElement(task->task_id, &compressed,
&end_of_sequence);
if (s.ok()) {
break;
}
if (errors::IsNotFound(s)) {
// This indicates that the worker was restarted. The restarted worker
// will get a new task, and the old task is lost.
mutex_lock l(mu_);
finished_tasks_++;
task->end_of_sequence = true;
return Status::OK();
}
// Retry all errors that could indicate preemption.
if (!errors::IsUnavailable(s) && !errors::IsCancelled(s) &&
!errors::IsAborted(s)) {
@ -428,7 +474,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
mutex_lock l(mu_);
// If `UpdateTaskThreads` finds that the task has been cancelled, it
// will set end_of_sequence to `true`.
if (task_thread->end_of_sequence || cancelled_) {
if (task->end_of_sequence || cancelled_) {
return Status::OK();
}
}
@ -454,21 +500,31 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
mutex_lock l(mu_);
if (end_of_sequence) {
task_thread->end_of_sequence = true;
task->end_of_sequence = true;
finished_tasks_++;
return Status::OK();
}
results_.push(std::move(element));
cv_.notify_all();
VLOG(3) << "Got an element for task id " << task_thread->task_id;
get_next_cv_.notify_all();
VLOG(3) << "Got an element for task id " << task->task_id;
return Status::OK();
}
bool SpaceInBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return results_.size() + outstanding_requests_ <
max_outstanding_requests_;
}
bool TaskAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return finished_tasks_ + outstanding_requests_ < tasks_.size();
}
const int64 iterator_index_;
mutex mu_;
// TODO(aaudibert): split this into a couple cvs for different conditions
// so that we can use notify_one and avoid unnecessary wakeups.
condition_variable cv_ TF_GUARDED_BY(mu_);
condition_variable get_next_cv_ TF_GUARDED_BY(mu_);
condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
@ -476,17 +532,31 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// at the same time. This count includes both in-progress requests for
// elements as well as completed requests which haven't yet been produced.
int64 max_outstanding_requests_ TF_GUARDED_BY(mu_);
// The number of threads in `worker_threads_` which are still running.
int64 num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0;
// The index of the next task in `tasks_` to read from.
int64 next_task_index_ TF_GUARDED_BY(mu_) = 0;
// The number tasks in the `tasks_` list that have reached end_of_sequence.
int64 finished_tasks_ TF_GUARDED_BY(mu_) = 0;
// List of tasks to read from.
std::vector<std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
// A status to be returned from the next call to `GetNext`. This is set by
// asynchronous threads when they encounter errors.
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
std::queue<std::vector<Tensor>> results_ TF_GUARDED_BY(mu_);
// Set once in Initialize().
int64 job_id_;
int64 num_unfinished_tasks_ TF_GUARDED_BY(mu_) = 0;
bool job_finished_ = false;
// Must come second to last so that task threads are joined before
// Must be ordered second to last so that worker threads are joined before
// destroying other fields.
absl::flat_hash_map<int64, std::unique_ptr<TaskThread>> task_threads_
TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
// Must be ordered last so that the thread is joined before destroying other
// fields.
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);