diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index d425dab46dc..cc50e0f3e5a 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -54,19 +54,26 @@ std::string ProcessingModeToString(ProcessingMode mode) { } } -Status DataServiceDispatcherClient::RegisterWorker( - const std::string& worker_address, std::vector& tasks) { +Status DataServiceDispatcherClient::WorkerHeartbeat( + const std::string& worker_address, const std::vector& current_tasks, + std::vector& new_tasks, std::vector& tasks_to_delete) { TF_RETURN_IF_ERROR(EnsureInitialized()); - RegisterWorkerRequest req; + WorkerHeartbeatRequest req; req.set_worker_address(worker_address); - RegisterWorkerResponse resp; - grpc::ClientContext client_ctx; - grpc::Status status = stub_->RegisterWorker(&client_ctx, req, &resp); - if (!status.ok()) { - return grpc_util::WrapError("Failed to register worker", status); + for (int64 task : current_tasks) { + req.add_current_tasks(task); } - for (const auto& task : resp.tasks()) { - tasks.push_back(task); + WorkerHeartbeatResponse resp; + grpc::ClientContext client_ctx; + grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp); + if (!status.ok()) { + return grpc_util::WrapError("Failed to perform worker heartbeat", status); + } + for (const auto& task : resp.new_tasks()) { + new_tasks.push_back(task); + } + for (int64 task_to_delete : resp.tasks_to_delete()) { + tasks_to_delete.push_back(task_to_delete); } return Status::OK(); } diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index c5eb6a37269..f0adbb3d4eb 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -73,10 +73,15 @@ class DataServiceDispatcherClient : public DataServiceClientBase { const std::string& protocol) : DataServiceClientBase(address, protocol) {} - // Registers a worker with the dispatcher. The dispatcher returns a list of - // initial tasks for the worker to run, storing them in `tasks`. - Status RegisterWorker(const std::string& worker_address, - std::vector& tasks); + // Sends a heartbeat to the dispatcher. If the worker wasn't already + // registered with the dispatcher, this will register the worker. The + // dispatcher will report which new tasks the worker should run, and which + // tasks it should delete. This is stored into `new_tasks` and + // `tasks_to_delete`. + Status WorkerHeartbeat(const std::string& worker_address, + const std::vector& current_tasks, + std::vector& new_tasks, + std::vector& tasks_to_delete); // Updates the dispatcher with information about the worker's state. Status WorkerUpdate(const std::string& worker_address, diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index cf8c4c20c70..ffa3eb6b5ca 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -4,16 +4,6 @@ package tensorflow.data; import "tensorflow/core/data/service/common.proto"; -message RegisterWorkerRequest { - // The address of the registering worker. - string worker_address = 1; -} - -message RegisterWorkerResponse { - // Tasks to begin processing. - repeated TaskDef tasks = 2; -} - message TaskProgress { // The task that this message is about. int64 task_id = 1; @@ -21,6 +11,16 @@ message TaskProgress { bool completed = 2; } +message WorkerHeartbeatRequest { + string worker_address = 1; + repeated int64 current_tasks = 2; +} + +message WorkerHeartbeatResponse { + repeated TaskDef new_tasks = 1; + repeated int64 tasks_to_delete = 2; +} + message WorkerUpdateRequest { string worker_address = 1; repeated TaskProgress updates = 2; @@ -110,8 +110,8 @@ message GetWorkersResponse { } service DispatcherService { - // Registers a worker with the dispatcher. - rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); + // Performs a periodic worker heartbeat. + rpc WorkerHeartbeat(WorkerHeartbeatRequest) returns (WorkerHeartbeatResponse); // Updates the dispatcher with information about the worker's state. rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index da9f99e093c..4477848d024 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -85,7 +85,7 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol, DataServiceDispatcherImpl::DataServiceDispatcherImpl( const experimental::DispatcherConfig& config) - : config_(config) { + : config_(config), env_(Env::Default()) { if (config_.work_dir().empty()) { dataset_store_ = absl::make_unique(); } else { @@ -94,8 +94,19 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl( } } +DataServiceDispatcherImpl::~DataServiceDispatcherImpl() { + { + mutex_lock l(mu_); + cancelled_ = true; + job_gc_thread_cv_.notify_all(); + } + job_gc_thread_.reset(); +} + Status DataServiceDispatcherImpl::Start() { mutex_lock l(mu_); + job_gc_thread_ = absl::WrapUnique( + env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); })); if (config_.work_dir().empty()) { if (config_.fault_tolerant_mode()) { return errors::InvalidArgument( @@ -103,7 +114,7 @@ Status DataServiceDispatcherImpl::Start() { } } else { TF_RETURN_IF_ERROR( - Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir()))); + env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir()))); } if (!config_.fault_tolerant_mode()) { LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will " @@ -111,12 +122,12 @@ Status DataServiceDispatcherImpl::Start() { return Status::OK(); } journal_writer_ = absl::make_unique( - Env::Default(), JournalDir(config_.work_dir())); - LOG(INFO) << "Restoring dispatcher state from journal in " + env_, JournalDir(config_.work_dir())); + LOG(INFO) << "Attempting to restore dispatcher state from journal in " << JournalDir(config_.work_dir()); Update update; bool end_of_journal = false; - FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir())); + FileJournalReader reader(env_, JournalDir(config_.work_dir())); Status s = reader.Read(update, end_of_journal); if (errors::IsNotFound(s)) { LOG(INFO) << "No journal found. Starting dispatcher from new state."; @@ -134,45 +145,38 @@ Status DataServiceDispatcherImpl::Start() { return Status::OK(); } -Status DataServiceDispatcherImpl::RegisterWorker( - const RegisterWorkerRequest* request, RegisterWorkerResponse* response) { - VLOG(3) << "Received register worker request"; +Status DataServiceDispatcherImpl::WorkerHeartbeat( + const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) { + VLOG(3) << "Received worker heartbeat request from worker " + << request->worker_address(); mutex_lock l(mu_); - std::string worker_address = request->worker_address(); - std::vector> tasks; - Status s = state_.TasksForWorker(worker_address, tasks); - if (errors::IsNotFound(s)) { + const std::string& worker_address = request->worker_address(); + std::vector> correct_tasks; + Status s = state_.TasksForWorker(worker_address, correct_tasks); + if (!s.ok()) { + if (!errors::IsNotFound(s)) { + return s; + } Update update; update.mutable_register_worker()->set_worker_address(worker_address); TF_RETURN_IF_ERROR(Apply(update)); - } else if (!s.ok()) { - return s; + TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address)); + TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks)); } - absl::flat_hash_map> tasks_by_job; - for (const auto& task : tasks) { - // Should never have multiple tasks on the same worker for the same job. - auto& task_for_job = tasks_by_job[task->job_id]; - DCHECK(task_for_job == nullptr); - task_for_job = task; - } + absl::flat_hash_set current_tasks; + current_tasks.insert(request->current_tasks().cbegin(), + request->current_tasks().cend()); + absl::flat_hash_set correct_tasks_set; - std::vector> jobs = state_.ListJobs(); - // Allocate tasks to the worker. - for (const auto& job : jobs) { - if (job->finished) { + for (const auto& task : correct_tasks) { + correct_tasks_set.insert(task->task_id); + if (current_tasks.contains(task->task_id)) { continue; } - std::shared_ptr task; - auto it = tasks_by_job.find(job->job_id); - if (it != tasks_by_job.end()) { - task = it->second; - } else { - TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task)); - } - TaskDef* task_def = response->add_tasks(); + TaskDef* task_def = response->add_new_tasks(); std::shared_ptr dataset; - TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset)); + TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset)); std::string dataset_key = DatasetKey(dataset->dataset_id, dataset->fingerprint); if (config_.work_dir().empty()) { @@ -184,12 +188,18 @@ Status DataServiceDispatcherImpl::RegisterWorker( io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key); task_def->set_path(path); } - task_def->set_dataset_id(job->dataset_id); - task_def->set_job_id(job->job_id); + task_def->set_dataset_id(task->dataset_id); + task_def->set_job_id(task->job_id); task_def->set_task_id(task->task_id); } + for (int64 current_task : current_tasks) { + if (!correct_tasks_set.contains(current_task)) { + response->add_tasks_to_delete(current_task); + } + } - VLOG(1) << "Registered worker at address " << request->worker_address(); + VLOG(1) << "Finished worker heartbeat for worker at address " + << request->worker_address(); return Status::OK(); } @@ -346,7 +356,7 @@ Status DataServiceDispatcherImpl::ReleaseJobClient( ReleaseJobClientUpdate* release_job_client = update.mutable_release_job_client(); release_job_client->set_job_client_id(job_client_id); - release_job_client->set_time_micros(Env::Default()->NowMicros()); + release_job_client->set_time_micros(env_->NowMicros()); TF_RETURN_IF_ERROR(Apply(update)); return Status::OK(); } @@ -412,6 +422,19 @@ Status DataServiceDispatcherImpl::CreateJob( return Status::OK(); } +Status DataServiceDispatcherImpl::CreateTasksForWorker( + const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector> jobs = state_.ListJobs(); + for (const auto& job : jobs) { + if (job->finished) { + continue; + } + std::shared_ptr task; + TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task)); + } + return Status::OK(); +} + Status DataServiceDispatcherImpl::AcquireJobClientId( const std::shared_ptr& job, int64& job_client_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -576,5 +599,51 @@ Status DataServiceDispatcherImpl::Apply(const Update& update) return state_.Apply(update); } +void DataServiceDispatcherImpl::JobGcThread() { + int64 next_check_micros = 0; + while (true) { + mutex_lock l(mu_); + while (!cancelled_ && env_->NowMicros() < next_check_micros) { + int64 remaining_micros = next_check_micros - env_->NowMicros(); + job_gc_thread_cv_.wait_for(l, + std::chrono::microseconds(remaining_micros)); + } + if (cancelled_) { + return; + } + Status s = GcOldJobs(); + if (!s.ok()) { + LOG(WARNING) << "Error garbage collecting old jobs: " << s; + } + next_check_micros = + env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000); + } +} + +Status DataServiceDispatcherImpl::GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector> jobs = state_.ListJobs(); + int64 now = env_->NowMicros(); + for (const auto& job : jobs) { + if (job->finished || job->num_clients > 0 || + job->last_client_released_micros < 0 || + now < job->last_client_released_micros + + (config_.job_gc_timeout_ms() * 1000)) { + continue; + } + std::vector> tasks; + TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks)); + for (const auto& task : tasks) { + if (task->finished) { + continue; + } + Update update; + update.mutable_finish_task()->set_task_id(task->task_id); + TF_RETURN_IF_ERROR(state_.Apply(update)); + } + DCHECK(job->finished); + } + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 2cf341a812e..2ce367735e5 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -48,6 +48,8 @@ class DataServiceDispatcherImpl { explicit DataServiceDispatcherImpl( const experimental::DispatcherConfig& config); + ~DataServiceDispatcherImpl(); + // Starts the dispatcher. If there is a journal, this will read from the // journal to restore the dispatcher's state. Status Start(); @@ -55,8 +57,8 @@ class DataServiceDispatcherImpl { // See dispatcher.proto for API documentation. /// Worker-facing API. - Status RegisterWorker(const RegisterWorkerRequest* request, - RegisterWorkerResponse* response); + Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, + WorkerHeartbeatResponse* response); Status WorkerUpdate(const WorkerUpdateRequest* request, WorkerUpdateResponse* response); Status GetDatasetDef(const GetDatasetDefRequest* request, @@ -92,6 +94,8 @@ class DataServiceDispatcherImpl { absl::optional named_job_key, std::shared_ptr& job) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates tasks for the specified worker, one task for every unfinished job. + Status CreateTasksForWorker(const std::string& worker_address); // Acquires a job client id to read from the given job and sets // `job_client_id`. Status AcquireJobClientId( @@ -128,12 +132,16 @@ class DataServiceDispatcherImpl { // used when recovering state when the dispatcher starts. Status ApplyWithoutJournaling(const Update& update) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // A thread which periodically checks for jobs to clean up. + void JobGcThread(); + // Scans for old jobs and marks them as finished. + Status GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_); const experimental::DispatcherConfig& config_; + Env* env_; mutex mu_; - - int64 next_task_id_ TF_GUARDED_BY(mu_) = 0; + bool cancelled_ TF_GUARDED_BY(mu_) = false; // Cached worker stubs for communicating with workers. absl::flat_hash_map> @@ -144,6 +152,9 @@ class DataServiceDispatcherImpl { absl::optional> journal_writer_ TF_GUARDED_BY(mu_); DispatcherState state_ TF_GUARDED_BY(mu_); + // Condition variable for waking up the job gc thread. + condition_variable job_gc_thread_cv_; + std::unique_ptr job_gc_thread_; TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); }; diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index 3afee88262e..749f9bce340 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -72,7 +72,8 @@ void DispatcherState::RegisterWorker( std::string address = register_worker.worker_address(); DCHECK(!workers_.contains(address)); workers_[address] = std::make_shared(address); - tasks_by_worker_[address] = std::vector>(); + tasks_by_worker_[address] = + absl::flat_hash_map>(); } void DispatcherState::CreateJob(const CreateJobUpdate& create_job) { @@ -126,7 +127,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) { create_task.dataset_id(), create_task.worker_address()); tasks_by_job_[create_task.job_id()].push_back(task); - tasks_by_worker_[create_task.worker_address()].push_back(task); + tasks_by_worker_[create_task.worker_address()][task->task_id] = task; next_available_task_id_ = std::max(next_available_task_id_, task_id + 1); } @@ -136,6 +137,7 @@ void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) { auto& task = tasks_[task_id]; DCHECK(task != nullptr); task->finished = true; + tasks_by_worker_[task->worker_address].erase(task->task_id); bool all_finished = true; for (const auto& task_for_job : tasks_by_job_[task->job_id]) { if (!task_for_job->finished) { @@ -269,10 +271,11 @@ Status DispatcherState::TasksForWorker( if (it == tasks_by_worker_.end()) { return errors::NotFound("Worker ", worker_address, " not found"); } - std::vector> worker_tasks = it->second; + const absl::flat_hash_map>& worker_tasks = + it->second; tasks.reserve(worker_tasks.size()); for (const auto& task : worker_tasks) { - tasks.push_back(task); + tasks.push_back(task.second); } return Status::OK(); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 59d7f192fb1..9f307e92ae3 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -179,7 +179,7 @@ class DispatcherState { void CreateTask(const CreateTaskUpdate& create_task); void FinishTask(const FinishTaskUpdate& finish_task); - int64 next_available_dataset_id_ = 0; + int64 next_available_dataset_id_ = 1000; // Registered datasets, keyed by dataset ids. absl::flat_hash_map> datasets_by_id_; // Registered datasets, keyed by dataset fingerprints. @@ -189,24 +189,26 @@ class DispatcherState { // Registered workers, keyed by address. absl::flat_hash_map> workers_; - int64 next_available_job_id_ = 0; + int64 next_available_job_id_ = 2000; // Jobs, keyed by job ids. absl::flat_hash_map> jobs_; // Named jobs, keyed by their names and indices. Not all jobs have names, so // this is a subset of the jobs stored in `jobs_`. absl::flat_hash_map> named_jobs_; - int64 next_available_job_client_id_ = 0; + int64 next_available_job_client_id_ = 3000; // Mapping from client ids to the jobs they are associated with. absl::flat_hash_map> jobs_for_client_ids_; - int64 next_available_task_id_ = 0; + int64 next_available_task_id_ = 4000; // Tasks, keyed by task ids. absl::flat_hash_map> tasks_; // Tasks, keyed by job ids. absl::flat_hash_map>> tasks_by_job_; - // Tasks, keyed by worker addresses. - absl::flat_hash_map>> + // Tasks, keyed by worker addresses. The values are a map from task id to + // task. + absl::flat_hash_map>> tasks_by_worker_; }; diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index 43a47f8581f..299ff2c8feb 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -125,9 +125,9 @@ Status FinishTask(int64 task_id, DispatcherState& state) { } // namespace TEST(DispatcherState, RegisterDataset) { - int64 id = 10; uint64 fingerprint = 20; DispatcherState state; + int64 id = state.NextAvailableDatasetId(); TF_EXPECT_OK(RegisterDataset(id, fingerprint, state)); EXPECT_EQ(state.NextAvailableDatasetId(), id + 1); @@ -210,9 +210,9 @@ TEST(DispatcherState, UnknownUpdate) { } TEST(DispatcherState, AnonymousJob) { - int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; + int64 job_id = state.NextAvailableJobId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); std::shared_ptr job; @@ -227,9 +227,9 @@ TEST(DispatcherState, AnonymousJob) { } TEST(DispatcherState, NamedJob) { - int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; + int64 job_id = state.NextAvailableJobId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); NamedJobKey named_job_key("test", 1); TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state)); @@ -244,9 +244,9 @@ TEST(DispatcherState, NamedJob) { TEST(DispatcherState, CreateTask) { int64 job_id = 3; int64 dataset_id = 10; - int64 task_id = 8; std::string worker_address = "test_worker_address"; DispatcherState state; + int64 task_id = state.NextAvailableTaskId(); TF_EXPECT_OK(RegisterDataset(dataset_id, state)); TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state)); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index fbfc5d20665..89ae6d4fd50 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -40,7 +40,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); } method##Response* response) { \ return ToGrpcStatus(impl_.method(request, response)); \ } -HANDLER(RegisterWorker); +HANDLER(WorkerHeartbeat); HANDLER(WorkerUpdate); HANDLER(GetDatasetDef); HANDLER(GetOrRegisterDataset); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index 171deed4792..6269148a5f9 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -39,7 +39,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service { ::grpc::Status method(::grpc::ServerContext* context, \ const method##Request* request, \ method##Response* response) override; - HANDLER(RegisterWorker); + HANDLER(WorkerHeartbeat); HANDLER(WorkerUpdate); HANDLER(GetDatasetDef); HANDLER(GetOrRegisterDataset); diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index ef386be4640..3c3a81d0daf 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -43,6 +43,7 @@ Status GrpcWorkerImpl::Start(const std::string& worker_address) { } HANDLER(ProcessTask); HANDLER(GetElement); +HANDLER(GetWorkerTasks); #undef HANDLER } // namespace data diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index 3d30af9a806..734865e3447 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -41,6 +41,7 @@ class GrpcWorkerImpl : public WorkerService::Service { method##Response* response) override; HANDLER(ProcessTask); HANDLER(GetElement); + HANDLER(GetWorkerTasks); #undef HANDLER private: diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index cb85693cbf7..af940fe54a3 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -138,6 +138,18 @@ Status WorkerGrpcDataServer::StartServiceInternal() { return Status::OK(); } +Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { + GetWorkerTasksRequest req; + GetWorkerTasksResponse resp; + ::grpc::ServerContext ctx; + ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to get tasks", s); + } + *num_tasks = resp.tasks_size(); + return Status::OK(); +} + Status NewDispatchServer(const experimental::DispatcherConfig& config, std::unique_ptr& out_server) { out_server = absl::make_unique(config); diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index c45ec144652..c9981008248 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -98,6 +98,9 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config); ~WorkerGrpcDataServer() override; + // Returns the number of tasks currently being executed by the worker. + Status NumTasks(int* num_tasks); + protected: void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto index 51c6899f540..32d3b79a78e 100644 --- a/tensorflow/core/data/service/worker.proto +++ b/tensorflow/core/data/service/worker.proto @@ -23,10 +23,20 @@ message GetElementResponse { bool end_of_sequence = 2; } +// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto +message GetWorkerTasksRequest {} + +message GetWorkerTasksResponse { + repeated TaskInfo tasks = 1; +} + service WorkerService { // Processes an task for a dataset, making elements available to clients. rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse); // Gets the next dataset element. rpc GetElement(GetElementRequest) returns (GetElementResponse); + + // Gets the tasks currently being executed by the worker. + rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse); } diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index cc61c481d7c..f0790f6961e 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -56,7 +56,8 @@ DataServiceWorkerImpl::DataServiceWorkerImpl( DataServiceWorkerImpl::~DataServiceWorkerImpl() { mutex_lock l(mu_); cancelled_ = true; - background_cv_.notify_one(); + task_completion_cv_.notify_one(); + heartbeat_cv_.notify_one(); } Status DataServiceWorkerImpl::Start(const std::string& worker_address) { @@ -67,18 +68,20 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) { config_.dispatcher_address(), config_.protocol()); TF_RETURN_IF_ERROR(dispatcher_->Initialize()); - Status s = Register(); + Status s = Heartbeat(); while (!s.ok()) { LOG(WARNING) << "Failed to register with dispatcher at " << config_.dispatcher_address() << ": " << s; Env::Default()->SleepForMicroseconds(kRetryIntervalMicros); - s = Register(); + s = Heartbeat(); } - Thread* thread = Env::Default()->StartThread( - {}, "data-service-worker-background", [this]() { BackgroundThread(); }); LOG(INFO) << "Worker registered with dispatcher running at " << config_.dispatcher_address(); - background_thread_.reset(thread); + task_completion_thread_ = absl::WrapUnique( + Env::Default()->StartThread({}, "data-service-worker-task-completion", + [this]() { TaskCompletionThread(); })); + heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread( + {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); })); mutex_lock l(mu_); registered_ = true; return Status::OK(); @@ -96,8 +99,9 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::unique_ptr& task = tasks_[task_def.task_id()]; if (task) { - return errors::AlreadyExists("A task with id ", task_def.task_id(), - " already exists."); + VLOG(1) << "Received request to process already-processed task " + << task->task_def.task_id(); + return Status::OK(); } task = absl::make_unique(task_def); VLOG(3) << "Began processing for task " << task_def.task_id(); @@ -156,24 +160,17 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, } auto it = tasks_.find(request->task_id()); if (it == tasks_.end()) { - return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ", - "Task id ", request->task_id(), " not found"); - } - auto& task = it->second; - TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); - std::unique_ptr& iter = task->iterator; - if (iter == nullptr) { - VLOG(3) << "Task " << request->task_id() << " is already finished"; response->set_end_of_sequence(true); return Status::OK(); } - TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence)); + auto& task = it->second; + TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); + TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence)); if (end_of_sequence) { VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); - // Release iterator memory and leave a null entry as a tombstone. - iter.reset(); + tasks_.erase(request->task_id()); pending_completed_tasks_.insert(request->task_id()); - background_cv_.notify_one(); + task_completion_cv_.notify_one(); } } @@ -212,27 +209,28 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, return Status::OK(); } -Status DataServiceWorkerImpl::Register() LOCKS_EXCLUDED(mu_) { - VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address(); - std::vector tasks; - TF_RETURN_IF_ERROR(dispatcher_->RegisterWorker(worker_address_, tasks)); - for (const TaskDef& task : tasks) { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(ProcessTaskInternal(task)); +Status DataServiceWorkerImpl::GetWorkerTasks( + const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) { + mutex_lock l(mu_); + for (const auto& it : tasks_) { + Task* task = it.second.get(); + TaskInfo* task_info = response->add_tasks(); + task_info->set_worker_address(worker_address_); + task_info->set_task_id(task->task_def.task_id()); + task_info->set_job_id(task->task_def.job_id()); } - VLOG(3) << "Registered worker with address " << worker_address_; return Status::OK(); } -void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) { +void DataServiceWorkerImpl::TaskCompletionThread() LOCKS_EXCLUDED(mu_) { while (true) { { mutex_lock l(mu_); while (!cancelled_ && pending_completed_tasks_.empty()) { - background_cv_.wait(l); + task_completion_cv_.wait(l); } if (cancelled_) { - VLOG(3) << "Background thread shutting down"; + VLOG(3) << "Task completion thread shutting down"; return; } } @@ -241,7 +239,7 @@ void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) { LOG(WARNING) << "Failed to send task updates to dispatcher: " << s; mutex_lock l(mu_); if (!cancelled_) { - background_cv_.wait_for( + task_completion_cv_.wait_for( l, std::chrono::microseconds(kRetryIntervalMicros)); } } @@ -271,5 +269,62 @@ Status DataServiceWorkerImpl::SendTaskUpdates() LOCKS_EXCLUDED(mu_) { return Status::OK(); } +void DataServiceWorkerImpl::HeartbeatThread() LOCKS_EXCLUDED(mu_) { + while (true) { + int64 next_heartbeat_micros = + Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000); + { + mutex_lock l(mu_); + while (!cancelled_ && + Env::Default()->NowMicros() < next_heartbeat_micros) { + int64 time_to_wait_micros = + next_heartbeat_micros - Env::Default()->NowMicros(); + heartbeat_cv_.wait_for(l, + std::chrono::microseconds(time_to_wait_micros)); + } + if (cancelled_) { + VLOG(3) << "Heartbeat thread shutting down"; + return; + } + if (!registered_) { + VLOG(1) << "Not performing heartbeat; worker is not yet registered"; + continue; + } + } + Status s = Heartbeat(); + if (!s.ok()) { + LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s; + } + } +} + +Status DataServiceWorkerImpl::Heartbeat() LOCKS_EXCLUDED(mu_) { + std::vector current_tasks; + { + mutex_lock l(mu_); + for (const auto& task : tasks_) { + current_tasks.push_back(task.first); + } + } + std::vector new_tasks; + std::vector tasks_to_delete; + TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat( + worker_address_, current_tasks, new_tasks, tasks_to_delete)); + mutex_lock l(mu_); + for (const auto& task : new_tasks) { + Status s = ProcessTaskInternal(task); + if (!s.ok() && !errors::IsAlreadyExists(s)) { + LOG(WARNING) << "Failed to start processing task " << task.task_id() + << ": " << s; + } + } + for (int64 task_id : tasks_to_delete) { + VLOG(3) << "Deleting task " << task_id + << " at the request of the dispatcher"; + tasks_.erase(task_id); + } + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 27a7da34c1d..5f05275622b 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -50,6 +50,8 @@ class DataServiceWorkerImpl { /// Client-facing API. Status GetElement(const GetElementRequest* request, GetElementResponse* response); + Status GetWorkerTasks(const GetWorkerTasksRequest* request, + GetWorkerTasksResponse* response); private: struct Task { @@ -64,16 +66,17 @@ class DataServiceWorkerImpl { std::unique_ptr iterator; }; - // Registers the worker with the dispatcher. - Status Register() LOCKS_EXCLUDED(mu_); // Sends task status to the dispatcher and checks for dispatcher commands. Status SendTaskUpdates() LOCKS_EXCLUDED(mu_); // Creates an iterator to process a task. Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureTaskInitialized(Task& task); - // A thread for doing async background processing not associated with a - // specific RPC, such as reporting finished tasks. - void BackgroundThread() LOCKS_EXCLUDED(mu_); + // A thread for notifying the dispatcher when tasks complete. + void TaskCompletionThread() LOCKS_EXCLUDED(mu_); + // A thread for doing periodic heartbeats to the dispatcher. + void HeartbeatThread() LOCKS_EXCLUDED(mu_); + // Performs a heartbeat to the dispatcher. + Status Heartbeat() LOCKS_EXCLUDED(mu_); const experimental::WorkerConfig config_; // The worker's own address. @@ -88,9 +91,12 @@ class DataServiceWorkerImpl { bool cancelled_ TF_GUARDED_BY(mu_) = false; // Whether the worker has registered with the dispatcher yet. bool registered_ TF_GUARDED_BY(mu_) = false; - // Condition variable for notifying the background thread. - condition_variable background_cv_ TF_GUARDED_BY(mu_); - std::unique_ptr background_thread_; + // A thread for notifying the dispatcher when tasks complete. + std::unique_ptr task_completion_thread_; + condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); + // A thread for performing regular heartbeats to the dispatcher. + std::unique_ptr heartbeat_thread_; + condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); }; diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index d89392598d5..b17d2008778 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -200,6 +200,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { for (auto& worker_thread : worker_threads_) { worker_thread.reset(); } + + VLOG(1) << "Destroyed data service dataset iterator for job id " + << job_client_id_; } void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 233c18076cf..7a0aa16e2c4 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -15,6 +15,11 @@ message DispatcherConfig { // Whether to run in fault tolerant mode, where dispatcher state is saved // across restarts. Requires that `work_dir` is nonempty. bool fault_tolerant_mode = 4; + // How often the dispatcher should scan through to delete old and unused jobs. + int64 job_gc_check_interval_ms = 5; + // How long a job needs to be unused before it becomes a candidate for garbage + // collection. + int64 job_gc_timeout_ms = 6; } // Configuration for a tf.data service WorkerServer. @@ -30,4 +35,6 @@ message WorkerConfig { // will be replaced with the worker's bound port. This is useful when the port // is set to `0`. string worker_address = 4; + // How often the worker should heartbeat to the master. + int64 heartbeat_interval_ms = 5; } diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 8c5c85a169a..d964080ebba 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -29,9 +29,10 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("data.experimental.service.DispatcherConfig") class DispatcherConfig( - collections.namedtuple( - "DispatcherConfig", - ["port", "protocol", "work_dir", "fault_tolerant_mode"])): + collections.namedtuple("DispatcherConfig", [ + "port", "protocol", "work_dir", "fault_tolerant_mode", + "job_gc_check_interval_ms", "job_gc_timeout_ms" + ])): """Configuration class for tf.data service dispatchers. Fields: @@ -47,15 +48,34 @@ class DispatcherConfig( registered datasets and created jobs, is synchronously written to the journal before responding to RPCs. If `True`, `work_dir` must also be specified. + job_gc_check_interval_ms: How often the dispatcher should scan through to + delete old and unused jobs, in milliseconds. If not set, the runtime will + select a reasonable default. A higher value will reduce load on the + dispatcher, while a lower value will reduce the time it takes for the + dispatcher to garbage collect expired jobs. + job_gc_timeout_ms: How long a job needs to be unused before it becomes a + candidate for garbage collection, in milliseconds. If not set, the runtime + will select a reasonable default. A higher value will cause jobs to stay + around longer with no consumers. This is useful if there is a large gap in + time between when consumers read from the job. A lower value will reduce + the time it takes to reclaim the resources from expired jobs. """ def __new__(cls, port=0, protocol="grpc", work_dir=None, - fault_tolerant_mode=False): - return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir, - fault_tolerant_mode) + fault_tolerant_mode=False, + job_gc_check_interval_ms=None, + job_gc_timeout_ms=None): + if job_gc_check_interval_ms is None: + job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes. + if job_gc_timeout_ms is None: + job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes. + return super(DispatcherConfig, + cls).__new__(cls, port, protocol, work_dir, + fault_tolerant_mode, job_gc_check_interval_ms, + job_gc_timeout_ms) @tf_export("data.experimental.service.DispatchServer", v1=[]) @@ -116,7 +136,9 @@ class DispatchServer(object): port=config.port, protocol=config.protocol, work_dir=config.work_dir, - fault_tolerant_mode=config.fault_tolerant_mode) + fault_tolerant_mode=config.fault_tolerant_mode, + job_gc_check_interval_ms=config.job_gc_check_interval_ms, + job_gc_timeout_ms=config.job_gc_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( config_proto.SerializeToString()) if start: @@ -193,9 +215,10 @@ class DispatchServer(object): @tf_export("data.experimental.service.WorkerConfig") class WorkerConfig( - collections.namedtuple( - "WorkerConfig", - ["dispatcher_address", "worker_address", "port", "protocol"])): + collections.namedtuple("WorkerConfig", [ + "dispatcher_address", "worker_address", "port", "protocol", + "heartbeat_interval_ms" + ])): """Configuration class for tf.data service dispatchers. Fields: @@ -205,19 +228,29 @@ class WorkerConfig( connect to this worker. port: Specifies the port to bind to. A value of 0 indicates that the worker can bind to any available port. - protocol: (Optional.) Specifies the protocol to be used by the server. + protocol: Specifies the protocol to be used by the server. Acceptable values include `"grpc" and "grpc+local"`. + heartbeat_interval_ms: How often the worker should heartbeat to the + dispatcher, in milliseconds. If not set, the runtime will select a + reasonable default. A higher value will reduce the load on the dispatcher, + while a lower value will reduce the time it takes to reclaim resources + from finished jobs. """ def __new__(cls, dispatcher_address, worker_address=None, port=0, - protocol="grpc"): - worker_address = ("localhost:%port%" - if worker_address is None else worker_address) - return super(WorkerConfig, cls).__new__(cls, dispatcher_address, - worker_address, port, protocol) + protocol="grpc", + heartbeat_interval_ms=None): + if worker_address is None: + worker_address = "localhost:%port%" + if heartbeat_interval_ms is None: + heartbeat_interval_ms = 30 * 1000 # 30 seconds + + return super(WorkerConfig, + cls).__new__(cls, dispatcher_address, worker_address, port, + protocol, heartbeat_interval_ms) @tf_export("data.experimental.service.WorkerServer", v1=[]) @@ -264,7 +297,8 @@ class WorkerServer(object): dispatcher_address=config.dispatcher_address, worker_address=config.worker_address, port=config.port, - protocol=config.protocol) + protocol=config.protocol, + heartbeat_interval_ms=config.heartbeat_interval_ms) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: @@ -317,3 +351,7 @@ class WorkerServer(object): The returned string will be in the form address:port, e.g. "localhost:1000". """ return "localhost:{0}".format(self._server.bound_port()) + + def _num_tasks(self): + """Returns the number of tasks currently being executed on the worker.""" + return self._server.num_tasks() diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index 8ce904eecba..5a229f88d92 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -50,7 +50,14 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { .def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop) .def("join", &tensorflow::data::WorkerGrpcDataServer::Join, py::call_guard()) - .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort); + .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort) + .def("num_tasks", + [](tensorflow::data::WorkerGrpcDataServer* server) -> int { + int num_tasks; + tensorflow::Status status = server->NumTasks(&num_tasks); + tensorflow::MaybeRaiseFromStatus(status); + return num_tasks; + }); m.def( "TF_DATA_NewDispatchServer", diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index eb7a67d69ec..0733c2ffbd8 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -101,7 +101,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): name="", port=0, work_dir=None, - fault_tolerant_mode=True): + fault_tolerant_mode=True, + job_gc_check_interval_ms=None, + job_gc_timeout_ms=None): # If a test starts multiple independent dispatch servers, it should give # them different `name` values. work_dir = os.path.join(self.get_temp_dir(), "work_dir_", @@ -110,13 +112,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): server_lib.DispatcherConfig( port=port, work_dir=work_dir, - fault_tolerant_mode=fault_tolerant_mode)) + fault_tolerant_mode=fault_tolerant_mode, + job_gc_check_interval_ms=job_gc_check_interval_ms, + job_gc_timeout_ms=job_gc_timeout_ms)) def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), - port=port)) + port=port, + heartbeat_interval_ms=200)) def restart_dispatcher(self, dispatcher): """Stops `dispatcher` and returns a new dispatcher with the same port.""" @@ -535,6 +540,47 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) + @combinations.generate( + combinations.times(test_base.eager_only_combinations(), + combinations.combine(job_name=[None, "test"]))) + def testGcUnusedJob(self, job_name): + dispatcher = self.start_dispatch_server( + job_gc_check_interval_ms=50, job_gc_timeout_ms=20) + worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable + num_elements = 10 + ds = _make_distributed_range_dataset( + num_elements, dispatcher, job_name=job_name) + it = iter(ds) + self.assertEqual(0, next(it).numpy()) + self.assertEqual(1, worker._num_tasks()) + del it + while worker._num_tasks() > 0: + time.sleep(0.1) + + @combinations.generate(test_base.eager_only_combinations()) + def testDontGcUsedJob(self): + dispatcher = self.start_dispatch_server( + job_gc_check_interval_ms=50, job_gc_timeout_ms=20) + worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable + num_elements = 10 + it1 = iter( + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test1")) + it2 = iter( + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test2")) + it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable + _make_distributed_range_dataset( + num_elements, dispatcher, job_name="test2")) + self.assertEqual(2, worker._num_tasks()) + del it1 + del it2 + # Check that only the first job is gced. The second job will not be gced + # because there is still an outstanding iterator for it. + while worker._num_tasks() > 1: + time.sleep(0.1) + self.assertEqual(1, worker._num_tasks()) + @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt index 3efc1c97797..2364e41e4c5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-dispatcher-config.pbtxt @@ -7,6 +7,14 @@ tf_class { name: "fault_tolerant_mode" mtype: "" } + member { + name: "job_gc_check_interval_ms" + mtype: "" + } + member { + name: "job_gc_timeout_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt index 7bf0bc136e6..d8eaf9bc7d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "heartbeat_interval_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt index 3efc1c97797..2364e41e4c5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatcher-config.pbtxt @@ -7,6 +7,14 @@ tf_class { name: "fault_tolerant_mode" mtype: "" } + member { + name: "job_gc_check_interval_ms" + mtype: "" + } + member { + name: "job_gc_timeout_ms" + mtype: "" + } member { name: "port" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt index 7bf0bc136e6..d8eaf9bc7d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "heartbeat_interval_ms" + mtype: "" + } member { name: "port" mtype: ""