From 8e789c38727f955fd9735e41c0155a536f0d30b6 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Thu, 10 Sep 2020 12:19:07 -0700 Subject: [PATCH] [tf.data service] Garbage collect old and unused jobs. This CL adds a worker heartbeat which does the following: - Registers the worker with the dispatcher if the worker is not yet registered - Reports to the dispatcher which tasks the worker is currently processing - Learns from the dispatcher which tasks should be added and which should be deleted. Learning about new tasks is important in case the dispatcher's original ProcessTask request to the worker failed due to network issues. The heartbeat provides a backup mechanism where the worker can still learn about what tasks it should process. Deleting old tasks is important for job lifecycle management. When the dispatcher detects that a job is old and unused, it will mark the job (and all of its tasks) as finished. The worker will learn about the task finishing during the next heartbeat. PiperOrigin-RevId: 330990312 Change-Id: Ifc0dbb3465ff478de83bfb3a265d9a57072d852a --- tensorflow/core/data/service/data_service.cc | 27 ++-- tensorflow/core/data/service/data_service.h | 13 +- tensorflow/core/data/service/dispatcher.proto | 24 +-- .../core/data/service/dispatcher_impl.cc | 145 +++++++++++++----- .../core/data/service/dispatcher_impl.h | 19 ++- .../core/data/service/dispatcher_state.cc | 11 +- .../core/data/service/dispatcher_state.h | 14 +- .../data/service/dispatcher_state_test.cc | 8 +- .../core/data/service/grpc_dispatcher_impl.cc | 2 +- .../core/data/service/grpc_dispatcher_impl.h | 2 +- .../core/data/service/grpc_worker_impl.cc | 1 + .../core/data/service/grpc_worker_impl.h | 1 + tensorflow/core/data/service/server_lib.cc | 12 ++ tensorflow/core/data/service/server_lib.h | 3 + tensorflow/core/data/service/worker.proto | 10 ++ tensorflow/core/data/service/worker_impl.cc | 119 ++++++++++---- tensorflow/core/data/service/worker_impl.h | 22 ++- .../experimental/data_service_dataset_op.cc | 3 + .../data/experimental/service_config.proto | 7 + .../data/experimental/service/server_lib.py | 72 +++++++-- .../service/server_lib_wrapper.cc | 9 +- .../kernel_tests/data_service_ops_test.py | 52 ++++++- ...erimental.service.-dispatcher-config.pbtxt | 8 + ....experimental.service.-worker-config.pbtxt | 4 + ...erimental.service.-dispatcher-config.pbtxt | 8 + ....experimental.service.-worker-config.pbtxt | 4 + 26 files changed, 455 insertions(+), 145 deletions(-) diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index d425dab46dc..cc50e0f3e5a 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -54,19 +54,26 @@ std::string ProcessingModeToString(ProcessingMode mode) { } } -Status DataServiceDispatcherClient::RegisterWorker( - const std::string& worker_address, std::vector& tasks) { +Status DataServiceDispatcherClient::WorkerHeartbeat( + const std::string& worker_address, const std::vector& current_tasks, + std::vector& new_tasks, std::vector& tasks_to_delete) { TF_RETURN_IF_ERROR(EnsureInitialized()); - RegisterWorkerRequest req; + WorkerHeartbeatRequest req; req.set_worker_address(worker_address); - RegisterWorkerResponse resp; - grpc::ClientContext client_ctx; - grpc::Status status = stub_->RegisterWorker(&client_ctx, req, &resp); - if (!status.ok()) { - return grpc_util::WrapError("Failed to register worker", status); + for (int64 task : current_tasks) { + req.add_current_tasks(task); } - for (const auto& task : resp.tasks()) { - tasks.push_back(task); + WorkerHeartbeatResponse resp; + grpc::ClientContext client_ctx; + grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp); + if (!status.ok()) { + return grpc_util::WrapError("Failed to perform worker heartbeat", status); + } + for (const auto& task : resp.new_tasks()) { + new_tasks.push_back(task); + } + for (int64 task_to_delete : resp.tasks_to_delete()) { + tasks_to_delete.push_back(task_to_delete); } return Status::OK(); } diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index c5eb6a37269..f0adbb3d4eb 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -73,10 +73,15 @@ class DataServiceDispatcherClient : public DataServiceClientBase { const std::string& protocol) : DataServiceClientBase(address, protocol) {} - // Registers a worker with the dispatcher. The dispatcher returns a list of - // initial tasks for the worker to run, storing them in `tasks`. - Status RegisterWorker(const std::string& worker_address, - std::vector& tasks); + // Sends a heartbeat to the dispatcher. If the worker wasn't already + // registered with the dispatcher, this will register the worker. The + // dispatcher will report which new tasks the worker should run, and which + // tasks it should delete. This is stored into `new_tasks` and + // `tasks_to_delete`. + Status WorkerHeartbeat(const std::string& worker_address, + const std::vector& current_tasks, + std::vector& new_tasks, + std::vector& tasks_to_delete); // Updates the dispatcher with information about the worker's state. Status WorkerUpdate(const std::string& worker_address, diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index cf8c4c20c70..ffa3eb6b5ca 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -4,16 +4,6 @@ package tensorflow.data; import "tensorflow/core/data/service/common.proto"; -message RegisterWorkerRequest { - // The address of the registering worker. - string worker_address = 1; -} - -message RegisterWorkerResponse { - // Tasks to begin processing. - repeated TaskDef tasks = 2; -} - message TaskProgress { // The task that this message is about. int64 task_id = 1; @@ -21,6 +11,16 @@ message TaskProgress { bool completed = 2; } +message WorkerHeartbeatRequest { + string worker_address = 1; + repeated int64 current_tasks = 2; +} + +message WorkerHeartbeatResponse { + repeated TaskDef new_tasks = 1; + repeated int64 tasks_to_delete = 2; +} + message WorkerUpdateRequest { string worker_address = 1; repeated TaskProgress updates = 2; @@ -110,8 +110,8 @@ message GetWorkersResponse { } service DispatcherService { - // Registers a worker with the dispatcher. - rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); + // Performs a periodic worker heartbeat. + rpc WorkerHeartbeat(WorkerHeartbeatRequest) returns (WorkerHeartbeatResponse); // Updates the dispatcher with information about the worker's state. rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index da9f99e093c..4477848d024 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -85,7 +85,7 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol, DataServiceDispatcherImpl::DataServiceDispatcherImpl( const experimental::DispatcherConfig& config) - : config_(config) { + : config_(config), env_(Env::Default()) { if (config_.work_dir().empty()) { dataset_store_ = absl::make_unique(); } else { @@ -94,8 +94,19 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl( } } +DataServiceDispatcherImpl::~DataServiceDispatcherImpl() { + { + mutex_lock l(mu_); + cancelled_ = true; + job_gc_thread_cv_.notify_all(); + } + job_gc_thread_.reset(); +} + Status DataServiceDispatcherImpl::Start() { mutex_lock l(mu_); + job_gc_thread_ = absl::WrapUnique( + env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); })); if (config_.work_dir().empty()) { if (config_.fault_tolerant_mode()) { return errors::InvalidArgument( @@ -103,7 +114,7 @@ Status DataServiceDispatcherImpl::Start() { } } else { TF_RETURN_IF_ERROR( - Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir()))); + env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir()))); } if (!config_.fault_tolerant_mode()) { LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will " @@ -111,12 +122,12 @@ Status DataServiceDispatcherImpl::Start() { return Status::OK(); } journal_writer_ = absl::make_unique( - Env::Default(), JournalDir(config_.work_dir())); - LOG(INFO) << "Restoring dispatcher state from journal in " + env_, JournalDir(config_.work_dir())); + LOG(INFO) << "Attempting to restore dispatcher state from journal in " << JournalDir(config_.work_dir()); Update update; bool end_of_journal = false; - FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir())); + FileJournalReader reader(env_, JournalDir(config_.work_dir())); Status s = reader.Read(update, end_of_journal); if (errors::IsNotFound(s)) { LOG(INFO) << "No journal found. Starting dispatcher from new state."; @@ -134,45 +145,38 @@ Status DataServiceDispatcherImpl::Start() { return Status::OK(); } -Status DataServiceDispatcherImpl::RegisterWorker( - const RegisterWorkerRequest* request, RegisterWorkerResponse* response) { - VLOG(3) << "Received register worker request"; +Status DataServiceDispatcherImpl::WorkerHeartbeat( + const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) { + VLOG(3) << "Received worker heartbeat request from worker " + << request->worker_address(); mutex_lock l(mu_); - std::string worker_address = request->worker_address(); - std::vector> tasks; - Status s = state_.TasksForWorker(worker_address, tasks); - if (errors::IsNotFound(s)) { + const std::string& worker_address = request->worker_address(); + std::vector> correct_tasks; + Status s = state_.TasksForWorker(worker_address, correct_tasks); + if (!s.ok()) { + if (!errors::IsNotFound(s)) { + return s; + } Update update; update.mutable_register_worker()->set_worker_address(worker_address); TF_RETURN_IF_ERROR(Apply(update)); - } else if (!s.ok()) { - return s; + TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address)); + TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks)); } - absl::flat_hash_map> tasks_by_job; - for (const auto& task : tasks) { - // Should never have multiple tasks on the same worker for the same job. - auto& task_for_job = tasks_by_job[task->job_id]; - DCHECK(task_for_job == nullptr); - task_for_job = task; - } + absl::flat_hash_set current_tasks; + current_tasks.insert(request->current_tasks().cbegin(), + request->current_tasks().cend()); + absl::flat_hash_set correct_tasks_set; - std::vector> jobs = state_.ListJobs(); - // Allocate tasks to the worker. - for (const auto& job : jobs) { - if (job->finished) { + for (const auto& task : correct_tasks) { + correct_tasks_set.insert(task->task_id); + if (current_tasks.contains(task->task_id)) { continue; } - std::shared_ptr task; - auto it = tasks_by_job.find(job->job_id); - if (it != tasks_by_job.end()) { - task = it->second; - } else { - TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task)); - } - TaskDef* task_def = response->add_tasks(); + TaskDef* task_def = response->add_new_tasks(); std::shared_ptr dataset; - TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset)); + TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset)); std::string dataset_key = DatasetKey(dataset->dataset_id, dataset->fingerprint); if (config_.work_dir().empty()) { @@ -184,12 +188,18 @@ Status DataServiceDispatcherImpl::RegisterWorker( io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key); task_def->set_path(path); } - task_def->set_dataset_id(job->dataset_id); - task_def->set_job_id(job->job_id); + task_def->set_dataset_id(task->dataset_id); + task_def->set_job_id(task->job_id); task_def->set_task_id(task->task_id); } + for (int64 current_task : current_tasks) { + if (!correct_tasks_set.contains(current_task)) { + response->add_tasks_to_delete(current_task); + } + } - VLOG(1) << "Registered worker at address " << request->worker_address(); + VLOG(1) << "Finished worker heartbeat for worker at address " + << request->worker_address(); return Status::OK(); } @@ -346,7 +356,7 @@ Status DataServiceDispatcherImpl::ReleaseJobClient( ReleaseJobClientUpdate* release_job_client = update.mutable_release_job_client(); release_job_client->set_job_client_id(job_client_id); - release_job_client->set_time_micros(Env::Default()->NowMicros()); + release_job_client->set_time_micros(env_->NowMicros()); TF_RETURN_IF_ERROR(Apply(update)); return Status::OK(); } @@ -412,6 +422,19 @@ Status DataServiceDispatcherImpl::CreateJob( return Status::OK(); } +Status DataServiceDispatcherImpl::CreateTasksForWorker( + const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector> jobs = state_.ListJobs(); + for (const auto& job : jobs) { + if (job->finished) { + continue; + } + std::shared_ptr task; + TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task)); + } + return Status::OK(); +} + Status DataServiceDispatcherImpl::AcquireJobClientId( const std::shared_ptr& job, int64& job_client_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -576,5 +599,51 @@ Status DataServiceDispatcherImpl::Apply(const Update& update) return state_.Apply(update); } +void DataServiceDispatcherImpl::JobGcThread() { + int64 next_check_micros = 0; + while (true) { + mutex_lock l(mu_); + while (!cancelled_ && env_->NowMicros() < next_check_micros) { + int64 remaining_micros = next_check_micros - env_->NowMicros(); + job_gc_thread_cv_.wait_for(l, + std::chrono::microseconds(remaining_micros)); + } + if (cancelled_) { + return; + } + Status s = GcOldJobs(); + if (!s.ok()) { + LOG(WARNING) << "Error garbage collecting old jobs: " << s; + } + next_check_micros = + env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000); + } +} + +Status DataServiceDispatcherImpl::GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector> jobs = state_.ListJobs(); + int64 now = env_->NowMicros(); + for (const auto& job : jobs) { + if (job->finished || job->num_clients > 0 || + job->last_client_released_micros < 0 || + now < job->last_client_released_micros + + (config_.job_gc_timeout_ms() * 1000)) { + continue; + } + std::vector> tasks; + TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks)); + for (const auto& task : tasks) { + if (task->finished) { + continue; + } + Update update; + update.mutable_finish_task()->set_task_id(task->task_id); + TF_RETURN_IF_ERROR(state_.Apply(update)); + } + DCHECK(job->finished); + } + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 2cf341a812e..2ce367735e5 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -48,6 +48,8 @@ class DataServiceDispatcherImpl { explicit DataServiceDispatcherImpl( const experimental::DispatcherConfig& config); + ~DataServiceDispatcherImpl(); + // Starts the dispatcher. If there is a journal, this will read from the // journal to restore the dispatcher's state. Status Start(); @@ -55,8 +57,8 @@ class DataServiceDispatcherImpl { // See dispatcher.proto for API documentation. /// Worker-facing API. - Status RegisterWorker(const RegisterWorkerRequest* request, - RegisterWorkerResponse* response); + Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, + WorkerHeartbeatResponse* response); Status WorkerUpdate(const WorkerUpdateRequest* request, WorkerUpdateResponse* response); Status GetDatasetDef(const GetDatasetDefRequest* request, @@ -92,6 +94,8 @@ class DataServiceDispatcherImpl { absl::optional named_job_key, std::shared_ptr& job) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates tasks for the specified worker, one task for every unfinished job. + Status CreateTasksForWorker(const std::string& worker_address); // Acquires a job client id to read from the given job and sets // `job_client_id`. Status AcquireJobClientId( @@ -128,12 +132,16 @@ class DataServiceDispatcherImpl { // used when recovering state when the dispatcher starts. Status ApplyWithoutJournaling(const Update& update) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // A thread which periodically checks for jobs to clean up. + void JobGcThread(); + // Scans for old jobs and marks them as finished. + Status GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_); const experimental::DispatcherConfig& config_; + Env* env_; mutex mu_; - - int64 next_task_id_ TF_GUARDED_BY(mu_) = 0; + bool cancelled_ TF_GUARDED_BY(mu_) = false; // Cached worker stubs for communicating with workers. absl::flat_hash_map> @@ -144,6 +152,9 @@ class DataServiceDispatcherImpl { absl::optional> journal_writer_ TF_GUARDED_BY(mu_); DispatcherState state_ TF_GUARDED_BY(mu_); + // Condition variable for waking up the job gc thread. + condition_variable job_gc_thread_cv_; + std::unique_ptr job_gc_thread_; TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); }; diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index 3afee88262e..749f9bce340 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -72,7 +72,8 @@ void DispatcherState::RegisterWorker( std::string address = register_worker.worker_address(); DCHECK(!workers_.contains(address)); workers_[address] = std::make_shared(address); - tasks_by_worker_[address] = std::vector>(); + tasks_by_worker_[address] = + absl::flat_hash_map>(); } void DispatcherState::CreateJob(const CreateJobUpdate& create_job) { @@ -126,7 +127,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) { create_task.dataset_id(), create_task.worker_address()); tasks_by_job_[create_task.job_id()].push_back(task); - tasks_by_worker_[create_task.worker_address()].push_back(task); + tasks_by_worker_[create_task.worker_address()][task->task_id] = task; next_available_task_id_ = std::max(next_available_task_id_, task_id + 1); } @@ -136,6 +137,7 @@ void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) { auto& task = tasks_[task_id]; DCHECK(task != nullptr); task->finished = true; + tasks_by_worker_[task->worker_address].erase(task->task_id); bool all_finished = true; for (const auto& task_for_job : tasks_by_job_[task->job_id]) { if (!task_for_job->finished) { @@ -269,10 +271,11 @@ Status DispatcherState::TasksForWorker( if (it == tasks_by_worker_.end()) { return errors::NotFound("Worker ", worker_address, " not found"); } - std::vector> worker_tasks = it->second; + const absl::flat_hash_map>& worker_tasks = + it->second; tasks.reserve(worker_tasks.size()); for (const auto& task : worker_tasks) { - tasks.push_back(task); + tasks.push_back(task.second); } return Status::OK(); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 59d7f192fb1..9f307e92ae3 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -179,7 +179,7 @@ class DispatcherState { void CreateTask(const CreateTaskUpdate& create_task); void FinishTask(const FinishTaskUpdate& finish_task); - int64 next_available_dataset_id_ = 0; + int64 next_available_dataset_id_ = 1000; // Registered datasets, keyed by dataset ids. absl::flat_hash_map> datasets_by_id_; // Registered datasets, keyed by dataset fingerprints. @@ -189,24 +189,26 @@ class DispatcherState { // Registered workers, keyed by address. absl::flat_hash_map> workers_; - int64 next_available_job_id_ = 0; + int64 next_available_job_id_ = 2000; // Jobs, keyed by job ids. absl::flat_hash_map> jobs_; // Named jobs, keyed by their names and indices. Not all jobs have names, so // this is a subset of the jobs stored in `jobs_`. absl::flat_hash_map> named_jobs_; - int64 next_available_job_client_id_ = 0; + int64 next_available_job_client_id_ = 3000; // Mapping from client ids to the jobs they are associated with. absl::flat_hash_map> jobs_for_client_ids_; - int64 next_available_task_id_ = 0; + int64 next_available_task_id_ = 4000; // Tasks, keyed by task ids. absl::flat_hash_map> tasks_; // Tasks, keyed by job ids. absl::flat_hash_map>> tasks_by_job_; - // Tasks, keyed by worker addresses. - absl::flat_hash_map>> + // Tasks, keyed by worker addresses. The values are a map from task id to + // task. + absl::flat_hash_map>> tasks_by_worker_; }; diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index 43a47f8581f..299ff2c8feb 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -125,9 +125,9 @@ Status FinishTask(int64 task_id, DispatcherState& state) { } // namespace TEST(DispatcherState, RegisterDataset) { - int64 id = 10; uint64 fingerprint = 20; DispatcherState state; + int64 id = state.NextAvailableDatasetId(); TF_EXPECT_OK(RegisterDataset(id, fingerprint, state)); EXPECT_EQ(state.NextAvailableDatasetId(), id + 1); @@ -210,9 +210,9 @@ TEST(DispatcherState, UnknownUpdate) { } TEST(DispatcherState, AnonymousJob) { - int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; + int64 job_id = state.NextAvailableJobId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); std::shared_ptr job; @@ -227,9 +227,9 @@ TEST(DispatcherState, AnonymousJob) { } TEST(DispatcherState, NamedJob) { - int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; + int64 job_id = state.NextAvailableJobId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); NamedJobKey named_job_key("test", 1); TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state)); @@ -244,9 +244,9 @@ TEST(DispatcherState, NamedJob) { TEST(DispatcherState, CreateTask) { int64 job_id = 3; int64 dataset_id = 10; - int64 task_id = 8; std::string worker_address = "test_worker_address"; DispatcherState state; + int64 task_id = state.NextAvailableTaskId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state)); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index fbfc5d20665..89ae6d4fd50 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -40,7 +40,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); } method##Response* response) { \ return ToGrpcStatus(impl_.method(request, response)); \ } -HANDLER(RegisterWorker); +HANDLER(WorkerHeartbeat); HANDLER(WorkerUpdate); HANDLER(GetDatasetDef); HANDLER(GetOrRegisterDataset); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index 171deed4792..6269148a5f9 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -39,7 +39,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service { ::grpc::Status method(::grpc::ServerContext* context, \ const method##Request* request, \ method##Response* response) override; - HANDLER(RegisterWorker); + HANDLER(WorkerHeartbeat); HANDLER(WorkerUpdate); HANDLER(GetDatasetDef); HANDLER(GetOrRegisterDataset); diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index ef386be4640..3c3a81d0daf 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -43,6 +43,7 @@ Status GrpcWorkerImpl::Start(const std::string& worker_address) { } HANDLER(ProcessTask); HANDLER(GetElement); +HANDLER(GetWorkerTasks); #undef HANDLER } // namespace data diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index 3d30af9a806..734865e3447 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -41,6 +41,7 @@ class GrpcWorkerImpl : public WorkerService::Service { method##Response* response) override; HANDLER(ProcessTask); HANDLER(GetElement); + HANDLER(GetWorkerTasks); #undef HANDLER private: diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index cb85693cbf7..af940fe54a3 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -138,6 +138,18 @@ Status WorkerGrpcDataServer::StartServiceInternal() { return Status::OK(); } +Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { + GetWorkerTasksRequest req; + GetWorkerTasksResponse resp; + ::grpc::ServerContext ctx; + ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to get tasks", s); + } + *num_tasks = resp.tasks_size(); + return Status::OK(); +} + Status NewDispatchServer(const experimental::DispatcherConfig& config, std::unique_ptr& out_server) { out_server = absl::make_unique(config); diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index c45ec144652..c9981008248 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -98,6 +98,9 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config); ~WorkerGrpcDataServer() override; + // Returns the number of tasks currently being executed by the worker. + Status NumTasks(int* num_tasks); + protected: void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto index 51c6899f540..32d3b79a78e 100644 --- a/tensorflow/core/data/service/worker.proto +++ b/tensorflow/core/data/service/worker.proto @@ -23,10 +23,20 @@ message GetElementResponse { bool end_of_sequence = 2; } +// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto +message GetWorkerTasksRequest {} + +message GetWorkerTasksResponse { + repeated TaskInfo tasks = 1; +} + service WorkerService { // Processes an task for a dataset, making elements available to clients. rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse); // Gets the next dataset element. rpc GetElement(GetElementRequest) returns (GetElementResponse); + + // Gets the tasks currently being executed by the worker. + rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse); } diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index cc61c481d7c..f0790f6961e 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -56,7 +56,8 @@ DataServiceWorkerImpl::DataServiceWorkerImpl( DataServiceWorkerImpl::~DataServiceWorkerImpl() { mutex_lock l(mu_); cancelled_ = true; - background_cv_.notify_one(); + task_completion_cv_.notify_one(); + heartbeat_cv_.notify_one(); } Status DataServiceWorkerImpl::Start(const std::string& worker_address) { @@ -67,18 +68,20 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) { config_.dispatcher_address(), config_.protocol()); TF_RETURN_IF_ERROR(dispatcher_->Initialize()); - Status s = Register(); + Status s = Heartbeat(); while (!s.ok()) { LOG(WARNING) << "Failed to register with dispatcher at " << config_.dispatcher_address() << ": " << s; Env::Default()->SleepForMicroseconds(kRetryIntervalMicros); - s = Register(); + s = Heartbeat(); } - Thread* thread = Env::Default()->StartThread( - {}, "data-service-worker-background", [this]() { BackgroundThread(); }); LOG(INFO) << "Worker registered with dispatcher running at " << config_.dispatcher_address(); - background_thread_.reset(thread); + task_completion_thread_ = absl::WrapUnique( + Env::Default()->StartThread({}, "data-service-worker-task-completion", + [this]() { TaskCompletionThread(); })); + heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread( + {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); })); mutex_lock l(mu_); registered_ = true; return Status::OK(); @@ -96,8 +99,9 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::unique_ptr& task = tasks_[task_def.task_id()]; if (task) { - return errors::AlreadyExists("A task with id ", task_def.task_id(), - " already exists."); + VLOG(1) << "Received request to process already-processed task " + << task->task_def.task_id(); + return Status::OK(); } task = absl::make_unique(task_def); VLOG(3) << "Began processing for task " << task_def.task_id(); @@ -156,24 +160,17 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, } auto it = tasks_.find(request->task_id()); if (it == tasks_.end()) { - return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ", - "Task id ", request->task_id(), " not found"); - } - auto& task = it->second; - TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); - std::unique_ptr& iter = task->iterator; - if (iter == nullptr) { - VLOG(3) << "Task " << request->task_id() << " is already finished"; response->set_end_of_sequence(true); return Status::OK(); } - TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence)); + auto& task = it->second; + TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); + TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence)); if (end_of_sequence) { VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); - // Release iterator memory and leave a null entry as a tombstone. - iter.reset(); + tasks_.erase(request->task_id()); pending_completed_tasks_.insert(request->task_id()); - background_cv_.notify_one(); + task_completion_cv_.notify_one(); } } @@ -212,27 +209,28 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, return Status::OK(); } -Status DataServiceWorkerImpl::Register() LOCKS_EXCLUDED(mu_) { - VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address(); - std::vector tasks; - TF_RETURN_IF_ERROR(dispatcher_->RegisterWorker(worker_address_, tasks)); - for (const TaskDef& task : tasks) { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(ProcessTaskInternal(task)); +Status DataServiceWorkerImpl::GetWorkerTasks( + const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) { + mutex_lock l(mu_); + for (const auto& it : tasks_) { + Task* task = it.second.get(); + TaskInfo* task_info = response->add_tasks(); + task_info->set_worker_address(worker_address_); + task_info->set_task_id(task->task_def.task_id()); + task_info->set_job_id(task->task_def.job_id()); } - VLOG(3) << "Registered worker with address " << worker_address_; return Status::OK(); } -void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) { +void DataServiceWorkerImpl::TaskCompletionThread() LOCKS_EXCLUDED(mu_) { while (true) { { mutex_lock l(mu_); while (!cancelled_ && pending_completed_tasks_.empty()) { - background_cv_.wait(l); + task_completion_cv_.wait(l); } if (cancelled_) { - VLOG(3) << "Background thread shutting down"; + VLOG(3) << "Task completion thread shutting down"; return; } } @@ -241,7 +239,7 @@ void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) { LOG(WARNING) << "Failed to send task updates to dispatcher: " << s; mutex_lock l(mu_); if (!cancelled_) { - background_cv_.wait_for( + task_completion_cv_.wait_for( l, std::chrono::microseconds(kRetryIntervalMicros)); } } @@ -271,5 +269,62 @@ Status DataServiceWorkerImpl::SendTaskUpdates() LOCKS_EXCLUDED(mu_) { return Status::OK(); } +void DataServiceWorkerImpl::HeartbeatThread() LOCKS_EXCLUDED(mu_) { + while (true) { + int64 next_heartbeat_micros = + Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000); + { + mutex_lock l(mu_); + while (!cancelled_ && + Env::Default()->NowMicros() < next_heartbeat_micros) { + int64 time_to_wait_micros = + next_heartbeat_micros - Env::Default()->NowMicros(); + heartbeat_cv_.wait_for(l, + std::chrono::microseconds(time_to_wait_micros)); + } + if (cancelled_) { + VLOG(3) << "Heartbeat thread shutting down"; + return; + } + if (!registered_) { + VLOG(1) << "Not performing heartbeat; worker is not yet registered"; + continue; + } + } + Status s = Heartbeat(); + if (!s.ok()) { + LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s; + } + } +} + +Status DataServiceWorkerImpl::Heartbeat() LOCKS_EXCLUDED(mu_) { + std::vector current_tasks; + { + mutex_lock l(mu_); + for (const auto& task : tasks_) { + current_tasks.push_back(task.first); + } + } + std::vector new_tasks; + std::vector tasks_to_delete; + TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat( + worker_address_, current_tasks, new_tasks, tasks_to_delete)); + mutex_lock l(mu_); + for (const auto& task : new_tasks) { + Status s = ProcessTaskInternal(task); + if (!s.ok() && !errors::IsAlreadyExists(s)) { + LOG(WARNING) << "Failed to start processing task " << task.task_id() + << ": " << s; + } + } + for (int64 task_id : tasks_to_delete) { + VLOG(3) << "Deleting task " << task_id + << " at the request of the dispatcher"; + tasks_.erase(task_id); + } + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 27a7da34c1d..5f05275622b 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -50,6 +50,8 @@ class DataServiceWorkerImpl { /// Client-facing API. Status GetElement(const GetElementRequest* request, GetElementResponse* response); + Status GetWorkerTasks(const GetWorkerTasksRequest* request, + GetWorkerTasksResponse* response); private: struct Task { @@ -64,16 +66,17 @@ class DataServiceWorkerImpl { std::unique_ptr iterator; }; - // Registers the worker with the dispatcher. - Status Register() LOCKS_EXCLUDED(mu_); // Sends task status to the dispatcher and checks for dispatcher commands. Status SendTaskUpdates() LOCKS_EXCLUDED(mu_); // Creates an iterator to process a task. Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureTaskInitialized(Task& task); - // A thread for doing async background processing not associated with a - // specific RPC, such as reporting finished tasks. - void BackgroundThread() LOCKS_EXCLUDED(mu_); + // A thread for notifying the dispatcher when tasks complete. + void TaskCompletionThread() LOCKS_EXCLUDED(mu_); + // A thread for doing periodic heartbeats to the dispatcher. + void HeartbeatThread() LOCKS_EXCLUDED(mu_); + // Performs a heartbeat to the dispatcher. + Status Heartbeat() LOCKS_EXCLUDED(mu_); const experimental::WorkerConfig config_; // The worker's own address. @@ -88,9 +91,12 @@ class DataServiceWorkerImpl { bool cancelled_ TF_GUARDED_BY(mu_) = false; // Whether the worker has registered with the dispatcher yet. bool registered_ TF_GUARDED_BY(mu_) = false; - // Condition variable for notifying the background thread. - condition_variable background_cv_ TF_GUARDED_BY(mu_); - std::unique_ptr background_thread_; + // A thread for notifying the dispatcher when tasks complete. + std::unique_ptr task_completion_thread_; + condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); + // A thread for performing regular heartbeats to the dispatcher. + std::unique_ptr heartbeat_thread_; + condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); }; diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index d89392598d5..b17d2008778 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -200,6 +200,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { for (auto& worker_thread : worker_threads_) { worker_thread.reset(); } + + VLOG(1) << "Destroyed data service dataset iterator for job id " + << job_client_id_; } void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 233c18076cf..7a0aa16e2c4 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -15,6 +15,11 @@ message DispatcherConfig { // Whether to run in fault tolerant mode, where dispatcher state is saved // across restarts. Requires that `work_dir` is nonempty. bool fault_tolerant_mode = 4; + // How often the dispatcher should scan through to delete old and unused jobs. + int64 job_gc_check_interval_ms = 5; + // How long a job needs to be unused before it becomes a candidate for garbage + // collection. + int64 job_gc_timeout_ms = 6; } // Configuration for a tf.data service WorkerServer. @@ -30,4 +35,6 @@ message WorkerConfig { // will be replaced with the worker's bound port. This is useful when the port // is set to `0`. string worker_address = 4; + // How often the worker should heartbeat to the master. + int64 heartbeat_interval_ms = 5; } diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 8c5c85a169a..d964080ebba 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -29,9 +29,10 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("data.experimental.service.DispatcherConfig") class DispatcherConfig( - collections.namedtuple( - "DispatcherConfig", - ["port", "protocol", "work_dir", "fault_tolerant_mode"])): + collections.namedtuple("DispatcherConfig", [ + "port", "protocol", "work_dir", "fault_tolerant_mode", + "job_gc_check_interval_ms", "job_gc_timeout_ms" + ])): """Configuration class for tf.data service dispatchers. Fields: @@ -47,15 +48,34 @@ class DispatcherConfig( registered datasets and created jobs, is synchronously written to the journal before responding to RPCs. If `True`, `work_dir` must also be specified. + job_gc_check_interval_ms: How often the dispatcher should scan through to + delete old and unused jobs, in milliseconds. If not set, the runtime will + select a reasonable default. A higher value will reduce load on the + dispatcher, while a lower value will reduce the time it takes for the + dispatcher to garbage collect expired jobs. + job_gc_timeout_ms: How long a job needs to be unused before it becomes a + candidate for garbage collection, in milliseconds. If not set, the runtime + will select a reasonable default. A higher value will cause jobs to stay + around longer with no consumers. This is useful if there is a large gap in + time between when consumers read from the job. A lower value will reduce + the time it takes to reclaim the resources from expired jobs. """ def __new__(cls, port=0, protocol="grpc", work_dir=None, - fault_tolerant_mode=False): - return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir, - fault_tolerant_mode) + fault_tolerant_mode=False, + job_gc_check_interval_ms=None, + job_gc_timeout_ms=None): + if job_gc_check_interval_ms is None: + job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes. + if job_gc_timeout_ms is None: + job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes. + return super(DispatcherConfig, + cls).__new__(cls, port, protocol, work_dir, + fault_tolerant_mode, job_gc_check_interval_ms, + job_gc_timeout_ms) @tf_export("data.experimental.service.DispatchServer", v1=[]) @@ -116,7 +136,9 @@ class DispatchServer(object): port=config.port, protocol=config.protocol, work_dir=config.work_dir, - fault_tolerant_mode=config.fault_tolerant_mode) + fault_tolerant_mode=config.fault_tolerant_mode, + job_gc_check_interval_ms=config.job_gc_check_interval_ms, + job_gc_timeout_ms=config.job_gc_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( config_proto.SerializeToString()) if start: @@ -193,9 +215,10 @@ class DispatchServer(object): @tf_export("data.experimental.service.WorkerConfig") class WorkerConfig( - collections.namedtuple( - "WorkerConfig", - ["dispatcher_address", "worker_address", "port", "protocol"])): + collections.namedtuple("WorkerConfig", [ + "dispatcher_address", "worker_address", "port", "protocol", + "heartbeat_interval_ms" + ])): """Configuration class for tf.data service dispatchers. Fields: @@ -205,19 +228,29 @@ class WorkerConfig( connect to this worker. port: Specifies the port to bind to. A value of 0 indicates that the worker can bind to any available port. - protocol: (Optional.) Specifies the protocol to be used by the server. + protocol: Specifies the protocol to be used by the server. Acceptable values include `"grpc" and "grpc+local"`. + heartbeat_interval_ms: How often the worker should heartbeat to the + dispatcher, in milliseconds. If not set, the runtime will select a + reasonable default. A higher value will reduce the load on the dispatcher, + while a lower value will reduce the time it takes to reclaim resources + from finished jobs. """ def __new__(cls, dispatcher_address, worker_address=None, port=0, - protocol="grpc"): - worker_address = ("localhost:%port%" - if worker_address is None else worker_address) - return super(WorkerConfig, cls).__new__(cls, dispatcher_address, - worker_address, port, protocol) + protocol="grpc", + heartbeat_interval_ms=None): + if worker_address is None: + worker_address = "localhost:%port%" + if heartbeat_interval_ms is None: + heartbeat_interval_ms = 30 * 1000 # 30 seconds + + return super(WorkerConfig, + cls).__new__(cls, dispatcher_address, worker_address, port, + protocol, heartbeat_interval_ms) @tf_export("data.experimental.service.WorkerServer", v1=[]) @@ -264,7 +297,8 @@ class WorkerServer(object): dispatcher_address=config.dispatcher_address, worker_address=config.worker_address, port=config.port, - protocol=config.protocol) + protocol=config.protocol, + heartbeat_interval_ms=config.heartbeat_interval_ms) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: @@ -317,3 +351,7 @@ class WorkerServer(object): The returned string will be in the form address:port, e.g. "localhost:1000". """ return "localhost:{0}".format(self._server.bound_port()) + + def _num_tasks(self): + """Returns the number of tasks currently being executed on the worker.""" + return self._server.num_tasks() diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index 8ce904eecba..5a229f88d92 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -50,7 +50,14 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { .def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop) .def("join", &tensorflow::data::WorkerGrpcDataServer::Join, py::call_guard()) - .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort); + .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort) + .def("num_tasks", + [](tensorflow::data::WorkerGrpcDataServer* server) -> int { + int num_tasks; + tensorflow::Status status = server->NumTasks(&num_tasks); + tensorflow::MaybeRaiseFromStatus(status); + return num_tasks; + }); m.def( "TF_DATA_NewDispatchServer", diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index eb7a67d69ec..0733c2ffbd8 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -101,7 +101,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): name="", port=0, work_dir=None, - fault_tolerant_mode=True): + fault_tolerant_mode=True, + job_gc_check_interval_ms=None, + job_gc_timeout_ms=None): # If a test starts multiple independent dispatch servers, it should give # them different `name` values. work_dir = os.path.join(self.get_temp_dir(), "work_dir_", @@ -110,13 +112,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): server_lib.DispatcherConfig( port=port, work_dir=work_dir, - fault_tolerant_mode=fault_tolerant_mode)) + fault_tolerant_mode=fault_tolerant_mode, + job_gc_check_interval_ms=job_gc_check_interval_ms, + job_gc_timeout_ms=job_gc_timeout_ms)) def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), - port=port)) + port=port, + heartbeat_interval_ms=200)) def restart_dispatcher(self, dispatcher): """Stops `dispatcher` and returns a new dispatcher with the same port.""" @@ -535,6 +540,47 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) + @combinations.generate( + combinations.times(test_base.eager_only_combinations(), + combinations.combine(job_name=[None, "test"]))) + def testGcUnusedJob(self, job_name): + dispatcher = self.start_dispatch_server( + job_gc_check_interval_ms=50, job_gc_timeout_ms=20) + worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable + num_elements = 10 + ds = _make_distributed_range_dataset( + num_elements, dispatcher, job_name=job_name) + it = iter(ds) + self.assertEqual(0, next(it).numpy()) + self.assertEqual(1, worker._num_tasks()) + del it + while worker._num_tasks() > 0: + time.sleep(0.1) + + @combinations.generate(test_base.eager_only_combinations()) + def testDontGcUsedJob(self): + dispatcher = self.start_dispatch_server( + job_gc_check_interval_ms=50, job_gc_timeout_ms=20) + worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable + num_elements = 10 + it1 = iter( + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test1")) + it2 = iter( + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test2")) + it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test2")) + self.assertEqual(2, worker._num_tasks()) + del it1 + del it2 + # Check that only the first job is gced. The second job will not be gced + # because there is still an outstanding iterator for it. + while worker._num_tasks() > 1: + time.sleep(0.1) + self.assertEqual(1, worker._num_tasks()) + @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt index 3efc1c97797..2364e41e4c5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt @@ -7,6 +7,14 @@ tf_class { name: "fault_tolerant_mode" mtype: "" } + member { + name: "job_gc_check_interval_ms" + mtype: "" + } + member { + name: "job_gc_timeout_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt index 7bf0bc136e6..d8eaf9bc7d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "heartbeat_interval_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt index 3efc1c97797..2364e41e4c5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt @@ -7,6 +7,14 @@ tf_class { name: "fault_tolerant_mode" mtype: "" } + member { + name: "job_gc_check_interval_ms" + mtype: "" + } + member { + name: "job_gc_timeout_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt index 7bf0bc136e6..d8eaf9bc7d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "heartbeat_interval_ms" + mtype: "" + } member { name: "port" mtype: ""