[tf.data service] Garbage collect old and unused jobs.

This CL adds a worker heartbeat which does the following:
- Registers the worker with the dispatcher if the worker is not yet registered
- Reports to the dispatcher which tasks the worker is currently processing
- Learns from the dispatcher which tasks should be added and which should be deleted.

Learning about new tasks is important in case the dispatcher's original ProcessTask request to the worker failed due to network issues. The heartbeat provides a backup mechanism where the worker can still learn about what tasks it should process.

Deleting old tasks is important for job lifecycle management. When the dispatcher
detects that a job is old and unused, it will mark the job (and all of its tasks) as finished. The worker will learn about the task finishing during the next heartbeat.

PiperOrigin-RevId: 330990312
Change-Id: Ifc0dbb3465ff478de83bfb3a265d9a57072d852a
This commit is contained in:
Andrew Audibert 2020-09-10 12:19:07 -07:00 committed by TensorFlower Gardener
parent 660f6e0610
commit 8e789c3872
26 changed files with 455 additions and 145 deletions

View File

@ -54,19 +54,26 @@ std::string ProcessingModeToString(ProcessingMode mode) {
}
}
Status DataServiceDispatcherClient::RegisterWorker(
const std::string& worker_address, std::vector<TaskDef>& tasks) {
Status DataServiceDispatcherClient::WorkerHeartbeat(
const std::string& worker_address, const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks, std::vector<int64>& tasks_to_delete) {
TF_RETURN_IF_ERROR(EnsureInitialized());
RegisterWorkerRequest req;
WorkerHeartbeatRequest req;
req.set_worker_address(worker_address);
RegisterWorkerResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->RegisterWorker(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to register worker", status);
for (int64 task : current_tasks) {
req.add_current_tasks(task);
}
for (const auto& task : resp.tasks()) {
tasks.push_back(task);
WorkerHeartbeatResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to perform worker heartbeat", status);
}
for (const auto& task : resp.new_tasks()) {
new_tasks.push_back(task);
}
for (int64 task_to_delete : resp.tasks_to_delete()) {
tasks_to_delete.push_back(task_to_delete);
}
return Status::OK();
}

View File

@ -73,10 +73,15 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Registers a worker with the dispatcher. The dispatcher returns a list of
// initial tasks for the worker to run, storing them in `tasks`.
Status RegisterWorker(const std::string& worker_address,
std::vector<TaskDef>& tasks);
// Sends a heartbeat to the dispatcher. If the worker wasn't already
// registered with the dispatcher, this will register the worker. The
// dispatcher will report which new tasks the worker should run, and which
// tasks it should delete. This is stored into `new_tasks` and
// `tasks_to_delete`.
Status WorkerHeartbeat(const std::string& worker_address,
const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks,
std::vector<int64>& tasks_to_delete);
// Updates the dispatcher with information about the worker's state.
Status WorkerUpdate(const std::string& worker_address,

View File

@ -4,16 +4,6 @@ package tensorflow.data;
import "tensorflow/core/data/service/common.proto";
message RegisterWorkerRequest {
// The address of the registering worker.
string worker_address = 1;
}
message RegisterWorkerResponse {
// Tasks to begin processing.
repeated TaskDef tasks = 2;
}
message TaskProgress {
// The task that this message is about.
int64 task_id = 1;
@ -21,6 +11,16 @@ message TaskProgress {
bool completed = 2;
}
message WorkerHeartbeatRequest {
string worker_address = 1;
repeated int64 current_tasks = 2;
}
message WorkerHeartbeatResponse {
repeated TaskDef new_tasks = 1;
repeated int64 tasks_to_delete = 2;
}
message WorkerUpdateRequest {
string worker_address = 1;
repeated TaskProgress updates = 2;
@ -110,8 +110,8 @@ message GetWorkersResponse {
}
service DispatcherService {
// Registers a worker with the dispatcher.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
// Performs a periodic worker heartbeat.
rpc WorkerHeartbeat(WorkerHeartbeatRequest) returns (WorkerHeartbeatResponse);
// Updates the dispatcher with information about the worker's state.
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);

View File

@ -85,7 +85,7 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
const experimental::DispatcherConfig& config)
: config_(config) {
: config_(config), env_(Env::Default()) {
if (config_.work_dir().empty()) {
dataset_store_ = absl::make_unique<MemoryDatasetStore>();
} else {
@ -94,8 +94,19 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
}
}
DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
{
mutex_lock l(mu_);
cancelled_ = true;
job_gc_thread_cv_.notify_all();
}
job_gc_thread_.reset();
}
Status DataServiceDispatcherImpl::Start() {
mutex_lock l(mu_);
job_gc_thread_ = absl::WrapUnique(
env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
if (config_.work_dir().empty()) {
if (config_.fault_tolerant_mode()) {
return errors::InvalidArgument(
@ -103,7 +114,7 @@ Status DataServiceDispatcherImpl::Start() {
}
} else {
TF_RETURN_IF_ERROR(
Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
}
if (!config_.fault_tolerant_mode()) {
LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
@ -111,12 +122,12 @@ Status DataServiceDispatcherImpl::Start() {
return Status::OK();
}
journal_writer_ = absl::make_unique<FileJournalWriter>(
Env::Default(), JournalDir(config_.work_dir()));
LOG(INFO) << "Restoring dispatcher state from journal in "
env_, JournalDir(config_.work_dir()));
LOG(INFO) << "Attempting to restore dispatcher state from journal in "
<< JournalDir(config_.work_dir());
Update update;
bool end_of_journal = false;
FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir()));
FileJournalReader reader(env_, JournalDir(config_.work_dir()));
Status s = reader.Read(update, end_of_journal);
if (errors::IsNotFound(s)) {
LOG(INFO) << "No journal found. Starting dispatcher from new state.";
@ -134,45 +145,38 @@ Status DataServiceDispatcherImpl::Start() {
return Status::OK();
}
Status DataServiceDispatcherImpl::RegisterWorker(
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
VLOG(3) << "Received register worker request";
Status DataServiceDispatcherImpl::WorkerHeartbeat(
const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
VLOG(3) << "Received worker heartbeat request from worker "
<< request->worker_address();
mutex_lock l(mu_);
std::string worker_address = request->worker_address();
std::vector<std::shared_ptr<const Task>> tasks;
Status s = state_.TasksForWorker(worker_address, tasks);
if (errors::IsNotFound(s)) {
const std::string& worker_address = request->worker_address();
std::vector<std::shared_ptr<const Task>> correct_tasks;
Status s = state_.TasksForWorker(worker_address, correct_tasks);
if (!s.ok()) {
if (!errors::IsNotFound(s)) {
return s;
}
Update update;
update.mutable_register_worker()->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(Apply(update));
} else if (!s.ok()) {
return s;
TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
}
absl::flat_hash_map<int64, std::shared_ptr<const Task>> tasks_by_job;
for (const auto& task : tasks) {
// Should never have multiple tasks on the same worker for the same job.
auto& task_for_job = tasks_by_job[task->job_id];
DCHECK(task_for_job == nullptr);
task_for_job = task;
}
absl::flat_hash_set<int64> current_tasks;
current_tasks.insert(request->current_tasks().cbegin(),
request->current_tasks().cend());
absl::flat_hash_set<int64> correct_tasks_set;
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
// Allocate tasks to the worker.
for (const auto& job : jobs) {
if (job->finished) {
for (const auto& task : correct_tasks) {
correct_tasks_set.insert(task->task_id);
if (current_tasks.contains(task->task_id)) {
continue;
}
std::shared_ptr<const Task> task;
auto it = tasks_by_job.find(job->job_id);
if (it != tasks_by_job.end()) {
task = it->second;
} else {
TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
}
TaskDef* task_def = response->add_tasks();
TaskDef* task_def = response->add_new_tasks();
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset));
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset));
std::string dataset_key =
DatasetKey(dataset->dataset_id, dataset->fingerprint);
if (config_.work_dir().empty()) {
@ -184,12 +188,18 @@ Status DataServiceDispatcherImpl::RegisterWorker(
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
task_def->set_path(path);
}
task_def->set_dataset_id(job->dataset_id);
task_def->set_job_id(job->job_id);
task_def->set_dataset_id(task->dataset_id);
task_def->set_job_id(task->job_id);
task_def->set_task_id(task->task_id);
}
for (int64 current_task : current_tasks) {
if (!correct_tasks_set.contains(current_task)) {
response->add_tasks_to_delete(current_task);
}
}
VLOG(1) << "Registered worker at address " << request->worker_address();
VLOG(1) << "Finished worker heartbeat for worker at address "
<< request->worker_address();
return Status::OK();
}
@ -346,7 +356,7 @@ Status DataServiceDispatcherImpl::ReleaseJobClient(
ReleaseJobClientUpdate* release_job_client =
update.mutable_release_job_client();
release_job_client->set_job_client_id(job_client_id);
release_job_client->set_time_micros(Env::Default()->NowMicros());
release_job_client->set_time_micros(env_->NowMicros());
TF_RETURN_IF_ERROR(Apply(update));
return Status::OK();
}
@ -412,6 +422,19 @@ Status DataServiceDispatcherImpl::CreateJob(
return Status::OK();
}
Status DataServiceDispatcherImpl::CreateTasksForWorker(
const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
for (const auto& job : jobs) {
if (job->finished) {
continue;
}
std::shared_ptr<const Task> task;
TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
}
return Status::OK();
}
Status DataServiceDispatcherImpl::AcquireJobClientId(
const std::shared_ptr<const Job>& job, int64& job_client_id)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@ -576,5 +599,51 @@ Status DataServiceDispatcherImpl::Apply(const Update& update)
return state_.Apply(update);
}
void DataServiceDispatcherImpl::JobGcThread() {
int64 next_check_micros = 0;
while (true) {
mutex_lock l(mu_);
while (!cancelled_ && env_->NowMicros() < next_check_micros) {
int64 remaining_micros = next_check_micros - env_->NowMicros();
job_gc_thread_cv_.wait_for(l,
std::chrono::microseconds(remaining_micros));
}
if (cancelled_) {
return;
}
Status s = GcOldJobs();
if (!s.ok()) {
LOG(WARNING) << "Error garbage collecting old jobs: " << s;
}
next_check_micros =
env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
}
}
Status DataServiceDispatcherImpl::GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
int64 now = env_->NowMicros();
for (const auto& job : jobs) {
if (job->finished || job->num_clients > 0 ||
job->last_client_released_micros < 0 ||
now < job->last_client_released_micros +
(config_.job_gc_timeout_ms() * 1000)) {
continue;
}
std::vector<std::shared_ptr<const Task>> tasks;
TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
for (const auto& task : tasks) {
if (task->finished) {
continue;
}
Update update;
update.mutable_finish_task()->set_task_id(task->task_id);
TF_RETURN_IF_ERROR(state_.Apply(update));
}
DCHECK(job->finished);
}
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -48,6 +48,8 @@ class DataServiceDispatcherImpl {
explicit DataServiceDispatcherImpl(
const experimental::DispatcherConfig& config);
~DataServiceDispatcherImpl();
// Starts the dispatcher. If there is a journal, this will read from the
// journal to restore the dispatcher's state.
Status Start();
@ -55,8 +57,8 @@ class DataServiceDispatcherImpl {
// See dispatcher.proto for API documentation.
/// Worker-facing API.
Status RegisterWorker(const RegisterWorkerRequest* request,
RegisterWorkerResponse* response);
Status WorkerHeartbeat(const WorkerHeartbeatRequest* request,
WorkerHeartbeatResponse* response);
Status WorkerUpdate(const WorkerUpdateRequest* request,
WorkerUpdateResponse* response);
Status GetDatasetDef(const GetDatasetDefRequest* request,
@ -92,6 +94,8 @@ class DataServiceDispatcherImpl {
absl::optional<DispatcherState::NamedJobKey> named_job_key,
std::shared_ptr<const DispatcherState::Job>& job)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates tasks for the specified worker, one task for every unfinished job.
Status CreateTasksForWorker(const std::string& worker_address);
// Acquires a job client id to read from the given job and sets
// `job_client_id`.
Status AcquireJobClientId(
@ -128,12 +132,16 @@ class DataServiceDispatcherImpl {
// used when recovering state when the dispatcher starts.
Status ApplyWithoutJournaling(const Update& update)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// A thread which periodically checks for jobs to clean up.
void JobGcThread();
// Scans for old jobs and marks them as finished.
Status GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
const experimental::DispatcherConfig& config_;
Env* env_;
mutex mu_;
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Cached worker stubs for communicating with workers.
absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
@ -144,6 +152,9 @@ class DataServiceDispatcherImpl {
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
TF_GUARDED_BY(mu_);
DispatcherState state_ TF_GUARDED_BY(mu_);
// Condition variable for waking up the job gc thread.
condition_variable job_gc_thread_cv_;
std::unique_ptr<Thread> job_gc_thread_;
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
};

View File

@ -72,7 +72,8 @@ void DispatcherState::RegisterWorker(
std::string address = register_worker.worker_address();
DCHECK(!workers_.contains(address));
workers_[address] = std::make_shared<Worker>(address);
tasks_by_worker_[address] = std::vector<std::shared_ptr<Task>>();
tasks_by_worker_[address] =
absl::flat_hash_map<int64, std::shared_ptr<Task>>();
}
void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
@ -126,7 +127,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
create_task.dataset_id(),
create_task.worker_address());
tasks_by_job_[create_task.job_id()].push_back(task);
tasks_by_worker_[create_task.worker_address()].push_back(task);
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
}
@ -136,6 +137,7 @@ void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
auto& task = tasks_[task_id];
DCHECK(task != nullptr);
task->finished = true;
tasks_by_worker_[task->worker_address].erase(task->task_id);
bool all_finished = true;
for (const auto& task_for_job : tasks_by_job_[task->job_id]) {
if (!task_for_job->finished) {
@ -269,10 +271,11 @@ Status DispatcherState::TasksForWorker(
if (it == tasks_by_worker_.end()) {
return errors::NotFound("Worker ", worker_address, " not found");
}
std::vector<std::shared_ptr<Task>> worker_tasks = it->second;
const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
it->second;
tasks.reserve(worker_tasks.size());
for (const auto& task : worker_tasks) {
tasks.push_back(task);
tasks.push_back(task.second);
}
return Status::OK();
}

View File

@ -179,7 +179,7 @@ class DispatcherState {
void CreateTask(const CreateTaskUpdate& create_task);
void FinishTask(const FinishTaskUpdate& finish_task);
int64 next_available_dataset_id_ = 0;
int64 next_available_dataset_id_ = 1000;
// Registered datasets, keyed by dataset ids.
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_;
// Registered datasets, keyed by dataset fingerprints.
@ -189,24 +189,26 @@ class DispatcherState {
// Registered workers, keyed by address.
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
int64 next_available_job_id_ = 0;
int64 next_available_job_id_ = 2000;
// Jobs, keyed by job ids.
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
// Named jobs, keyed by their names and indices. Not all jobs have names, so
// this is a subset of the jobs stored in `jobs_`.
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_;
int64 next_available_job_client_id_ = 0;
int64 next_available_job_client_id_ = 3000;
// Mapping from client ids to the jobs they are associated with.
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_;
int64 next_available_task_id_ = 0;
int64 next_available_task_id_ = 4000;
// Tasks, keyed by task ids.
absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_;
// Tasks, keyed by job ids.
absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_;
// Tasks, keyed by worker addresses.
absl::flat_hash_map<std::string, std::vector<std::shared_ptr<Task>>>
// Tasks, keyed by worker addresses. The values are a map from task id to
// task.
absl::flat_hash_map<std::string,
absl::flat_hash_map<int64, std::shared_ptr<Task>>>
tasks_by_worker_;
};

View File

@ -125,9 +125,9 @@ Status FinishTask(int64 task_id, DispatcherState& state) {
} // namespace
TEST(DispatcherState, RegisterDataset) {
int64 id = 10;
uint64 fingerprint = 20;
DispatcherState state;
int64 id = state.NextAvailableDatasetId();
TF_EXPECT_OK(RegisterDataset(id, fingerprint, state));
EXPECT_EQ(state.NextAvailableDatasetId(), id + 1);
@ -210,9 +210,9 @@ TEST(DispatcherState, UnknownUpdate) {
}
TEST(DispatcherState, AnonymousJob) {
int64 job_id = 3;
int64 dataset_id = 10;
DispatcherState state;
int64 job_id = state.NextAvailableJobId();
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
std::shared_ptr<const Job> job;
@ -227,9 +227,9 @@ TEST(DispatcherState, AnonymousJob) {
}
TEST(DispatcherState, NamedJob) {
int64 job_id = 3;
int64 dataset_id = 10;
DispatcherState state;
int64 job_id = state.NextAvailableJobId();
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
NamedJobKey named_job_key("test", 1);
TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state));
@ -244,9 +244,9 @@ TEST(DispatcherState, NamedJob) {
TEST(DispatcherState, CreateTask) {
int64 job_id = 3;
int64 dataset_id = 10;
int64 task_id = 8;
std::string worker_address = "test_worker_address";
DispatcherState state;
int64 task_id = state.NextAvailableTaskId();
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state));

View File

@ -40,7 +40,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); }
method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \
}
HANDLER(RegisterWorker);
HANDLER(WorkerHeartbeat);
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetOrRegisterDataset);

View File

@ -39,7 +39,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
::grpc::Status method(::grpc::ServerContext* context, \
const method##Request* request, \
method##Response* response) override;
HANDLER(RegisterWorker);
HANDLER(WorkerHeartbeat);
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetOrRegisterDataset);

View File

@ -43,6 +43,7 @@ Status GrpcWorkerImpl::Start(const std::string& worker_address) {
}
HANDLER(ProcessTask);
HANDLER(GetElement);
HANDLER(GetWorkerTasks);
#undef HANDLER
} // namespace data

View File

@ -41,6 +41,7 @@ class GrpcWorkerImpl : public WorkerService::Service {
method##Response* response) override;
HANDLER(ProcessTask);
HANDLER(GetElement);
HANDLER(GetWorkerTasks);
#undef HANDLER
private:

View File

@ -138,6 +138,18 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
return Status::OK();
}
Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
GetWorkerTasksRequest req;
GetWorkerTasksResponse resp;
::grpc::ServerContext ctx;
::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to get tasks", s);
}
*num_tasks = resp.tasks_size();
return Status::OK();
}
Status NewDispatchServer(const experimental::DispatcherConfig& config,
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
out_server = absl::make_unique<DispatchGrpcDataServer>(config);

View File

@ -98,6 +98,9 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config);
~WorkerGrpcDataServer() override;
// Returns the number of tasks currently being executed by the worker.
Status NumTasks(int* num_tasks);
protected:
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
Status StartServiceInternal() override;

View File

@ -23,10 +23,20 @@ message GetElementResponse {
bool end_of_sequence = 2;
}
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
message GetWorkerTasksRequest {}
message GetWorkerTasksResponse {
repeated TaskInfo tasks = 1;
}
service WorkerService {
// Processes an task for a dataset, making elements available to clients.
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
// Gets the next dataset element.
rpc GetElement(GetElementRequest) returns (GetElementResponse);
// Gets the tasks currently being executed by the worker.
rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse);
}

View File

@ -56,7 +56,8 @@ DataServiceWorkerImpl::DataServiceWorkerImpl(
DataServiceWorkerImpl::~DataServiceWorkerImpl() {
mutex_lock l(mu_);
cancelled_ = true;
background_cv_.notify_one();
task_completion_cv_.notify_one();
heartbeat_cv_.notify_one();
}
Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
@ -67,18 +68,20 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
config_.dispatcher_address(), config_.protocol());
TF_RETURN_IF_ERROR(dispatcher_->Initialize());
Status s = Register();
Status s = Heartbeat();
while (!s.ok()) {
LOG(WARNING) << "Failed to register with dispatcher at "
<< config_.dispatcher_address() << ": " << s;
Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
s = Register();
s = Heartbeat();
}
Thread* thread = Env::Default()->StartThread(
{}, "data-service-worker-background", [this]() { BackgroundThread(); });
LOG(INFO) << "Worker registered with dispatcher running at "
<< config_.dispatcher_address();
background_thread_.reset(thread);
task_completion_thread_ = absl::WrapUnique(
Env::Default()->StartThread({}, "data-service-worker-task-completion",
[this]() { TaskCompletionThread(); }));
heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
{}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
mutex_lock l(mu_);
registered_ = true;
return Status::OK();
@ -96,8 +99,9 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
if (task) {
return errors::AlreadyExists("A task with id ", task_def.task_id(),
" already exists.");
VLOG(1) << "Received request to process already-processed task "
<< task->task_def.task_id();
return Status::OK();
}
task = absl::make_unique<Task>(task_def);
VLOG(3) << "Began processing for task " << task_def.task_id();
@ -156,24 +160,17 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
}
auto it = tasks_.find(request->task_id());
if (it == tasks_.end()) {
return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
"Task id ", request->task_id(), " not found");
}
auto& task = it->second;
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
std::unique_ptr<standalone::Iterator>& iter = task->iterator;
if (iter == nullptr) {
VLOG(3) << "Task " << request->task_id() << " is already finished";
response->set_end_of_sequence(true);
return Status::OK();
}
TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence));
auto& task = it->second;
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence));
if (end_of_sequence) {
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
// Release iterator memory and leave a null entry as a tombstone.
iter.reset();
tasks_.erase(request->task_id());
pending_completed_tasks_.insert(request->task_id());
background_cv_.notify_one();
task_completion_cv_.notify_one();
}
}
@ -212,27 +209,28 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
return Status::OK();
}
Status DataServiceWorkerImpl::Register() LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address();
std::vector<TaskDef> tasks;
TF_RETURN_IF_ERROR(dispatcher_->RegisterWorker(worker_address_, tasks));
for (const TaskDef& task : tasks) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(ProcessTaskInternal(task));
Status DataServiceWorkerImpl::GetWorkerTasks(
const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
mutex_lock l(mu_);
for (const auto& it : tasks_) {
Task* task = it.second.get();
TaskInfo* task_info = response->add_tasks();
task_info->set_worker_address(worker_address_);
task_info->set_task_id(task->task_def.task_id());
task_info->set_job_id(task->task_def.job_id());
}
VLOG(3) << "Registered worker with address " << worker_address_;
return Status::OK();
}
void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) {
void DataServiceWorkerImpl::TaskCompletionThread() LOCKS_EXCLUDED(mu_) {
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ && pending_completed_tasks_.empty()) {
background_cv_.wait(l);
task_completion_cv_.wait(l);
}
if (cancelled_) {
VLOG(3) << "Background thread shutting down";
VLOG(3) << "Task completion thread shutting down";
return;
}
}
@ -241,7 +239,7 @@ void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) {
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
mutex_lock l(mu_);
if (!cancelled_) {
background_cv_.wait_for(
task_completion_cv_.wait_for(
l, std::chrono::microseconds(kRetryIntervalMicros));
}
}
@ -271,5 +269,62 @@ Status DataServiceWorkerImpl::SendTaskUpdates() LOCKS_EXCLUDED(mu_) {
return Status::OK();
}
void DataServiceWorkerImpl::HeartbeatThread() LOCKS_EXCLUDED(mu_) {
while (true) {
int64 next_heartbeat_micros =
Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
{
mutex_lock l(mu_);
while (!cancelled_ &&
Env::Default()->NowMicros() < next_heartbeat_micros) {
int64 time_to_wait_micros =
next_heartbeat_micros - Env::Default()->NowMicros();
heartbeat_cv_.wait_for(l,
std::chrono::microseconds(time_to_wait_micros));
}
if (cancelled_) {
VLOG(3) << "Heartbeat thread shutting down";
return;
}
if (!registered_) {
VLOG(1) << "Not performing heartbeat; worker is not yet registered";
continue;
}
}
Status s = Heartbeat();
if (!s.ok()) {
LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
}
}
}
Status DataServiceWorkerImpl::Heartbeat() LOCKS_EXCLUDED(mu_) {
std::vector<int64> current_tasks;
{
mutex_lock l(mu_);
for (const auto& task : tasks_) {
current_tasks.push_back(task.first);
}
}
std::vector<TaskDef> new_tasks;
std::vector<int64> tasks_to_delete;
TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
worker_address_, current_tasks, new_tasks, tasks_to_delete));
mutex_lock l(mu_);
for (const auto& task : new_tasks) {
Status s = ProcessTaskInternal(task);
if (!s.ok() && !errors::IsAlreadyExists(s)) {
LOG(WARNING) << "Failed to start processing task " << task.task_id()
<< ": " << s;
}
}
for (int64 task_id : tasks_to_delete) {
VLOG(3) << "Deleting task " << task_id
<< " at the request of the dispatcher";
tasks_.erase(task_id);
}
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -50,6 +50,8 @@ class DataServiceWorkerImpl {
/// Client-facing API.
Status GetElement(const GetElementRequest* request,
GetElementResponse* response);
Status GetWorkerTasks(const GetWorkerTasksRequest* request,
GetWorkerTasksResponse* response);
private:
struct Task {
@ -64,16 +66,17 @@ class DataServiceWorkerImpl {
std::unique_ptr<standalone::Iterator> iterator;
};
// Registers the worker with the dispatcher.
Status Register() LOCKS_EXCLUDED(mu_);
// Sends task status to the dispatcher and checks for dispatcher commands.
Status SendTaskUpdates() LOCKS_EXCLUDED(mu_);
// Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureTaskInitialized(Task& task);
// A thread for doing async background processing not associated with a
// specific RPC, such as reporting finished tasks.
void BackgroundThread() LOCKS_EXCLUDED(mu_);
// A thread for notifying the dispatcher when tasks complete.
void TaskCompletionThread() LOCKS_EXCLUDED(mu_);
// A thread for doing periodic heartbeats to the dispatcher.
void HeartbeatThread() LOCKS_EXCLUDED(mu_);
// Performs a heartbeat to the dispatcher.
Status Heartbeat() LOCKS_EXCLUDED(mu_);
const experimental::WorkerConfig config_;
// The worker's own address.
@ -88,9 +91,12 @@ class DataServiceWorkerImpl {
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Whether the worker has registered with the dispatcher yet.
bool registered_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the background thread.
condition_variable background_cv_ TF_GUARDED_BY(mu_);
std::unique_ptr<Thread> background_thread_;
// A thread for notifying the dispatcher when tasks complete.
std::unique_ptr<Thread> task_completion_thread_;
condition_variable task_completion_cv_ TF_GUARDED_BY(mu_);
// A thread for performing regular heartbeats to the dispatcher.
std::unique_ptr<Thread> heartbeat_thread_;
condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
};

View File

@ -200,6 +200,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
for (auto& worker_thread : worker_threads_) {
worker_thread.reset();
}
VLOG(1) << "Destroyed data service dataset iterator for job id "
<< job_client_id_;
}
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {

View File

@ -15,6 +15,11 @@ message DispatcherConfig {
// Whether to run in fault tolerant mode, where dispatcher state is saved
// across restarts. Requires that `work_dir` is nonempty.
bool fault_tolerant_mode = 4;
// How often the dispatcher should scan through to delete old and unused jobs.
int64 job_gc_check_interval_ms = 5;
// How long a job needs to be unused before it becomes a candidate for garbage
// collection.
int64 job_gc_timeout_ms = 6;
}
// Configuration for a tf.data service WorkerServer.
@ -30,4 +35,6 @@ message WorkerConfig {
// will be replaced with the worker's bound port. This is useful when the port
// is set to `0`.
string worker_address = 4;
// How often the worker should heartbeat to the master.
int64 heartbeat_interval_ms = 5;
}

View File

@ -29,9 +29,10 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.DispatcherConfig")
class DispatcherConfig(
collections.namedtuple(
"DispatcherConfig",
["port", "protocol", "work_dir", "fault_tolerant_mode"])):
collections.namedtuple("DispatcherConfig", [
"port", "protocol", "work_dir", "fault_tolerant_mode",
"job_gc_check_interval_ms", "job_gc_timeout_ms"
])):
"""Configuration class for tf.data service dispatchers.
Fields:
@ -47,15 +48,34 @@ class DispatcherConfig(
registered datasets and created jobs, is synchronously written to the
journal before responding to RPCs. If `True`, `work_dir` must also be
specified.
job_gc_check_interval_ms: How often the dispatcher should scan through to
delete old and unused jobs, in milliseconds. If not set, the runtime will
select a reasonable default. A higher value will reduce load on the
dispatcher, while a lower value will reduce the time it takes for the
dispatcher to garbage collect expired jobs.
job_gc_timeout_ms: How long a job needs to be unused before it becomes a
candidate for garbage collection, in milliseconds. If not set, the runtime
will select a reasonable default. A higher value will cause jobs to stay
around longer with no consumers. This is useful if there is a large gap in
time between when consumers read from the job. A lower value will reduce
the time it takes to reclaim the resources from expired jobs.
"""
def __new__(cls,
port=0,
protocol="grpc",
work_dir=None,
fault_tolerant_mode=False):
return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir,
fault_tolerant_mode)
fault_tolerant_mode=False,
job_gc_check_interval_ms=None,
job_gc_timeout_ms=None):
if job_gc_check_interval_ms is None:
job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes.
if job_gc_timeout_ms is None:
job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes.
return super(DispatcherConfig,
cls).__new__(cls, port, protocol, work_dir,
fault_tolerant_mode, job_gc_check_interval_ms,
job_gc_timeout_ms)
@tf_export("data.experimental.service.DispatchServer", v1=[])
@ -116,7 +136,9 @@ class DispatchServer(object):
port=config.port,
protocol=config.protocol,
work_dir=config.work_dir,
fault_tolerant_mode=config.fault_tolerant_mode)
fault_tolerant_mode=config.fault_tolerant_mode,
job_gc_check_interval_ms=config.job_gc_check_interval_ms,
job_gc_timeout_ms=config.job_gc_timeout_ms)
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
config_proto.SerializeToString())
if start:
@ -193,9 +215,10 @@ class DispatchServer(object):
@tf_export("data.experimental.service.WorkerConfig")
class WorkerConfig(
collections.namedtuple(
"WorkerConfig",
["dispatcher_address", "worker_address", "port", "protocol"])):
collections.namedtuple("WorkerConfig", [
"dispatcher_address", "worker_address", "port", "protocol",
"heartbeat_interval_ms"
])):
"""Configuration class for tf.data service dispatchers.
Fields:
@ -205,19 +228,29 @@ class WorkerConfig(
connect to this worker.
port: Specifies the port to bind to. A value of 0 indicates that the worker
can bind to any available port.
protocol: (Optional.) Specifies the protocol to be used by the server.
protocol: Specifies the protocol to be used by the server.
Acceptable values include `"grpc" and "grpc+local"`.
heartbeat_interval_ms: How often the worker should heartbeat to the
dispatcher, in milliseconds. If not set, the runtime will select a
reasonable default. A higher value will reduce the load on the dispatcher,
while a lower value will reduce the time it takes to reclaim resources
from finished jobs.
"""
def __new__(cls,
dispatcher_address,
worker_address=None,
port=0,
protocol="grpc"):
worker_address = ("localhost:%port%"
if worker_address is None else worker_address)
return super(WorkerConfig, cls).__new__(cls, dispatcher_address,
worker_address, port, protocol)
protocol="grpc",
heartbeat_interval_ms=None):
if worker_address is None:
worker_address = "localhost:%port%"
if heartbeat_interval_ms is None:
heartbeat_interval_ms = 30 * 1000 # 30 seconds
return super(WorkerConfig,
cls).__new__(cls, dispatcher_address, worker_address, port,
protocol, heartbeat_interval_ms)
@tf_export("data.experimental.service.WorkerServer", v1=[])
@ -264,7 +297,8 @@ class WorkerServer(object):
dispatcher_address=config.dispatcher_address,
worker_address=config.worker_address,
port=config.port,
protocol=config.protocol)
protocol=config.protocol,
heartbeat_interval_ms=config.heartbeat_interval_ms)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config_proto.SerializeToString())
if start:
@ -317,3 +351,7 @@ class WorkerServer(object):
The returned string will be in the form address:port, e.g. "localhost:1000".
"""
return "localhost:{0}".format(self._server.bound_port())
def _num_tasks(self):
"""Returns the number of tasks currently being executed on the worker."""
return self._server.num_tasks()

View File

@ -50,7 +50,14 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
.def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop)
.def("join", &tensorflow::data::WorkerGrpcDataServer::Join,
py::call_guard<py::gil_scoped_release>())
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort)
.def("num_tasks",
[](tensorflow::data::WorkerGrpcDataServer* server) -> int {
int num_tasks;
tensorflow::Status status = server->NumTasks(&num_tasks);
tensorflow::MaybeRaiseFromStatus(status);
return num_tasks;
});
m.def(
"TF_DATA_NewDispatchServer",

View File

@ -101,7 +101,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
name="",
port=0,
work_dir=None,
fault_tolerant_mode=True):
fault_tolerant_mode=True,
job_gc_check_interval_ms=None,
job_gc_timeout_ms=None):
# If a test starts multiple independent dispatch servers, it should give
# them different `name` values.
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
@ -110,13 +112,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
server_lib.DispatcherConfig(
port=port,
work_dir=work_dir,
fault_tolerant_mode=fault_tolerant_mode))
fault_tolerant_mode=fault_tolerant_mode,
job_gc_check_interval_ms=job_gc_check_interval_ms,
job_gc_timeout_ms=job_gc_timeout_ms))
def start_worker_server(self, dispatcher, port=0):
return server_lib.WorkerServer(
server_lib.WorkerConfig(
dispatcher_address=_address_from_target(dispatcher.target),
port=port))
port=port,
heartbeat_interval_ms=200))
def restart_dispatcher(self, dispatcher):
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
@ -535,6 +540,47 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(elem.numpy())
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
@combinations.generate(
combinations.times(test_base.eager_only_combinations(),
combinations.combine(job_name=[None, "test"])))
def testGcUnusedJob(self, job_name):
dispatcher = self.start_dispatch_server(
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
num_elements = 10
ds = _make_distributed_range_dataset(
num_elements, dispatcher, job_name=job_name)
it = iter(ds)
self.assertEqual(0, next(it).numpy())
self.assertEqual(1, worker._num_tasks())
del it
while worker._num_tasks() > 0:
time.sleep(0.1)
@combinations.generate(test_base.eager_only_combinations())
def testDontGcUsedJob(self):
dispatcher = self.start_dispatch_server(
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
num_elements = 10
it1 = iter(
_make_distributed_range_dataset(
num_elements, dispatcher, job_name="test1"))
it2 = iter(
_make_distributed_range_dataset(
num_elements, dispatcher, job_name="test2"))
it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable
_make_distributed_range_dataset(
num_elements, dispatcher, job_name="test2"))
self.assertEqual(2, worker._num_tasks())
del it1
del it2
# Check that only the first job is gced. The second job will not be gced
# because there is still an outstanding iterator for it.
while worker._num_tasks() > 1:
time.sleep(0.1)
self.assertEqual(1, worker._num_tasks())
@combinations.generate(test_base.eager_only_combinations())
def testApplyDeterminismOption(self):
elements = list(range(10))

View File

@ -7,6 +7,14 @@ tf_class {
name: "fault_tolerant_mode"
mtype: "<type \'property\'>"
}
member {
name: "job_gc_check_interval_ms"
mtype: "<type \'property\'>"
}
member {
name: "job_gc_timeout_ms"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
name: "dispatcher_address"
mtype: "<type \'property\'>"
}
member {
name: "heartbeat_interval_ms"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,14 @@ tf_class {
name: "fault_tolerant_mode"
mtype: "<type \'property\'>"
}
member {
name: "job_gc_check_interval_ms"
mtype: "<type \'property\'>"
}
member {
name: "job_gc_timeout_ms"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
name: "dispatcher_address"
mtype: "<type \'property\'>"
}
member {
name: "heartbeat_interval_ms"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"