[tf.data service] Garbage collect old and unused jobs.
This CL adds a worker heartbeat which does the following: - Registers the worker with the dispatcher if the worker is not yet registered - Reports to the dispatcher which tasks the worker is currently processing - Learns from the dispatcher which tasks should be added and which should be deleted. Learning about new tasks is important in case the dispatcher's original ProcessTask request to the worker failed due to network issues. The heartbeat provides a backup mechanism where the worker can still learn about what tasks it should process. Deleting old tasks is important for job lifecycle management. When the dispatcher detects that a job is old and unused, it will mark the job (and all of its tasks) as finished. The worker will learn about the task finishing during the next heartbeat. PiperOrigin-RevId: 330990312 Change-Id: Ifc0dbb3465ff478de83bfb3a265d9a57072d852a
This commit is contained in:
parent
660f6e0610
commit
8e789c3872
@ -54,19 +54,26 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceDispatcherClient::RegisterWorker(
|
Status DataServiceDispatcherClient::WorkerHeartbeat(
|
||||||
const std::string& worker_address, std::vector<TaskDef>& tasks) {
|
const std::string& worker_address, const std::vector<int64>& current_tasks,
|
||||||
|
std::vector<TaskDef>& new_tasks, std::vector<int64>& tasks_to_delete) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
RegisterWorkerRequest req;
|
WorkerHeartbeatRequest req;
|
||||||
req.set_worker_address(worker_address);
|
req.set_worker_address(worker_address);
|
||||||
RegisterWorkerResponse resp;
|
for (int64 task : current_tasks) {
|
||||||
grpc::ClientContext client_ctx;
|
req.add_current_tasks(task);
|
||||||
grpc::Status status = stub_->RegisterWorker(&client_ctx, req, &resp);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return grpc_util::WrapError("Failed to register worker", status);
|
|
||||||
}
|
}
|
||||||
for (const auto& task : resp.tasks()) {
|
WorkerHeartbeatResponse resp;
|
||||||
tasks.push_back(task);
|
grpc::ClientContext client_ctx;
|
||||||
|
grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return grpc_util::WrapError("Failed to perform worker heartbeat", status);
|
||||||
|
}
|
||||||
|
for (const auto& task : resp.new_tasks()) {
|
||||||
|
new_tasks.push_back(task);
|
||||||
|
}
|
||||||
|
for (int64 task_to_delete : resp.tasks_to_delete()) {
|
||||||
|
tasks_to_delete.push_back(task_to_delete);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -73,10 +73,15 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
|
|||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: DataServiceClientBase(address, protocol) {}
|
: DataServiceClientBase(address, protocol) {}
|
||||||
|
|
||||||
// Registers a worker with the dispatcher. The dispatcher returns a list of
|
// Sends a heartbeat to the dispatcher. If the worker wasn't already
|
||||||
// initial tasks for the worker to run, storing them in `tasks`.
|
// registered with the dispatcher, this will register the worker. The
|
||||||
Status RegisterWorker(const std::string& worker_address,
|
// dispatcher will report which new tasks the worker should run, and which
|
||||||
std::vector<TaskDef>& tasks);
|
// tasks it should delete. This is stored into `new_tasks` and
|
||||||
|
// `tasks_to_delete`.
|
||||||
|
Status WorkerHeartbeat(const std::string& worker_address,
|
||||||
|
const std::vector<int64>& current_tasks,
|
||||||
|
std::vector<TaskDef>& new_tasks,
|
||||||
|
std::vector<int64>& tasks_to_delete);
|
||||||
|
|
||||||
// Updates the dispatcher with information about the worker's state.
|
// Updates the dispatcher with information about the worker's state.
|
||||||
Status WorkerUpdate(const std::string& worker_address,
|
Status WorkerUpdate(const std::string& worker_address,
|
||||||
|
@ -4,16 +4,6 @@ package tensorflow.data;
|
|||||||
|
|
||||||
import "tensorflow/core/data/service/common.proto";
|
import "tensorflow/core/data/service/common.proto";
|
||||||
|
|
||||||
message RegisterWorkerRequest {
|
|
||||||
// The address of the registering worker.
|
|
||||||
string worker_address = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message RegisterWorkerResponse {
|
|
||||||
// Tasks to begin processing.
|
|
||||||
repeated TaskDef tasks = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message TaskProgress {
|
message TaskProgress {
|
||||||
// The task that this message is about.
|
// The task that this message is about.
|
||||||
int64 task_id = 1;
|
int64 task_id = 1;
|
||||||
@ -21,6 +11,16 @@ message TaskProgress {
|
|||||||
bool completed = 2;
|
bool completed = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message WorkerHeartbeatRequest {
|
||||||
|
string worker_address = 1;
|
||||||
|
repeated int64 current_tasks = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WorkerHeartbeatResponse {
|
||||||
|
repeated TaskDef new_tasks = 1;
|
||||||
|
repeated int64 tasks_to_delete = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message WorkerUpdateRequest {
|
message WorkerUpdateRequest {
|
||||||
string worker_address = 1;
|
string worker_address = 1;
|
||||||
repeated TaskProgress updates = 2;
|
repeated TaskProgress updates = 2;
|
||||||
@ -110,8 +110,8 @@ message GetWorkersResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
service DispatcherService {
|
service DispatcherService {
|
||||||
// Registers a worker with the dispatcher.
|
// Performs a periodic worker heartbeat.
|
||||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
rpc WorkerHeartbeat(WorkerHeartbeatRequest) returns (WorkerHeartbeatResponse);
|
||||||
|
|
||||||
// Updates the dispatcher with information about the worker's state.
|
// Updates the dispatcher with information about the worker's state.
|
||||||
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
||||||
|
@ -85,7 +85,7 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
|
|||||||
|
|
||||||
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
||||||
const experimental::DispatcherConfig& config)
|
const experimental::DispatcherConfig& config)
|
||||||
: config_(config) {
|
: config_(config), env_(Env::Default()) {
|
||||||
if (config_.work_dir().empty()) {
|
if (config_.work_dir().empty()) {
|
||||||
dataset_store_ = absl::make_unique<MemoryDatasetStore>();
|
dataset_store_ = absl::make_unique<MemoryDatasetStore>();
|
||||||
} else {
|
} else {
|
||||||
@ -94,8 +94,19 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
cancelled_ = true;
|
||||||
|
job_gc_thread_cv_.notify_all();
|
||||||
|
}
|
||||||
|
job_gc_thread_.reset();
|
||||||
|
}
|
||||||
|
|
||||||
Status DataServiceDispatcherImpl::Start() {
|
Status DataServiceDispatcherImpl::Start() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
job_gc_thread_ = absl::WrapUnique(
|
||||||
|
env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
|
||||||
if (config_.work_dir().empty()) {
|
if (config_.work_dir().empty()) {
|
||||||
if (config_.fault_tolerant_mode()) {
|
if (config_.fault_tolerant_mode()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -103,7 +114,7 @@ Status DataServiceDispatcherImpl::Start() {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
|
env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
|
||||||
}
|
}
|
||||||
if (!config_.fault_tolerant_mode()) {
|
if (!config_.fault_tolerant_mode()) {
|
||||||
LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
|
LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
|
||||||
@ -111,12 +122,12 @@ Status DataServiceDispatcherImpl::Start() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
journal_writer_ = absl::make_unique<FileJournalWriter>(
|
journal_writer_ = absl::make_unique<FileJournalWriter>(
|
||||||
Env::Default(), JournalDir(config_.work_dir()));
|
env_, JournalDir(config_.work_dir()));
|
||||||
LOG(INFO) << "Restoring dispatcher state from journal in "
|
LOG(INFO) << "Attempting to restore dispatcher state from journal in "
|
||||||
<< JournalDir(config_.work_dir());
|
<< JournalDir(config_.work_dir());
|
||||||
Update update;
|
Update update;
|
||||||
bool end_of_journal = false;
|
bool end_of_journal = false;
|
||||||
FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir()));
|
FileJournalReader reader(env_, JournalDir(config_.work_dir()));
|
||||||
Status s = reader.Read(update, end_of_journal);
|
Status s = reader.Read(update, end_of_journal);
|
||||||
if (errors::IsNotFound(s)) {
|
if (errors::IsNotFound(s)) {
|
||||||
LOG(INFO) << "No journal found. Starting dispatcher from new state.";
|
LOG(INFO) << "No journal found. Starting dispatcher from new state.";
|
||||||
@ -134,45 +145,38 @@ Status DataServiceDispatcherImpl::Start() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceDispatcherImpl::RegisterWorker(
|
Status DataServiceDispatcherImpl::WorkerHeartbeat(
|
||||||
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
|
||||||
VLOG(3) << "Received register worker request";
|
VLOG(3) << "Received worker heartbeat request from worker "
|
||||||
|
<< request->worker_address();
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
std::string worker_address = request->worker_address();
|
const std::string& worker_address = request->worker_address();
|
||||||
std::vector<std::shared_ptr<const Task>> tasks;
|
std::vector<std::shared_ptr<const Task>> correct_tasks;
|
||||||
Status s = state_.TasksForWorker(worker_address, tasks);
|
Status s = state_.TasksForWorker(worker_address, correct_tasks);
|
||||||
if (errors::IsNotFound(s)) {
|
if (!s.ok()) {
|
||||||
|
if (!errors::IsNotFound(s)) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
Update update;
|
Update update;
|
||||||
update.mutable_register_worker()->set_worker_address(worker_address);
|
update.mutable_register_worker()->set_worker_address(worker_address);
|
||||||
TF_RETURN_IF_ERROR(Apply(update));
|
TF_RETURN_IF_ERROR(Apply(update));
|
||||||
} else if (!s.ok()) {
|
TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
|
||||||
return s;
|
TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_map<int64, std::shared_ptr<const Task>> tasks_by_job;
|
absl::flat_hash_set<int64> current_tasks;
|
||||||
for (const auto& task : tasks) {
|
current_tasks.insert(request->current_tasks().cbegin(),
|
||||||
// Should never have multiple tasks on the same worker for the same job.
|
request->current_tasks().cend());
|
||||||
auto& task_for_job = tasks_by_job[task->job_id];
|
absl::flat_hash_set<int64> correct_tasks_set;
|
||||||
DCHECK(task_for_job == nullptr);
|
|
||||||
task_for_job = task;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
|
for (const auto& task : correct_tasks) {
|
||||||
// Allocate tasks to the worker.
|
correct_tasks_set.insert(task->task_id);
|
||||||
for (const auto& job : jobs) {
|
if (current_tasks.contains(task->task_id)) {
|
||||||
if (job->finished) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::shared_ptr<const Task> task;
|
TaskDef* task_def = response->add_new_tasks();
|
||||||
auto it = tasks_by_job.find(job->job_id);
|
|
||||||
if (it != tasks_by_job.end()) {
|
|
||||||
task = it->second;
|
|
||||||
} else {
|
|
||||||
TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
|
|
||||||
}
|
|
||||||
TaskDef* task_def = response->add_tasks();
|
|
||||||
std::shared_ptr<const Dataset> dataset;
|
std::shared_ptr<const Dataset> dataset;
|
||||||
TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset));
|
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset));
|
||||||
std::string dataset_key =
|
std::string dataset_key =
|
||||||
DatasetKey(dataset->dataset_id, dataset->fingerprint);
|
DatasetKey(dataset->dataset_id, dataset->fingerprint);
|
||||||
if (config_.work_dir().empty()) {
|
if (config_.work_dir().empty()) {
|
||||||
@ -184,12 +188,18 @@ Status DataServiceDispatcherImpl::RegisterWorker(
|
|||||||
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
|
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
|
||||||
task_def->set_path(path);
|
task_def->set_path(path);
|
||||||
}
|
}
|
||||||
task_def->set_dataset_id(job->dataset_id);
|
task_def->set_dataset_id(task->dataset_id);
|
||||||
task_def->set_job_id(job->job_id);
|
task_def->set_job_id(task->job_id);
|
||||||
task_def->set_task_id(task->task_id);
|
task_def->set_task_id(task->task_id);
|
||||||
}
|
}
|
||||||
|
for (int64 current_task : current_tasks) {
|
||||||
|
if (!correct_tasks_set.contains(current_task)) {
|
||||||
|
response->add_tasks_to_delete(current_task);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(1) << "Registered worker at address " << request->worker_address();
|
VLOG(1) << "Finished worker heartbeat for worker at address "
|
||||||
|
<< request->worker_address();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -346,7 +356,7 @@ Status DataServiceDispatcherImpl::ReleaseJobClient(
|
|||||||
ReleaseJobClientUpdate* release_job_client =
|
ReleaseJobClientUpdate* release_job_client =
|
||||||
update.mutable_release_job_client();
|
update.mutable_release_job_client();
|
||||||
release_job_client->set_job_client_id(job_client_id);
|
release_job_client->set_job_client_id(job_client_id);
|
||||||
release_job_client->set_time_micros(Env::Default()->NowMicros());
|
release_job_client->set_time_micros(env_->NowMicros());
|
||||||
TF_RETURN_IF_ERROR(Apply(update));
|
TF_RETURN_IF_ERROR(Apply(update));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -412,6 +422,19 @@ Status DataServiceDispatcherImpl::CreateJob(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DataServiceDispatcherImpl::CreateTasksForWorker(
|
||||||
|
const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
|
||||||
|
for (const auto& job : jobs) {
|
||||||
|
if (job->finished) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::shared_ptr<const Task> task;
|
||||||
|
TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status DataServiceDispatcherImpl::AcquireJobClientId(
|
Status DataServiceDispatcherImpl::AcquireJobClientId(
|
||||||
const std::shared_ptr<const Job>& job, int64& job_client_id)
|
const std::shared_ptr<const Job>& job, int64& job_client_id)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
@ -576,5 +599,51 @@ Status DataServiceDispatcherImpl::Apply(const Update& update)
|
|||||||
return state_.Apply(update);
|
return state_.Apply(update);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DataServiceDispatcherImpl::JobGcThread() {
|
||||||
|
int64 next_check_micros = 0;
|
||||||
|
while (true) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!cancelled_ && env_->NowMicros() < next_check_micros) {
|
||||||
|
int64 remaining_micros = next_check_micros - env_->NowMicros();
|
||||||
|
job_gc_thread_cv_.wait_for(l,
|
||||||
|
std::chrono::microseconds(remaining_micros));
|
||||||
|
}
|
||||||
|
if (cancelled_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Status s = GcOldJobs();
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Error garbage collecting old jobs: " << s;
|
||||||
|
}
|
||||||
|
next_check_micros =
|
||||||
|
env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DataServiceDispatcherImpl::GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
|
||||||
|
int64 now = env_->NowMicros();
|
||||||
|
for (const auto& job : jobs) {
|
||||||
|
if (job->finished || job->num_clients > 0 ||
|
||||||
|
job->last_client_released_micros < 0 ||
|
||||||
|
now < job->last_client_released_micros +
|
||||||
|
(config_.job_gc_timeout_ms() * 1000)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<std::shared_ptr<const Task>> tasks;
|
||||||
|
TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
|
||||||
|
for (const auto& task : tasks) {
|
||||||
|
if (task->finished) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Update update;
|
||||||
|
update.mutable_finish_task()->set_task_id(task->task_id);
|
||||||
|
TF_RETURN_IF_ERROR(state_.Apply(update));
|
||||||
|
}
|
||||||
|
DCHECK(job->finished);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -48,6 +48,8 @@ class DataServiceDispatcherImpl {
|
|||||||
explicit DataServiceDispatcherImpl(
|
explicit DataServiceDispatcherImpl(
|
||||||
const experimental::DispatcherConfig& config);
|
const experimental::DispatcherConfig& config);
|
||||||
|
|
||||||
|
~DataServiceDispatcherImpl();
|
||||||
|
|
||||||
// Starts the dispatcher. If there is a journal, this will read from the
|
// Starts the dispatcher. If there is a journal, this will read from the
|
||||||
// journal to restore the dispatcher's state.
|
// journal to restore the dispatcher's state.
|
||||||
Status Start();
|
Status Start();
|
||||||
@ -55,8 +57,8 @@ class DataServiceDispatcherImpl {
|
|||||||
// See dispatcher.proto for API documentation.
|
// See dispatcher.proto for API documentation.
|
||||||
|
|
||||||
/// Worker-facing API.
|
/// Worker-facing API.
|
||||||
Status RegisterWorker(const RegisterWorkerRequest* request,
|
Status WorkerHeartbeat(const WorkerHeartbeatRequest* request,
|
||||||
RegisterWorkerResponse* response);
|
WorkerHeartbeatResponse* response);
|
||||||
Status WorkerUpdate(const WorkerUpdateRequest* request,
|
Status WorkerUpdate(const WorkerUpdateRequest* request,
|
||||||
WorkerUpdateResponse* response);
|
WorkerUpdateResponse* response);
|
||||||
Status GetDatasetDef(const GetDatasetDefRequest* request,
|
Status GetDatasetDef(const GetDatasetDefRequest* request,
|
||||||
@ -92,6 +94,8 @@ class DataServiceDispatcherImpl {
|
|||||||
absl::optional<DispatcherState::NamedJobKey> named_job_key,
|
absl::optional<DispatcherState::NamedJobKey> named_job_key,
|
||||||
std::shared_ptr<const DispatcherState::Job>& job)
|
std::shared_ptr<const DispatcherState::Job>& job)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
// Creates tasks for the specified worker, one task for every unfinished job.
|
||||||
|
Status CreateTasksForWorker(const std::string& worker_address);
|
||||||
// Acquires a job client id to read from the given job and sets
|
// Acquires a job client id to read from the given job and sets
|
||||||
// `job_client_id`.
|
// `job_client_id`.
|
||||||
Status AcquireJobClientId(
|
Status AcquireJobClientId(
|
||||||
@ -128,12 +132,16 @@ class DataServiceDispatcherImpl {
|
|||||||
// used when recovering state when the dispatcher starts.
|
// used when recovering state when the dispatcher starts.
|
||||||
Status ApplyWithoutJournaling(const Update& update)
|
Status ApplyWithoutJournaling(const Update& update)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
// A thread which periodically checks for jobs to clean up.
|
||||||
|
void JobGcThread();
|
||||||
|
// Scans for old jobs and marks them as finished.
|
||||||
|
Status GcOldJobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
const experimental::DispatcherConfig& config_;
|
const experimental::DispatcherConfig& config_;
|
||||||
|
Env* env_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||||
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
|
|
||||||
|
|
||||||
// Cached worker stubs for communicating with workers.
|
// Cached worker stubs for communicating with workers.
|
||||||
absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
|
absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
|
||||||
@ -144,6 +152,9 @@ class DataServiceDispatcherImpl {
|
|||||||
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
|
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
DispatcherState state_ TF_GUARDED_BY(mu_);
|
DispatcherState state_ TF_GUARDED_BY(mu_);
|
||||||
|
// Condition variable for waking up the job gc thread.
|
||||||
|
condition_variable job_gc_thread_cv_;
|
||||||
|
std::unique_ptr<Thread> job_gc_thread_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
|
||||||
};
|
};
|
||||||
|
@ -72,7 +72,8 @@ void DispatcherState::RegisterWorker(
|
|||||||
std::string address = register_worker.worker_address();
|
std::string address = register_worker.worker_address();
|
||||||
DCHECK(!workers_.contains(address));
|
DCHECK(!workers_.contains(address));
|
||||||
workers_[address] = std::make_shared<Worker>(address);
|
workers_[address] = std::make_shared<Worker>(address);
|
||||||
tasks_by_worker_[address] = std::vector<std::shared_ptr<Task>>();
|
tasks_by_worker_[address] =
|
||||||
|
absl::flat_hash_map<int64, std::shared_ptr<Task>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
|
void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
|
||||||
@ -126,7 +127,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
|
|||||||
create_task.dataset_id(),
|
create_task.dataset_id(),
|
||||||
create_task.worker_address());
|
create_task.worker_address());
|
||||||
tasks_by_job_[create_task.job_id()].push_back(task);
|
tasks_by_job_[create_task.job_id()].push_back(task);
|
||||||
tasks_by_worker_[create_task.worker_address()].push_back(task);
|
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
|
||||||
next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
|
next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,6 +137,7 @@ void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
|
|||||||
auto& task = tasks_[task_id];
|
auto& task = tasks_[task_id];
|
||||||
DCHECK(task != nullptr);
|
DCHECK(task != nullptr);
|
||||||
task->finished = true;
|
task->finished = true;
|
||||||
|
tasks_by_worker_[task->worker_address].erase(task->task_id);
|
||||||
bool all_finished = true;
|
bool all_finished = true;
|
||||||
for (const auto& task_for_job : tasks_by_job_[task->job_id]) {
|
for (const auto& task_for_job : tasks_by_job_[task->job_id]) {
|
||||||
if (!task_for_job->finished) {
|
if (!task_for_job->finished) {
|
||||||
@ -269,10 +271,11 @@ Status DispatcherState::TasksForWorker(
|
|||||||
if (it == tasks_by_worker_.end()) {
|
if (it == tasks_by_worker_.end()) {
|
||||||
return errors::NotFound("Worker ", worker_address, " not found");
|
return errors::NotFound("Worker ", worker_address, " not found");
|
||||||
}
|
}
|
||||||
std::vector<std::shared_ptr<Task>> worker_tasks = it->second;
|
const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
|
||||||
|
it->second;
|
||||||
tasks.reserve(worker_tasks.size());
|
tasks.reserve(worker_tasks.size());
|
||||||
for (const auto& task : worker_tasks) {
|
for (const auto& task : worker_tasks) {
|
||||||
tasks.push_back(task);
|
tasks.push_back(task.second);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ class DispatcherState {
|
|||||||
void CreateTask(const CreateTaskUpdate& create_task);
|
void CreateTask(const CreateTaskUpdate& create_task);
|
||||||
void FinishTask(const FinishTaskUpdate& finish_task);
|
void FinishTask(const FinishTaskUpdate& finish_task);
|
||||||
|
|
||||||
int64 next_available_dataset_id_ = 0;
|
int64 next_available_dataset_id_ = 1000;
|
||||||
// Registered datasets, keyed by dataset ids.
|
// Registered datasets, keyed by dataset ids.
|
||||||
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_;
|
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_;
|
||||||
// Registered datasets, keyed by dataset fingerprints.
|
// Registered datasets, keyed by dataset fingerprints.
|
||||||
@ -189,24 +189,26 @@ class DispatcherState {
|
|||||||
// Registered workers, keyed by address.
|
// Registered workers, keyed by address.
|
||||||
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
|
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
|
||||||
|
|
||||||
int64 next_available_job_id_ = 0;
|
int64 next_available_job_id_ = 2000;
|
||||||
// Jobs, keyed by job ids.
|
// Jobs, keyed by job ids.
|
||||||
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
|
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
|
||||||
// Named jobs, keyed by their names and indices. Not all jobs have names, so
|
// Named jobs, keyed by their names and indices. Not all jobs have names, so
|
||||||
// this is a subset of the jobs stored in `jobs_`.
|
// this is a subset of the jobs stored in `jobs_`.
|
||||||
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_;
|
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_;
|
||||||
|
|
||||||
int64 next_available_job_client_id_ = 0;
|
int64 next_available_job_client_id_ = 3000;
|
||||||
// Mapping from client ids to the jobs they are associated with.
|
// Mapping from client ids to the jobs they are associated with.
|
||||||
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_;
|
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_;
|
||||||
|
|
||||||
int64 next_available_task_id_ = 0;
|
int64 next_available_task_id_ = 4000;
|
||||||
// Tasks, keyed by task ids.
|
// Tasks, keyed by task ids.
|
||||||
absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_;
|
absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_;
|
||||||
// Tasks, keyed by job ids.
|
// Tasks, keyed by job ids.
|
||||||
absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_;
|
absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_;
|
||||||
// Tasks, keyed by worker addresses.
|
// Tasks, keyed by worker addresses. The values are a map from task id to
|
||||||
absl::flat_hash_map<std::string, std::vector<std::shared_ptr<Task>>>
|
// task.
|
||||||
|
absl::flat_hash_map<std::string,
|
||||||
|
absl::flat_hash_map<int64, std::shared_ptr<Task>>>
|
||||||
tasks_by_worker_;
|
tasks_by_worker_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -125,9 +125,9 @@ Status FinishTask(int64 task_id, DispatcherState& state) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST(DispatcherState, RegisterDataset) {
|
TEST(DispatcherState, RegisterDataset) {
|
||||||
int64 id = 10;
|
|
||||||
uint64 fingerprint = 20;
|
uint64 fingerprint = 20;
|
||||||
DispatcherState state;
|
DispatcherState state;
|
||||||
|
int64 id = state.NextAvailableDatasetId();
|
||||||
TF_EXPECT_OK(RegisterDataset(id, fingerprint, state));
|
TF_EXPECT_OK(RegisterDataset(id, fingerprint, state));
|
||||||
EXPECT_EQ(state.NextAvailableDatasetId(), id + 1);
|
EXPECT_EQ(state.NextAvailableDatasetId(), id + 1);
|
||||||
|
|
||||||
@ -210,9 +210,9 @@ TEST(DispatcherState, UnknownUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DispatcherState, AnonymousJob) {
|
TEST(DispatcherState, AnonymousJob) {
|
||||||
int64 job_id = 3;
|
|
||||||
int64 dataset_id = 10;
|
int64 dataset_id = 10;
|
||||||
DispatcherState state;
|
DispatcherState state;
|
||||||
|
int64 job_id = state.NextAvailableJobId();
|
||||||
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
||||||
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
|
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
|
||||||
std::shared_ptr<const Job> job;
|
std::shared_ptr<const Job> job;
|
||||||
@ -227,9 +227,9 @@ TEST(DispatcherState, AnonymousJob) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DispatcherState, NamedJob) {
|
TEST(DispatcherState, NamedJob) {
|
||||||
int64 job_id = 3;
|
|
||||||
int64 dataset_id = 10;
|
int64 dataset_id = 10;
|
||||||
DispatcherState state;
|
DispatcherState state;
|
||||||
|
int64 job_id = state.NextAvailableJobId();
|
||||||
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
||||||
NamedJobKey named_job_key("test", 1);
|
NamedJobKey named_job_key("test", 1);
|
||||||
TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state));
|
TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state));
|
||||||
@ -244,9 +244,9 @@ TEST(DispatcherState, NamedJob) {
|
|||||||
TEST(DispatcherState, CreateTask) {
|
TEST(DispatcherState, CreateTask) {
|
||||||
int64 job_id = 3;
|
int64 job_id = 3;
|
||||||
int64 dataset_id = 10;
|
int64 dataset_id = 10;
|
||||||
int64 task_id = 8;
|
|
||||||
std::string worker_address = "test_worker_address";
|
std::string worker_address = "test_worker_address";
|
||||||
DispatcherState state;
|
DispatcherState state;
|
||||||
|
int64 task_id = state.NextAvailableTaskId();
|
||||||
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
TF_EXPECT_OK(RegisterDataset(dataset_id, state));
|
||||||
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
|
TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
|
||||||
TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state));
|
TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state));
|
||||||
|
@ -40,7 +40,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); }
|
|||||||
method##Response* response) { \
|
method##Response* response) { \
|
||||||
return ToGrpcStatus(impl_.method(request, response)); \
|
return ToGrpcStatus(impl_.method(request, response)); \
|
||||||
}
|
}
|
||||||
HANDLER(RegisterWorker);
|
HANDLER(WorkerHeartbeat);
|
||||||
HANDLER(WorkerUpdate);
|
HANDLER(WorkerUpdate);
|
||||||
HANDLER(GetDatasetDef);
|
HANDLER(GetDatasetDef);
|
||||||
HANDLER(GetOrRegisterDataset);
|
HANDLER(GetOrRegisterDataset);
|
||||||
|
@ -39,7 +39,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
|
|||||||
::grpc::Status method(::grpc::ServerContext* context, \
|
::grpc::Status method(::grpc::ServerContext* context, \
|
||||||
const method##Request* request, \
|
const method##Request* request, \
|
||||||
method##Response* response) override;
|
method##Response* response) override;
|
||||||
HANDLER(RegisterWorker);
|
HANDLER(WorkerHeartbeat);
|
||||||
HANDLER(WorkerUpdate);
|
HANDLER(WorkerUpdate);
|
||||||
HANDLER(GetDatasetDef);
|
HANDLER(GetDatasetDef);
|
||||||
HANDLER(GetOrRegisterDataset);
|
HANDLER(GetOrRegisterDataset);
|
||||||
|
@ -43,6 +43,7 @@ Status GrpcWorkerImpl::Start(const std::string& worker_address) {
|
|||||||
}
|
}
|
||||||
HANDLER(ProcessTask);
|
HANDLER(ProcessTask);
|
||||||
HANDLER(GetElement);
|
HANDLER(GetElement);
|
||||||
|
HANDLER(GetWorkerTasks);
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -41,6 +41,7 @@ class GrpcWorkerImpl : public WorkerService::Service {
|
|||||||
method##Response* response) override;
|
method##Response* response) override;
|
||||||
HANDLER(ProcessTask);
|
HANDLER(ProcessTask);
|
||||||
HANDLER(GetElement);
|
HANDLER(GetElement);
|
||||||
|
HANDLER(GetWorkerTasks);
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -138,6 +138,18 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
|
||||||
|
GetWorkerTasksRequest req;
|
||||||
|
GetWorkerTasksResponse resp;
|
||||||
|
::grpc::ServerContext ctx;
|
||||||
|
::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
|
||||||
|
if (!s.ok()) {
|
||||||
|
return grpc_util::WrapError("Failed to get tasks", s);
|
||||||
|
}
|
||||||
|
*num_tasks = resp.tasks_size();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||||
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
|
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
|
||||||
out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
||||||
|
@ -98,6 +98,9 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
|||||||
explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config);
|
explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config);
|
||||||
~WorkerGrpcDataServer() override;
|
~WorkerGrpcDataServer() override;
|
||||||
|
|
||||||
|
// Returns the number of tasks currently being executed by the worker.
|
||||||
|
Status NumTasks(int* num_tasks);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
||||||
Status StartServiceInternal() override;
|
Status StartServiceInternal() override;
|
||||||
|
@ -23,10 +23,20 @@ message GetElementResponse {
|
|||||||
bool end_of_sequence = 2;
|
bool end_of_sequence = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
|
||||||
|
message GetWorkerTasksRequest {}
|
||||||
|
|
||||||
|
message GetWorkerTasksResponse {
|
||||||
|
repeated TaskInfo tasks = 1;
|
||||||
|
}
|
||||||
|
|
||||||
service WorkerService {
|
service WorkerService {
|
||||||
// Processes an task for a dataset, making elements available to clients.
|
// Processes an task for a dataset, making elements available to clients.
|
||||||
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
|
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
|
||||||
|
|
||||||
// Gets the next dataset element.
|
// Gets the next dataset element.
|
||||||
rpc GetElement(GetElementRequest) returns (GetElementResponse);
|
rpc GetElement(GetElementRequest) returns (GetElementResponse);
|
||||||
|
|
||||||
|
// Gets the tasks currently being executed by the worker.
|
||||||
|
rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse);
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,8 @@ DataServiceWorkerImpl::DataServiceWorkerImpl(
|
|||||||
DataServiceWorkerImpl::~DataServiceWorkerImpl() {
|
DataServiceWorkerImpl::~DataServiceWorkerImpl() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
cancelled_ = true;
|
cancelled_ = true;
|
||||||
background_cv_.notify_one();
|
task_completion_cv_.notify_one();
|
||||||
|
heartbeat_cv_.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
||||||
@ -67,18 +68,20 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
|||||||
config_.dispatcher_address(), config_.protocol());
|
config_.dispatcher_address(), config_.protocol());
|
||||||
TF_RETURN_IF_ERROR(dispatcher_->Initialize());
|
TF_RETURN_IF_ERROR(dispatcher_->Initialize());
|
||||||
|
|
||||||
Status s = Register();
|
Status s = Heartbeat();
|
||||||
while (!s.ok()) {
|
while (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to register with dispatcher at "
|
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||||
<< config_.dispatcher_address() << ": " << s;
|
<< config_.dispatcher_address() << ": " << s;
|
||||||
Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
|
Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
|
||||||
s = Register();
|
s = Heartbeat();
|
||||||
}
|
}
|
||||||
Thread* thread = Env::Default()->StartThread(
|
|
||||||
{}, "data-service-worker-background", [this]() { BackgroundThread(); });
|
|
||||||
LOG(INFO) << "Worker registered with dispatcher running at "
|
LOG(INFO) << "Worker registered with dispatcher running at "
|
||||||
<< config_.dispatcher_address();
|
<< config_.dispatcher_address();
|
||||||
background_thread_.reset(thread);
|
task_completion_thread_ = absl::WrapUnique(
|
||||||
|
Env::Default()->StartThread({}, "data-service-worker-task-completion",
|
||||||
|
[this]() { TaskCompletionThread(); }));
|
||||||
|
heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
|
||||||
|
{}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
registered_ = true;
|
registered_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -96,8 +99,9 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
|
|||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
|
std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
|
||||||
if (task) {
|
if (task) {
|
||||||
return errors::AlreadyExists("A task with id ", task_def.task_id(),
|
VLOG(1) << "Received request to process already-processed task "
|
||||||
" already exists.");
|
<< task->task_def.task_id();
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
task = absl::make_unique<Task>(task_def);
|
task = absl::make_unique<Task>(task_def);
|
||||||
VLOG(3) << "Began processing for task " << task_def.task_id();
|
VLOG(3) << "Began processing for task " << task_def.task_id();
|
||||||
@ -156,24 +160,17 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
}
|
}
|
||||||
auto it = tasks_.find(request->task_id());
|
auto it = tasks_.find(request->task_id());
|
||||||
if (it == tasks_.end()) {
|
if (it == tasks_.end()) {
|
||||||
return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
|
|
||||||
"Task id ", request->task_id(), " not found");
|
|
||||||
}
|
|
||||||
auto& task = it->second;
|
|
||||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
|
||||||
std::unique_ptr<standalone::Iterator>& iter = task->iterator;
|
|
||||||
if (iter == nullptr) {
|
|
||||||
VLOG(3) << "Task " << request->task_id() << " is already finished";
|
|
||||||
response->set_end_of_sequence(true);
|
response->set_end_of_sequence(true);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence));
|
auto& task = it->second;
|
||||||
|
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||||
|
TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence));
|
||||||
if (end_of_sequence) {
|
if (end_of_sequence) {
|
||||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||||
// Release iterator memory and leave a null entry as a tombstone.
|
tasks_.erase(request->task_id());
|
||||||
iter.reset();
|
|
||||||
pending_completed_tasks_.insert(request->task_id());
|
pending_completed_tasks_.insert(request->task_id());
|
||||||
background_cv_.notify_one();
|
task_completion_cv_.notify_one();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,27 +209,28 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::Register() LOCKS_EXCLUDED(mu_) {
|
Status DataServiceWorkerImpl::GetWorkerTasks(
|
||||||
VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address();
|
const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
|
||||||
std::vector<TaskDef> tasks;
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(dispatcher_->RegisterWorker(worker_address_, tasks));
|
for (const auto& it : tasks_) {
|
||||||
for (const TaskDef& task : tasks) {
|
Task* task = it.second.get();
|
||||||
mutex_lock l(mu_);
|
TaskInfo* task_info = response->add_tasks();
|
||||||
TF_RETURN_IF_ERROR(ProcessTaskInternal(task));
|
task_info->set_worker_address(worker_address_);
|
||||||
|
task_info->set_task_id(task->task_def.task_id());
|
||||||
|
task_info->set_job_id(task->task_def.job_id());
|
||||||
}
|
}
|
||||||
VLOG(3) << "Registered worker with address " << worker_address_;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) {
|
void DataServiceWorkerImpl::TaskCompletionThread() LOCKS_EXCLUDED(mu_) {
|
||||||
while (true) {
|
while (true) {
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
while (!cancelled_ && pending_completed_tasks_.empty()) {
|
while (!cancelled_ && pending_completed_tasks_.empty()) {
|
||||||
background_cv_.wait(l);
|
task_completion_cv_.wait(l);
|
||||||
}
|
}
|
||||||
if (cancelled_) {
|
if (cancelled_) {
|
||||||
VLOG(3) << "Background thread shutting down";
|
VLOG(3) << "Task completion thread shutting down";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,7 +239,7 @@ void DataServiceWorkerImpl::BackgroundThread() LOCKS_EXCLUDED(mu_) {
|
|||||||
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
|
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!cancelled_) {
|
if (!cancelled_) {
|
||||||
background_cv_.wait_for(
|
task_completion_cv_.wait_for(
|
||||||
l, std::chrono::microseconds(kRetryIntervalMicros));
|
l, std::chrono::microseconds(kRetryIntervalMicros));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -271,5 +269,62 @@ Status DataServiceWorkerImpl::SendTaskUpdates() LOCKS_EXCLUDED(mu_) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DataServiceWorkerImpl::HeartbeatThread() LOCKS_EXCLUDED(mu_) {
|
||||||
|
while (true) {
|
||||||
|
int64 next_heartbeat_micros =
|
||||||
|
Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!cancelled_ &&
|
||||||
|
Env::Default()->NowMicros() < next_heartbeat_micros) {
|
||||||
|
int64 time_to_wait_micros =
|
||||||
|
next_heartbeat_micros - Env::Default()->NowMicros();
|
||||||
|
heartbeat_cv_.wait_for(l,
|
||||||
|
std::chrono::microseconds(time_to_wait_micros));
|
||||||
|
}
|
||||||
|
if (cancelled_) {
|
||||||
|
VLOG(3) << "Heartbeat thread shutting down";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!registered_) {
|
||||||
|
VLOG(1) << "Not performing heartbeat; worker is not yet registered";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Status s = Heartbeat();
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DataServiceWorkerImpl::Heartbeat() LOCKS_EXCLUDED(mu_) {
|
||||||
|
std::vector<int64> current_tasks;
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
for (const auto& task : tasks_) {
|
||||||
|
current_tasks.push_back(task.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<TaskDef> new_tasks;
|
||||||
|
std::vector<int64> tasks_to_delete;
|
||||||
|
TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
|
||||||
|
worker_address_, current_tasks, new_tasks, tasks_to_delete));
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
for (const auto& task : new_tasks) {
|
||||||
|
Status s = ProcessTaskInternal(task);
|
||||||
|
if (!s.ok() && !errors::IsAlreadyExists(s)) {
|
||||||
|
LOG(WARNING) << "Failed to start processing task " << task.task_id()
|
||||||
|
<< ": " << s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int64 task_id : tasks_to_delete) {
|
||||||
|
VLOG(3) << "Deleting task " << task_id
|
||||||
|
<< " at the request of the dispatcher";
|
||||||
|
tasks_.erase(task_id);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,6 +50,8 @@ class DataServiceWorkerImpl {
|
|||||||
/// Client-facing API.
|
/// Client-facing API.
|
||||||
Status GetElement(const GetElementRequest* request,
|
Status GetElement(const GetElementRequest* request,
|
||||||
GetElementResponse* response);
|
GetElementResponse* response);
|
||||||
|
Status GetWorkerTasks(const GetWorkerTasksRequest* request,
|
||||||
|
GetWorkerTasksResponse* response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Task {
|
struct Task {
|
||||||
@ -64,16 +66,17 @@ class DataServiceWorkerImpl {
|
|||||||
std::unique_ptr<standalone::Iterator> iterator;
|
std::unique_ptr<standalone::Iterator> iterator;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Registers the worker with the dispatcher.
|
|
||||||
Status Register() LOCKS_EXCLUDED(mu_);
|
|
||||||
// Sends task status to the dispatcher and checks for dispatcher commands.
|
// Sends task status to the dispatcher and checks for dispatcher commands.
|
||||||
Status SendTaskUpdates() LOCKS_EXCLUDED(mu_);
|
Status SendTaskUpdates() LOCKS_EXCLUDED(mu_);
|
||||||
// Creates an iterator to process a task.
|
// Creates an iterator to process a task.
|
||||||
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
Status EnsureTaskInitialized(Task& task);
|
Status EnsureTaskInitialized(Task& task);
|
||||||
// A thread for doing async background processing not associated with a
|
// A thread for notifying the dispatcher when tasks complete.
|
||||||
// specific RPC, such as reporting finished tasks.
|
void TaskCompletionThread() LOCKS_EXCLUDED(mu_);
|
||||||
void BackgroundThread() LOCKS_EXCLUDED(mu_);
|
// A thread for doing periodic heartbeats to the dispatcher.
|
||||||
|
void HeartbeatThread() LOCKS_EXCLUDED(mu_);
|
||||||
|
// Performs a heartbeat to the dispatcher.
|
||||||
|
Status Heartbeat() LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
const experimental::WorkerConfig config_;
|
const experimental::WorkerConfig config_;
|
||||||
// The worker's own address.
|
// The worker's own address.
|
||||||
@ -88,9 +91,12 @@ class DataServiceWorkerImpl {
|
|||||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||||
// Whether the worker has registered with the dispatcher yet.
|
// Whether the worker has registered with the dispatcher yet.
|
||||||
bool registered_ TF_GUARDED_BY(mu_) = false;
|
bool registered_ TF_GUARDED_BY(mu_) = false;
|
||||||
// Condition variable for notifying the background thread.
|
// A thread for notifying the dispatcher when tasks complete.
|
||||||
condition_variable background_cv_ TF_GUARDED_BY(mu_);
|
std::unique_ptr<Thread> task_completion_thread_;
|
||||||
std::unique_ptr<Thread> background_thread_;
|
condition_variable task_completion_cv_ TF_GUARDED_BY(mu_);
|
||||||
|
// A thread for performing regular heartbeats to the dispatcher.
|
||||||
|
std::unique_ptr<Thread> heartbeat_thread_;
|
||||||
|
condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
|
||||||
};
|
};
|
||||||
|
@ -200,6 +200,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
for (auto& worker_thread : worker_threads_) {
|
for (auto& worker_thread : worker_threads_) {
|
||||||
worker_thread.reset();
|
worker_thread.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Destroyed data service dataset iterator for job id "
|
||||||
|
<< job_client_id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
|
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
@ -15,6 +15,11 @@ message DispatcherConfig {
|
|||||||
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
||||||
// across restarts. Requires that `work_dir` is nonempty.
|
// across restarts. Requires that `work_dir` is nonempty.
|
||||||
bool fault_tolerant_mode = 4;
|
bool fault_tolerant_mode = 4;
|
||||||
|
// How often the dispatcher should scan through to delete old and unused jobs.
|
||||||
|
int64 job_gc_check_interval_ms = 5;
|
||||||
|
// How long a job needs to be unused before it becomes a candidate for garbage
|
||||||
|
// collection.
|
||||||
|
int64 job_gc_timeout_ms = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configuration for a tf.data service WorkerServer.
|
// Configuration for a tf.data service WorkerServer.
|
||||||
@ -30,4 +35,6 @@ message WorkerConfig {
|
|||||||
// will be replaced with the worker's bound port. This is useful when the port
|
// will be replaced with the worker's bound port. This is useful when the port
|
||||||
// is set to `0`.
|
// is set to `0`.
|
||||||
string worker_address = 4;
|
string worker_address = 4;
|
||||||
|
// How often the worker should heartbeat to the master.
|
||||||
|
int64 heartbeat_interval_ms = 5;
|
||||||
}
|
}
|
||||||
|
@ -29,9 +29,10 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
@tf_export("data.experimental.service.DispatcherConfig")
|
@tf_export("data.experimental.service.DispatcherConfig")
|
||||||
class DispatcherConfig(
|
class DispatcherConfig(
|
||||||
collections.namedtuple(
|
collections.namedtuple("DispatcherConfig", [
|
||||||
"DispatcherConfig",
|
"port", "protocol", "work_dir", "fault_tolerant_mode",
|
||||||
["port", "protocol", "work_dir", "fault_tolerant_mode"])):
|
"job_gc_check_interval_ms", "job_gc_timeout_ms"
|
||||||
|
])):
|
||||||
"""Configuration class for tf.data service dispatchers.
|
"""Configuration class for tf.data service dispatchers.
|
||||||
|
|
||||||
Fields:
|
Fields:
|
||||||
@ -47,15 +48,34 @@ class DispatcherConfig(
|
|||||||
registered datasets and created jobs, is synchronously written to the
|
registered datasets and created jobs, is synchronously written to the
|
||||||
journal before responding to RPCs. If `True`, `work_dir` must also be
|
journal before responding to RPCs. If `True`, `work_dir` must also be
|
||||||
specified.
|
specified.
|
||||||
|
job_gc_check_interval_ms: How often the dispatcher should scan through to
|
||||||
|
delete old and unused jobs, in milliseconds. If not set, the runtime will
|
||||||
|
select a reasonable default. A higher value will reduce load on the
|
||||||
|
dispatcher, while a lower value will reduce the time it takes for the
|
||||||
|
dispatcher to garbage collect expired jobs.
|
||||||
|
job_gc_timeout_ms: How long a job needs to be unused before it becomes a
|
||||||
|
candidate for garbage collection, in milliseconds. If not set, the runtime
|
||||||
|
will select a reasonable default. A higher value will cause jobs to stay
|
||||||
|
around longer with no consumers. This is useful if there is a large gap in
|
||||||
|
time between when consumers read from the job. A lower value will reduce
|
||||||
|
the time it takes to reclaim the resources from expired jobs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
port=0,
|
port=0,
|
||||||
protocol="grpc",
|
protocol="grpc",
|
||||||
work_dir=None,
|
work_dir=None,
|
||||||
fault_tolerant_mode=False):
|
fault_tolerant_mode=False,
|
||||||
return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir,
|
job_gc_check_interval_ms=None,
|
||||||
fault_tolerant_mode)
|
job_gc_timeout_ms=None):
|
||||||
|
if job_gc_check_interval_ms is None:
|
||||||
|
job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes.
|
||||||
|
if job_gc_timeout_ms is None:
|
||||||
|
job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes.
|
||||||
|
return super(DispatcherConfig,
|
||||||
|
cls).__new__(cls, port, protocol, work_dir,
|
||||||
|
fault_tolerant_mode, job_gc_check_interval_ms,
|
||||||
|
job_gc_timeout_ms)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||||||
@ -116,7 +136,9 @@ class DispatchServer(object):
|
|||||||
port=config.port,
|
port=config.port,
|
||||||
protocol=config.protocol,
|
protocol=config.protocol,
|
||||||
work_dir=config.work_dir,
|
work_dir=config.work_dir,
|
||||||
fault_tolerant_mode=config.fault_tolerant_mode)
|
fault_tolerant_mode=config.fault_tolerant_mode,
|
||||||
|
job_gc_check_interval_ms=config.job_gc_check_interval_ms,
|
||||||
|
job_gc_timeout_ms=config.job_gc_timeout_ms)
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
||||||
config_proto.SerializeToString())
|
config_proto.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
@ -193,9 +215,10 @@ class DispatchServer(object):
|
|||||||
|
|
||||||
@tf_export("data.experimental.service.WorkerConfig")
|
@tf_export("data.experimental.service.WorkerConfig")
|
||||||
class WorkerConfig(
|
class WorkerConfig(
|
||||||
collections.namedtuple(
|
collections.namedtuple("WorkerConfig", [
|
||||||
"WorkerConfig",
|
"dispatcher_address", "worker_address", "port", "protocol",
|
||||||
["dispatcher_address", "worker_address", "port", "protocol"])):
|
"heartbeat_interval_ms"
|
||||||
|
])):
|
||||||
"""Configuration class for tf.data service dispatchers.
|
"""Configuration class for tf.data service dispatchers.
|
||||||
|
|
||||||
Fields:
|
Fields:
|
||||||
@ -205,19 +228,29 @@ class WorkerConfig(
|
|||||||
connect to this worker.
|
connect to this worker.
|
||||||
port: Specifies the port to bind to. A value of 0 indicates that the worker
|
port: Specifies the port to bind to. A value of 0 indicates that the worker
|
||||||
can bind to any available port.
|
can bind to any available port.
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
protocol: Specifies the protocol to be used by the server.
|
||||||
Acceptable values include `"grpc" and "grpc+local"`.
|
Acceptable values include `"grpc" and "grpc+local"`.
|
||||||
|
heartbeat_interval_ms: How often the worker should heartbeat to the
|
||||||
|
dispatcher, in milliseconds. If not set, the runtime will select a
|
||||||
|
reasonable default. A higher value will reduce the load on the dispatcher,
|
||||||
|
while a lower value will reduce the time it takes to reclaim resources
|
||||||
|
from finished jobs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
dispatcher_address,
|
dispatcher_address,
|
||||||
worker_address=None,
|
worker_address=None,
|
||||||
port=0,
|
port=0,
|
||||||
protocol="grpc"):
|
protocol="grpc",
|
||||||
worker_address = ("localhost:%port%"
|
heartbeat_interval_ms=None):
|
||||||
if worker_address is None else worker_address)
|
if worker_address is None:
|
||||||
return super(WorkerConfig, cls).__new__(cls, dispatcher_address,
|
worker_address = "localhost:%port%"
|
||||||
worker_address, port, protocol)
|
if heartbeat_interval_ms is None:
|
||||||
|
heartbeat_interval_ms = 30 * 1000 # 30 seconds
|
||||||
|
|
||||||
|
return super(WorkerConfig,
|
||||||
|
cls).__new__(cls, dispatcher_address, worker_address, port,
|
||||||
|
protocol, heartbeat_interval_ms)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||||||
@ -264,7 +297,8 @@ class WorkerServer(object):
|
|||||||
dispatcher_address=config.dispatcher_address,
|
dispatcher_address=config.dispatcher_address,
|
||||||
worker_address=config.worker_address,
|
worker_address=config.worker_address,
|
||||||
port=config.port,
|
port=config.port,
|
||||||
protocol=config.protocol)
|
protocol=config.protocol,
|
||||||
|
heartbeat_interval_ms=config.heartbeat_interval_ms)
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||||
config_proto.SerializeToString())
|
config_proto.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
@ -317,3 +351,7 @@ class WorkerServer(object):
|
|||||||
The returned string will be in the form address:port, e.g. "localhost:1000".
|
The returned string will be in the form address:port, e.g. "localhost:1000".
|
||||||
"""
|
"""
|
||||||
return "localhost:{0}".format(self._server.bound_port())
|
return "localhost:{0}".format(self._server.bound_port())
|
||||||
|
|
||||||
|
def _num_tasks(self):
|
||||||
|
"""Returns the number of tasks currently being executed on the worker."""
|
||||||
|
return self._server.num_tasks()
|
||||||
|
@ -50,7 +50,14 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||||||
.def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop)
|
.def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop)
|
||||||
.def("join", &tensorflow::data::WorkerGrpcDataServer::Join,
|
.def("join", &tensorflow::data::WorkerGrpcDataServer::Join,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort)
|
||||||
|
.def("num_tasks",
|
||||||
|
[](tensorflow::data::WorkerGrpcDataServer* server) -> int {
|
||||||
|
int num_tasks;
|
||||||
|
tensorflow::Status status = server->NumTasks(&num_tasks);
|
||||||
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
|
return num_tasks;
|
||||||
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TF_DATA_NewDispatchServer",
|
"TF_DATA_NewDispatchServer",
|
||||||
|
@ -101,7 +101,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
name="",
|
name="",
|
||||||
port=0,
|
port=0,
|
||||||
work_dir=None,
|
work_dir=None,
|
||||||
fault_tolerant_mode=True):
|
fault_tolerant_mode=True,
|
||||||
|
job_gc_check_interval_ms=None,
|
||||||
|
job_gc_timeout_ms=None):
|
||||||
# If a test starts multiple independent dispatch servers, it should give
|
# If a test starts multiple independent dispatch servers, it should give
|
||||||
# them different `name` values.
|
# them different `name` values.
|
||||||
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
||||||
@ -110,13 +112,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
server_lib.DispatcherConfig(
|
server_lib.DispatcherConfig(
|
||||||
port=port,
|
port=port,
|
||||||
work_dir=work_dir,
|
work_dir=work_dir,
|
||||||
fault_tolerant_mode=fault_tolerant_mode))
|
fault_tolerant_mode=fault_tolerant_mode,
|
||||||
|
job_gc_check_interval_ms=job_gc_check_interval_ms,
|
||||||
|
job_gc_timeout_ms=job_gc_timeout_ms))
|
||||||
|
|
||||||
def start_worker_server(self, dispatcher, port=0):
|
def start_worker_server(self, dispatcher, port=0):
|
||||||
return server_lib.WorkerServer(
|
return server_lib.WorkerServer(
|
||||||
server_lib.WorkerConfig(
|
server_lib.WorkerConfig(
|
||||||
dispatcher_address=_address_from_target(dispatcher.target),
|
dispatcher_address=_address_from_target(dispatcher.target),
|
||||||
port=port))
|
port=port,
|
||||||
|
heartbeat_interval_ms=200))
|
||||||
|
|
||||||
def restart_dispatcher(self, dispatcher):
|
def restart_dispatcher(self, dispatcher):
|
||||||
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
||||||
@ -535,6 +540,47 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
results.append(elem.numpy())
|
results.append(elem.numpy())
|
||||||
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
|
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.eager_only_combinations(),
|
||||||
|
combinations.combine(job_name=[None, "test"])))
|
||||||
|
def testGcUnusedJob(self, job_name):
|
||||||
|
dispatcher = self.start_dispatch_server(
|
||||||
|
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
|
||||||
|
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
||||||
|
num_elements = 10
|
||||||
|
ds = _make_distributed_range_dataset(
|
||||||
|
num_elements, dispatcher, job_name=job_name)
|
||||||
|
it = iter(ds)
|
||||||
|
self.assertEqual(0, next(it).numpy())
|
||||||
|
self.assertEqual(1, worker._num_tasks())
|
||||||
|
del it
|
||||||
|
while worker._num_tasks() > 0:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDontGcUsedJob(self):
|
||||||
|
dispatcher = self.start_dispatch_server(
|
||||||
|
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
|
||||||
|
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
||||||
|
num_elements = 10
|
||||||
|
it1 = iter(
|
||||||
|
_make_distributed_range_dataset(
|
||||||
|
num_elements, dispatcher, job_name="test1"))
|
||||||
|
it2 = iter(
|
||||||
|
_make_distributed_range_dataset(
|
||||||
|
num_elements, dispatcher, job_name="test2"))
|
||||||
|
it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable
|
||||||
|
_make_distributed_range_dataset(
|
||||||
|
num_elements, dispatcher, job_name="test2"))
|
||||||
|
self.assertEqual(2, worker._num_tasks())
|
||||||
|
del it1
|
||||||
|
del it2
|
||||||
|
# Check that only the first job is gced. The second job will not be gced
|
||||||
|
# because there is still an outstanding iterator for it.
|
||||||
|
while worker._num_tasks() > 1:
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.assertEqual(1, worker._num_tasks())
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testApplyDeterminismOption(self):
|
def testApplyDeterminismOption(self):
|
||||||
elements = list(range(10))
|
elements = list(range(10))
|
||||||
|
@ -7,6 +7,14 @@ tf_class {
|
|||||||
name: "fault_tolerant_mode"
|
name: "fault_tolerant_mode"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "job_gc_check_interval_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "job_gc_timeout_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "port"
|
name: "port"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
name: "dispatcher_address"
|
name: "dispatcher_address"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "heartbeat_interval_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "port"
|
name: "port"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,14 @@ tf_class {
|
|||||||
name: "fault_tolerant_mode"
|
name: "fault_tolerant_mode"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "job_gc_check_interval_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "job_gc_timeout_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "port"
|
name: "port"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
name: "dispatcher_address"
|
name: "dispatcher_address"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "heartbeat_interval_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "port"
|
name: "port"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user