Use round-robin approach to reading from tf.data service workers.
PiperOrigin-RevId: 311367134 Change-Id: I5408de5d85c13514c55681ecf09dcecec5c2168a
This commit is contained in:
parent
b97bf5ae0b
commit
59239ab499
@ -189,7 +189,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
VLOG(1) << "Destroying data service dataset iterator for job id "
|
||||
<< job_id_;
|
||||
cancelled_ = true;
|
||||
cv_.notify_all();
|
||||
worker_thread_cv_.notify_all();
|
||||
manager_thread_cv_.notify_all();
|
||||
get_next_cv_.notify_all();
|
||||
// Thread destructors will block until the threads finish, no need to wait
|
||||
// here.
|
||||
}
|
||||
@ -222,12 +224,16 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
});
|
||||
}
|
||||
|
||||
while (results_.empty() && !job_finished_ && !cancelled_) {
|
||||
cv_.wait(l);
|
||||
while (results_.empty() && !job_finished_ && !cancelled_ &&
|
||||
status_.ok()) {
|
||||
get_next_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Data service iterator was cancelled");
|
||||
}
|
||||
if (!status_.ok()) {
|
||||
return status_;
|
||||
}
|
||||
if (results_.empty()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
@ -236,7 +242,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
*end_of_sequence = false;
|
||||
out_tensors->swap(results_.front());
|
||||
results_.pop();
|
||||
cv_.notify_all();
|
||||
worker_thread_cv_.notify_one();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -259,16 +265,21 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct TaskThread {
|
||||
int64 task_id;
|
||||
// Cached address of the worker for task `task_id`.
|
||||
std::string address;
|
||||
std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
std::unique_ptr<Thread> thread;
|
||||
bool end_of_sequence = false;
|
||||
// Indicates that the thread has finished running.
|
||||
bool finished = false;
|
||||
} TaskThread;
|
||||
struct Task {
|
||||
Task(int64 task_id, const std::string& address,
|
||||
std::unique_ptr<DataServiceWorkerClient> worker)
|
||||
: task_id(task_id), address(address), worker(std::move(worker)) {}
|
||||
|
||||
const int64 task_id;
|
||||
// Address of the tf.data service worker for task `task_id`.
|
||||
const std::string address;
|
||||
// Client for fetching task elements from the tf.data service worker.
|
||||
const std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
// Indicates whether a worker thread is currently processing the task.
|
||||
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
// Indicates whether the worker has returned end_of_sequence for the task.
|
||||
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
};
|
||||
|
||||
// Periodically refresh the task list.
|
||||
// Maintain one thread fetching elements for each task.
|
||||
@ -286,22 +297,23 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
int64 remaining_time = next_check - Env::Default()->NowMicros();
|
||||
VLOG(3) << "Task thread manager waiting for " << remaining_time
|
||||
<< "us";
|
||||
cv_.wait_for(l, std::chrono::microseconds(remaining_time));
|
||||
manager_thread_cv_.wait_for(
|
||||
l, std::chrono::microseconds(remaining_time));
|
||||
}
|
||||
if (cancelled_) {
|
||||
VLOG(3) << "Task thread manager finished";
|
||||
return;
|
||||
}
|
||||
}
|
||||
UpdateTaskThreads(&master, ctx.get());
|
||||
UpdateTasks(&master);
|
||||
UpdateWorkerThreads(ctx.get());
|
||||
next_check = Env::Default()->NowMicros() +
|
||||
dataset()->task_refresh_interval_ms_ * 1000;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateTaskThreads(DataServiceMasterClient* master,
|
||||
IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Updating task threads";
|
||||
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Updating tasks";
|
||||
std::vector<TaskInfo> tasks;
|
||||
bool job_finished;
|
||||
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
|
||||
@ -310,94 +322,119 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
<< s;
|
||||
return;
|
||||
}
|
||||
absl::flat_hash_set<int64> task_ids;
|
||||
absl::flat_hash_map<int64, TaskInfo> task_id_to_task;
|
||||
for (auto& task : tasks) {
|
||||
task_id_to_task[task.id()] = task;
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
job_finished_ = job_finished;
|
||||
for (auto& task : tasks) {
|
||||
task_ids.insert(task.id());
|
||||
if (task_threads_.contains(task.id())) {
|
||||
continue;
|
||||
}
|
||||
task_threads_[task.id()] = absl::make_unique<TaskThread>();
|
||||
TaskThread* task_thread = task_threads_[task.id()].get();
|
||||
task_thread->task_id = task.id();
|
||||
task_thread->address = task.worker_address();
|
||||
num_unfinished_tasks_++;
|
||||
outstanding_requests_++;
|
||||
auto done = [this, task_thread]() {
|
||||
mutex_lock l(mu_);
|
||||
num_unfinished_tasks_--;
|
||||
outstanding_requests_--;
|
||||
cv_.notify_all();
|
||||
task_thread->finished = true;
|
||||
VLOG(3) << "Task thread " << task_thread->task_id << " finished";
|
||||
};
|
||||
task_thread->thread =
|
||||
ctx->StartThread("tf-data-service-task_thread",
|
||||
[this, task_thread, done = std::move(done)]() {
|
||||
RunTaskThread(task_thread, std::move(done));
|
||||
});
|
||||
if (job_finished) {
|
||||
get_next_cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
// Mark deleted tasks and clean up finished task threads.
|
||||
for (auto it = task_threads_.begin(); it != task_threads_.end();) {
|
||||
TaskThread* task_thread = it->second.get();
|
||||
if (task_thread->finished) {
|
||||
task_threads_.erase(it++);
|
||||
for (int i = 0; i < tasks_.size(); ++i) {
|
||||
std::shared_ptr<Task> task = tasks_[i];
|
||||
if (task_id_to_task.contains(task->task_id)) {
|
||||
// Remove already-known tasks from `task_id_to_task`, so that at the
|
||||
// end of the loop, only new tasks remain.
|
||||
task_id_to_task.erase(task->task_id);
|
||||
} else {
|
||||
// Task has been removed.
|
||||
if (task->end_of_sequence) {
|
||||
finished_tasks_--;
|
||||
}
|
||||
tasks_[i] = tasks_[tasks_.size() - 1];
|
||||
tasks_.pop_back();
|
||||
}
|
||||
}
|
||||
for (auto& new_task_entry : task_id_to_task) {
|
||||
TaskInfo& task_info = new_task_entry.second;
|
||||
std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
|
||||
dataset()->protocol_, &worker);
|
||||
if (!s.ok()) {
|
||||
status_ = s;
|
||||
get_next_cv_.notify_all();
|
||||
continue;
|
||||
}
|
||||
if (!task_ids.contains(task_thread->task_id)) {
|
||||
VLOG(3) << "Marking removed task thread " << task_thread->task_id
|
||||
<< " as finished";
|
||||
task_thread->end_of_sequence = true;
|
||||
}
|
||||
++it;
|
||||
tasks_.push_back(std::make_shared<Task>(
|
||||
task_info.id(), task_info.worker_address(), std::move(worker)));
|
||||
}
|
||||
if (dataset()->max_outstanding_requests_ == model::kAutotune) {
|
||||
// Adjust max_outstanding_requests to account for newly added tasks.
|
||||
max_outstanding_requests_ = task_threads_.size();
|
||||
max_outstanding_requests_ = tasks_.size();
|
||||
}
|
||||
}
|
||||
|
||||
void RunTaskThread(TaskThread* task_thread, std::function<void()> done) {
|
||||
void UpdateWorkerThreads(IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
while (num_running_worker_threads_ < max_outstanding_requests_) {
|
||||
num_running_worker_threads_++;
|
||||
outstanding_requests_++;
|
||||
auto done = [this]() {
|
||||
mutex_lock l(mu_);
|
||||
num_running_worker_threads_--;
|
||||
outstanding_requests_--;
|
||||
VLOG(3) << "Exiting worker thread";
|
||||
};
|
||||
worker_threads_.push_back(ctx->StartThread(
|
||||
"tf-data-service-task_thread", [this, done = std::move(done)]() {
|
||||
RunWorkerThread(std::move(done));
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void RunWorkerThread(std::function<void()> done) {
|
||||
auto cleanup = gtl::MakeCleanup([done = std::move(done)]() { done(); });
|
||||
VLOG(3) << "Starting task thread for task " << task_thread->task_id
|
||||
<< " with worker address " << task_thread->address;
|
||||
VLOG(3) << "Starting worker thread";
|
||||
std::shared_ptr<Task> task_to_process;
|
||||
while (true) {
|
||||
if (!task_thread->worker) {
|
||||
Status s = CreateDataServiceWorkerClient(
|
||||
task_thread->address, dataset()->protocol_, &task_thread->worker);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to create a worker client for "
|
||||
<< task_thread->address << ": " << s;
|
||||
}
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (task_thread->end_of_sequence) {
|
||||
VLOG(3) << "Task thread" << task_thread->task_id
|
||||
<< " reached end_of_sequence";
|
||||
return;
|
||||
if (task_to_process) {
|
||||
task_to_process->in_use = false;
|
||||
task_to_process = nullptr;
|
||||
worker_thread_cv_.notify_one();
|
||||
}
|
||||
outstanding_requests_--;
|
||||
while (!cancelled_ && results_.size() + outstanding_requests_ >=
|
||||
max_outstanding_requests_) {
|
||||
VLOG(3) << "Task thread for task " << task_thread->task_id
|
||||
<< " waiting. results_.size()=" << results_.size()
|
||||
<< " outstanding_requests_=" << outstanding_requests_;
|
||||
cv_.wait(l);
|
||||
while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable())) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "Sleeping with results_.size=" << results_.size()
|
||||
<< ", outstanding_requests_=" << outstanding_requests_
|
||||
<< ", max_oustanding_requests="
|
||||
<< max_outstanding_requests_
|
||||
<< " finished_tasks=" << finished_tasks_
|
||||
<< " tasks_.size()=" << tasks_.size();
|
||||
}
|
||||
worker_thread_cv_.wait(l);
|
||||
}
|
||||
outstanding_requests_++;
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
outstanding_requests_++;
|
||||
// Search for a task to update.
|
||||
int num_tasks = tasks_.size();
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
int index = (next_task_index_ + i) % num_tasks;
|
||||
std::shared_ptr<Task>& task = tasks_[index];
|
||||
if (!task->in_use && !task->end_of_sequence) {
|
||||
task->in_use = true;
|
||||
task_to_process = task;
|
||||
next_task_index_ = (index + 1) % num_tasks;
|
||||
break;
|
||||
}
|
||||
}
|
||||
DCHECK(task_to_process != nullptr);
|
||||
VLOG(3) << "Processing task " << task_to_process->task_id;
|
||||
}
|
||||
// TODO(aaudibert): add backoff and max retries.
|
||||
int64 deadline_micros =
|
||||
Env::Default()->NowMicros() + kRetryTimeoutMicros;
|
||||
Status s = GetElement(task_thread, deadline_micros);
|
||||
Status s = GetElement(task_to_process.get(), deadline_micros);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to get element from worker at "
|
||||
<< task_thread->address << ": " << s;
|
||||
mutex_lock l(mu_);
|
||||
status_ = s;
|
||||
get_next_cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -407,18 +444,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// If the task reaches end_of_sequence or is cancelled (e.g. due to a
|
||||
// worker dying), GetElement returns Status::OK() without adding to
|
||||
// `results_`.
|
||||
Status GetElement(TaskThread* task_thread, int64 deadline_micros) {
|
||||
VLOG(3) << "Getting an element for task id " << task_thread->task_id;
|
||||
Status GetElement(Task* task, int64 deadline_micros)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Getting an element for task id " << task->task_id;
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"GetElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
CompressedElement compressed;
|
||||
bool end_of_sequence;
|
||||
for (int num_retries = 0;; ++num_retries) {
|
||||
Status s = task_thread->worker->GetElement(
|
||||
task_thread->task_id, &compressed, &end_of_sequence);
|
||||
Status s = task->worker->GetElement(task->task_id, &compressed,
|
||||
&end_of_sequence);
|
||||
if (s.ok()) {
|
||||
break;
|
||||
}
|
||||
if (errors::IsNotFound(s)) {
|
||||
// This indicates that the worker was restarted. The restarted worker
|
||||
// will get a new task, and the old task is lost.
|
||||
mutex_lock l(mu_);
|
||||
finished_tasks_++;
|
||||
task->end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
// Retry all errors that could indicate preemption.
|
||||
if (!errors::IsUnavailable(s) && !errors::IsCancelled(s) &&
|
||||
!errors::IsAborted(s)) {
|
||||
@ -428,7 +474,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock l(mu_);
|
||||
// If `UpdateTaskThreads` finds that the task has been cancelled, it
|
||||
// will set end_of_sequence to `true`.
|
||||
if (task_thread->end_of_sequence || cancelled_) {
|
||||
if (task->end_of_sequence || cancelled_) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -454,21 +500,31 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
if (end_of_sequence) {
|
||||
task_thread->end_of_sequence = true;
|
||||
task->end_of_sequence = true;
|
||||
finished_tasks_++;
|
||||
return Status::OK();
|
||||
}
|
||||
results_.push(std::move(element));
|
||||
cv_.notify_all();
|
||||
VLOG(3) << "Got an element for task id " << task_thread->task_id;
|
||||
get_next_cv_.notify_all();
|
||||
VLOG(3) << "Got an element for task id " << task->task_id;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool SpaceInBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
return results_.size() + outstanding_requests_ <
|
||||
max_outstanding_requests_;
|
||||
}
|
||||
|
||||
bool TaskAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
return finished_tasks_ + outstanding_requests_ < tasks_.size();
|
||||
}
|
||||
|
||||
const int64 iterator_index_;
|
||||
|
||||
mutex mu_;
|
||||
// TODO(aaudibert): split this into a couple cvs for different conditions
|
||||
// so that we can use notify_one and avoid unnecessary wakeups.
|
||||
condition_variable cv_ TF_GUARDED_BY(mu_);
|
||||
condition_variable get_next_cv_ TF_GUARDED_BY(mu_);
|
||||
condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
|
||||
condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
|
||||
@ -476,17 +532,31 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// at the same time. This count includes both in-progress requests for
|
||||
// elements as well as completed requests which haven't yet been produced.
|
||||
int64 max_outstanding_requests_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// The number of threads in `worker_threads_` which are still running.
|
||||
int64 num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The index of the next task in `tasks_` to read from.
|
||||
int64 next_task_index_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The number tasks in the `tasks_` list that have reached end_of_sequence.
|
||||
int64 finished_tasks_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// List of tasks to read from.
|
||||
std::vector<std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// A status to be returned from the next call to `GetNext`. This is set by
|
||||
// asynchronous threads when they encounter errors.
|
||||
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
|
||||
std::queue<std::vector<Tensor>> results_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Set once in Initialize().
|
||||
int64 job_id_;
|
||||
int64 num_unfinished_tasks_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
bool job_finished_ = false;
|
||||
// Must come second to last so that task threads are joined before
|
||||
// Must be ordered second to last so that worker threads are joined before
|
||||
// destroying other fields.
|
||||
absl::flat_hash_map<int64, std::unique_ptr<TaskThread>> task_threads_
|
||||
TF_GUARDED_BY(mu_);
|
||||
std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
|
||||
// Must be ordered last so that the thread is joined before destroying other
|
||||
// fields.
|
||||
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user