[tf.data service] Use the journal to keep track of registered workers.
As part of this change, we stop using integer worker ids, and instead use workers addresses as their identifiers. PiperOrigin-RevId: 324927652 Change-Id: If6ef5a08aac6bf32cc603108f9045887619488f1
This commit is contained in:
parent
3f45c33ba5
commit
f9bb629525
@ -19,6 +19,15 @@ message TaskDef {
|
||||
int64 job_id = 4;
|
||||
}
|
||||
|
||||
message TaskInfo {
|
||||
// The address of the worker processing the task.
|
||||
string worker_address = 1;
|
||||
// The task id.
|
||||
int64 task_id = 2;
|
||||
// The id of the job that the task is part of.
|
||||
int64 job_id = 3;
|
||||
}
|
||||
|
||||
enum ProcessingModeDef {
|
||||
// Each tf.data worker processes an entire epoch.
|
||||
PARALLEL_EPOCHS = 0;
|
||||
|
||||
@ -10,8 +10,6 @@ message RegisterWorkerRequest {
|
||||
}
|
||||
|
||||
message RegisterWorkerResponse {
|
||||
// An id for the worker.
|
||||
int64 worker_id = 1;
|
||||
// Tasks to begin processing.
|
||||
repeated TaskDef tasks = 2;
|
||||
}
|
||||
@ -24,8 +22,7 @@ message TaskProgress {
|
||||
}
|
||||
|
||||
message WorkerUpdateRequest {
|
||||
// The worker id that the update is for.
|
||||
int64 worker_id = 1;
|
||||
string worker_address = 1;
|
||||
repeated TaskProgress updates = 2;
|
||||
}
|
||||
|
||||
@ -75,13 +72,6 @@ message GetTasksRequest {
|
||||
int64 job_id = 1;
|
||||
}
|
||||
|
||||
message TaskInfo {
|
||||
// The address of the worker processing the task.
|
||||
string worker_address = 1;
|
||||
// The task id.
|
||||
int64 id = 2;
|
||||
}
|
||||
|
||||
message GetTasksResponse {
|
||||
// A list of all tasks for a job.
|
||||
repeated TaskInfo task_info = 1;
|
||||
|
||||
@ -46,6 +46,7 @@ namespace {
|
||||
constexpr char kJournalDir[] = "journal";
|
||||
|
||||
using Dataset = DispatcherState::Dataset;
|
||||
using Worker = DispatcherState::Worker;
|
||||
using NamedJobKey = DispatcherState::NamedJobKey;
|
||||
using Job = DispatcherState::Job;
|
||||
using Task = DispatcherState::Task;
|
||||
@ -77,10 +78,16 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::Start() {
|
||||
if (config_.work_dir().empty()) {
|
||||
if (!config_.fault_tolerant_mode()) {
|
||||
LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
|
||||
"not be able to recover its state on restart.";
|
||||
return Status::OK();
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
if (config_.work_dir().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"fault_tolerant_mode is True, but no work_dir is configured.");
|
||||
}
|
||||
Update update;
|
||||
bool end_of_journal = false;
|
||||
FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir()));
|
||||
@ -104,12 +111,16 @@ Status DataServiceDispatcherImpl::RegisterWorker(
|
||||
VLOG(3) << "Received register worker request";
|
||||
mutex_lock l(mu_);
|
||||
std::string worker_address = request->worker_address();
|
||||
if (!workers_.contains(worker_address)) {
|
||||
workers_[worker_address] =
|
||||
std::make_shared<Worker>(next_worker_id_++, worker_address);
|
||||
std::shared_ptr<const Worker> worker;
|
||||
Status s = state_.WorkerFromAddress(worker_address, &worker);
|
||||
if (errors::IsNotFound(s)) {
|
||||
Update update;
|
||||
update.mutable_register_worker()->set_worker_address(worker_address);
|
||||
TF_RETURN_IF_ERROR(Apply(update));
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
int64 worker_id = workers_[worker_address]->worker_id;
|
||||
response->set_worker_id(worker_id);
|
||||
|
||||
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
|
||||
// Allocate tasks to the worker.
|
||||
for (const auto& job : jobs) {
|
||||
@ -127,8 +138,7 @@ Status DataServiceDispatcherImpl::RegisterWorker(
|
||||
task_def->set_task_id(task->task_id);
|
||||
}
|
||||
|
||||
VLOG(1) << "Registered worker at address " << request->worker_address()
|
||||
<< " with id " << worker_id;
|
||||
VLOG(1) << "Registered worker at address " << request->worker_address();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -146,8 +156,7 @@ Status DataServiceDispatcherImpl::WorkerUpdate(
|
||||
continue;
|
||||
}
|
||||
Update update;
|
||||
FinishTaskUpdate* finish_task = update.mutable_finish_task();
|
||||
finish_task->set_task_id(task_id);
|
||||
update.mutable_finish_task()->set_task_id(task_id);
|
||||
TF_RETURN_IF_ERROR(Apply(update));
|
||||
VLOG(3) << "Task " << task_id << " from job " << task->job_id
|
||||
<< " completed";
|
||||
@ -310,10 +319,10 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
|
||||
std::shared_ptr<const Job> job,
|
||||
std::vector<std::shared_ptr<const Task>>* tasks)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
|
||||
tasks->clear();
|
||||
tasks->reserve(workers_.size());
|
||||
for (const auto& it : workers_) {
|
||||
std::shared_ptr<Worker> worker = it.second;
|
||||
tasks->reserve(workers.size());
|
||||
for (const auto& worker : workers) {
|
||||
std::shared_ptr<const Task> task;
|
||||
TF_RETURN_IF_ERROR(CreateTask(job, worker->address, &task));
|
||||
tasks->push_back(task);
|
||||
@ -345,10 +354,28 @@ Status DataServiceDispatcherImpl::AssignTasks(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
if (!worker->stub) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateWorkerStub(worker->address, config_.protocol(), &worker->stub));
|
||||
Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
|
||||
const std::string& worker_address, WorkerService::Stub** out_stub)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto it = worker_stubs_.find(worker_address);
|
||||
if (it != worker_stubs_.end()) {
|
||||
*out_stub = it->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateWorkerStub(worker_address, config_.protocol(), &stub));
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// A concurrent call could have already created the stub.
|
||||
auto& worker = worker_stubs_[worker_address];
|
||||
if (worker == nullptr) {
|
||||
worker = std::move(stub);
|
||||
}
|
||||
*out_stub = worker.get();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -359,25 +386,21 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
ProcessTaskRequest req;
|
||||
TaskDef* task_def = req.mutable_task();
|
||||
task_def->set_dataset_id(task->dataset_id);
|
||||
std::shared_ptr<Worker> worker;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
worker = workers_[task->worker_address];
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset));
|
||||
*task_def->mutable_dataset() = dataset->dataset_def;
|
||||
}
|
||||
if (!worker) {
|
||||
return errors::NotFound("No worker found for address ",
|
||||
task->worker_address);
|
||||
}
|
||||
task_def->set_task_id(task->task_id);
|
||||
ProcessTaskResponse resp;
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker.get()));
|
||||
grpc::Status s = worker->stub->ProcessTask(&client_ctx, req, &resp);
|
||||
WorkerService::Stub* stub;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, &stub));
|
||||
grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError(
|
||||
absl::StrCat("Failed to submit task to worker ", worker->address), s);
|
||||
absl::StrCat("Failed to submit task to worker ", task->worker_address),
|
||||
s);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -391,7 +414,8 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
|
||||
for (const auto& task : tasks) {
|
||||
TaskInfo* task_info = response->mutable_task_info()->Add();
|
||||
task_info->set_worker_address(task->worker_address);
|
||||
task_info->set_id(task->task_id);
|
||||
task_info->set_task_id(task->task_id);
|
||||
task_info->set_job_id(task->job_id);
|
||||
}
|
||||
std::shared_ptr<const Job> job;
|
||||
TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), &job));
|
||||
@ -405,13 +429,12 @@ Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
|
||||
GetWorkersResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(3) << "Enter GetWorkers";
|
||||
for (const auto& it : workers_) {
|
||||
std::shared_ptr<Worker> worker = it.second;
|
||||
std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
|
||||
for (const auto& worker : workers) {
|
||||
WorkerInfo* info = response->add_workers();
|
||||
info->set_address(worker->address);
|
||||
info->set_id(worker->worker_id);
|
||||
}
|
||||
VLOG(3) << "Returning list of " << workers_.size()
|
||||
VLOG(3) << "Returning list of " << response->workers_size()
|
||||
<< " workers from GetWorkers";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -71,21 +71,15 @@ class DataServiceDispatcherImpl {
|
||||
GetWorkersResponse* response);
|
||||
|
||||
private:
|
||||
struct Worker {
|
||||
Worker(int64 worker_id, const std::string& address)
|
||||
: worker_id(worker_id), address(address) {}
|
||||
|
||||
const int64 worker_id;
|
||||
const std::string address;
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
};
|
||||
|
||||
// Registers a dataset with the given fingerprint, storing the new dataset's
|
||||
// id in `*dataset-id`.
|
||||
Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
|
||||
int64* dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Initializes a workers stub, if it hasn't been initialized already.
|
||||
Status EnsureWorkerStubInitialized(Worker* worker);
|
||||
// Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
|
||||
// stub and stores it in `worker_stubs_`.
|
||||
Status GetOrCreateWorkerStub(const std::string& worker_address,
|
||||
WorkerService::Stub** out_stub)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Creates a job and stores it in `*job`. This method updates the
|
||||
// dispatcher state with the new job, but does not assign tasks to workers.
|
||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
@ -128,12 +122,11 @@ class DataServiceDispatcherImpl {
|
||||
|
||||
mutex mu_;
|
||||
|
||||
int64 next_worker_id_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Registered workers, keyed by their addresses.
|
||||
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_
|
||||
TF_GUARDED_BY(mu_);
|
||||
// Cached worker stubs for communicating with workers.
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
|
||||
worker_stubs_ TF_GUARDED_BY(mu_);
|
||||
|
||||
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
@ -30,6 +30,9 @@ Status DispatcherState::Apply(Update update) {
|
||||
case Update::kRegisterDataset:
|
||||
RegisterDataset(update.register_dataset());
|
||||
break;
|
||||
case Update::kRegisterWorker:
|
||||
RegisterWorker(update.register_worker());
|
||||
break;
|
||||
case Update::kCreateJob:
|
||||
CreateJob(update.create_job());
|
||||
break;
|
||||
@ -59,6 +62,13 @@ void DispatcherState::RegisterDataset(
|
||||
next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
|
||||
}
|
||||
|
||||
void DispatcherState::RegisterWorker(
|
||||
const RegisterWorkerUpdate& register_worker) {
|
||||
std::string address = register_worker.worker_address();
|
||||
DCHECK(!workers_.contains(address));
|
||||
workers_[address] = std::make_shared<Worker>(address);
|
||||
}
|
||||
|
||||
void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
|
||||
int64 job_id = create_job.job_id();
|
||||
absl::optional<NamedJobKey> named_job_key;
|
||||
@ -71,6 +81,7 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
|
||||
named_job_key);
|
||||
DCHECK(!jobs_.contains(job_id));
|
||||
jobs_[job_id] = job;
|
||||
tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
|
||||
if (named_job_key.has_value()) {
|
||||
DCHECK(!named_jobs_.contains(named_job_key.value()));
|
||||
named_jobs_[named_job_key.value()] = job;
|
||||
@ -129,6 +140,26 @@ Status DispatcherState::DatasetFromFingerprint(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DispatcherState::WorkerFromAddress(
|
||||
const std::string& address, std::shared_ptr<const Worker>* worker) const {
|
||||
auto it = workers_.find(address);
|
||||
if (it == workers_.end()) {
|
||||
return errors::NotFound("Worker with address ", address, " not found.");
|
||||
}
|
||||
*worker = it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<const DispatcherState::Worker>>
|
||||
DispatcherState::ListWorkers() const {
|
||||
std::vector<std::shared_ptr<const Worker>> workers;
|
||||
workers.reserve(workers_.size());
|
||||
for (const auto& it : workers_) {
|
||||
workers.push_back(it.second);
|
||||
}
|
||||
return workers;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<const DispatcherState::Job>>
|
||||
DispatcherState::ListJobs() {
|
||||
std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;
|
||||
|
||||
@ -48,11 +48,6 @@ namespace data {
|
||||
// DispatcherImpl and for providing DispatcherImpl with read-only access to
|
||||
// the state.
|
||||
//
|
||||
// Note that not all state needs to be journaled, and in general we journal
|
||||
// as little state as possible. For example, worker and task state doesn't need
|
||||
// to be journaled because we can recover that information from workers when
|
||||
// they reconnect to a restarted dispatcher.
|
||||
//
|
||||
// DispatcherState is thread-compatible but not thread-safe.
|
||||
class DispatcherState {
|
||||
public:
|
||||
@ -65,7 +60,8 @@ class DispatcherState {
|
||||
|
||||
// A dataset registered with the dispatcher.
|
||||
struct Dataset {
|
||||
Dataset(int64 dataset_id, int64 fingerprint, const DatasetDef& dataset_def)
|
||||
explicit Dataset(int64 dataset_id, int64 fingerprint,
|
||||
const DatasetDef& dataset_def)
|
||||
: dataset_id(dataset_id),
|
||||
fingerprint(fingerprint),
|
||||
dataset_def(dataset_def) {}
|
||||
@ -75,10 +71,17 @@ class DispatcherState {
|
||||
const DatasetDef dataset_def;
|
||||
};
|
||||
|
||||
// A worker registered with the dispatcher.
|
||||
struct Worker {
|
||||
explicit Worker(const std::string& address) : address(address) {}
|
||||
|
||||
const std::string address;
|
||||
};
|
||||
|
||||
// A key for identifying a named job. The key contains a user-specified name,
|
||||
// as well as an index describing which iteration of the job we are on.
|
||||
struct NamedJobKey {
|
||||
NamedJobKey(absl::string_view name, int64 index)
|
||||
explicit NamedJobKey(absl::string_view name, int64 index)
|
||||
: name(name), index(index) {}
|
||||
|
||||
friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
|
||||
@ -96,8 +99,8 @@ class DispatcherState {
|
||||
|
||||
// A job for processing a dataset.
|
||||
struct Job {
|
||||
Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<NamedJobKey> named_job_key)
|
||||
explicit Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<NamedJobKey> named_job_key)
|
||||
: job_id(job_id),
|
||||
dataset_id(dataset_id),
|
||||
processing_mode(processing_mode),
|
||||
@ -111,8 +114,8 @@ class DispatcherState {
|
||||
};
|
||||
|
||||
struct Task {
|
||||
Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||
const std::string& worker_address)
|
||||
explicit Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||
const std::string& worker_address)
|
||||
: task_id(task_id),
|
||||
job_id(job_id),
|
||||
dataset_id(dataset_id),
|
||||
@ -134,6 +137,12 @@ class DispatcherState {
|
||||
Status DatasetFromFingerprint(uint64 fingerprint,
|
||||
std::shared_ptr<const Dataset>* dataset) const;
|
||||
|
||||
// Gets a worker by address. Returns NOT_FOUND if there is no such worker.
|
||||
Status WorkerFromAddress(const std::string& address,
|
||||
std::shared_ptr<const Worker>* worker) const;
|
||||
// Lists all workers registered with the dispatcher.
|
||||
std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
|
||||
|
||||
// Returns the next available job id.
|
||||
int64 NextAvailableJobId() const;
|
||||
// Returns a list of all jobs.
|
||||
@ -153,8 +162,8 @@ class DispatcherState {
|
||||
std::vector<std::shared_ptr<const Task>>* tasks) const;
|
||||
|
||||
private:
|
||||
// Registers a dataset. The dataset must not already be registered.
|
||||
void RegisterDataset(const RegisterDatasetUpdate& register_dataset);
|
||||
void RegisterWorker(const RegisterWorkerUpdate& register_worker);
|
||||
void CreateJob(const CreateJobUpdate& create_job);
|
||||
void CreateTask(const CreateTaskUpdate& create_task);
|
||||
void FinishTask(const FinishTaskUpdate& finish_task);
|
||||
@ -166,6 +175,9 @@ class DispatcherState {
|
||||
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>>
|
||||
datasets_by_fingerprint_;
|
||||
|
||||
// Registered workers, keyed by address.
|
||||
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
|
||||
|
||||
int64 next_available_job_id_ = 0;
|
||||
// Jobs, keyed by job ids.
|
||||
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
|
||||
|
||||
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/data/service/dispatcher_state.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/journal.h"
|
||||
#include "tensorflow/core/data/service/journal.pb.h"
|
||||
@ -27,9 +29,11 @@ namespace data {
|
||||
|
||||
namespace {
|
||||
using Dataset = DispatcherState::Dataset;
|
||||
using Worker = DispatcherState::Worker;
|
||||
using NamedJobKey = DispatcherState::NamedJobKey;
|
||||
using Job = DispatcherState::Job;
|
||||
using Task = DispatcherState::Task;
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::SizeIs;
|
||||
|
||||
Status RegisterDatasetWithIdAndFingerprint(int64 id, uint64 fingerprint,
|
||||
@ -42,6 +46,13 @@ Status RegisterDatasetWithIdAndFingerprint(int64 id, uint64 fingerprint,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterWorker(std::string worker_address, DispatcherState* state) {
|
||||
Update update;
|
||||
update.mutable_register_worker()->set_worker_address(worker_address);
|
||||
TF_RETURN_IF_ERROR(state->Apply(update));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateAnonymousJob(int64 job_id, int64 dataset_id,
|
||||
DispatcherState* state) {
|
||||
Update update;
|
||||
@ -98,12 +109,12 @@ TEST(DispatcherState, RegisterDataset) {
|
||||
{
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, &dataset));
|
||||
EXPECT_EQ(id, dataset->dataset_id);
|
||||
EXPECT_EQ(dataset->dataset_id, id);
|
||||
}
|
||||
{
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_EXPECT_OK(state.DatasetFromId(id, &dataset));
|
||||
EXPECT_EQ(fingerprint, dataset->fingerprint);
|
||||
EXPECT_EQ(dataset->fingerprint, fingerprint);
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,10 +137,46 @@ TEST(DispatcherState, NextAvailableDatasetId) {
|
||||
int64 id = state.NextAvailableDatasetId();
|
||||
uint64 fingerprint = 20;
|
||||
TF_EXPECT_OK(RegisterDatasetWithIdAndFingerprint(id, fingerprint, &state));
|
||||
EXPECT_NE(id, state.NextAvailableDatasetId());
|
||||
EXPECT_NE(state.NextAvailableDatasetId(), id);
|
||||
EXPECT_EQ(state.NextAvailableDatasetId(), state.NextAvailableDatasetId());
|
||||
}
|
||||
|
||||
TEST(DispatcherState, RegisterWorker) {
|
||||
DispatcherState state;
|
||||
std::string address = "test_worker_address";
|
||||
TF_EXPECT_OK(RegisterWorker(address, &state));
|
||||
std::shared_ptr<const Worker> worker;
|
||||
TF_EXPECT_OK(state.WorkerFromAddress(address, &worker));
|
||||
EXPECT_EQ(worker->address, address);
|
||||
}
|
||||
|
||||
TEST(DispatcherState, ListWorkers) {
|
||||
DispatcherState state;
|
||||
std::string address_1 = "address_1";
|
||||
std::string address_2 = "address_2";
|
||||
{
|
||||
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
|
||||
EXPECT_THAT(workers, IsEmpty());
|
||||
}
|
||||
TF_EXPECT_OK(RegisterWorker(address_1, &state));
|
||||
{
|
||||
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
|
||||
EXPECT_THAT(workers, SizeIs(1));
|
||||
}
|
||||
TF_EXPECT_OK(RegisterWorker(address_2, &state));
|
||||
{
|
||||
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
|
||||
EXPECT_THAT(workers, SizeIs(2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DispatcherState, MissingWorker) {
|
||||
DispatcherState state;
|
||||
std::shared_ptr<const Worker> worker;
|
||||
Status s = state.WorkerFromAddress("test_worker_address", &worker);
|
||||
EXPECT_EQ(s.code(), error::NOT_FOUND);
|
||||
}
|
||||
|
||||
TEST(DispatcherState, UnknownUpdate) {
|
||||
DispatcherState state;
|
||||
Update update;
|
||||
@ -146,8 +193,11 @@ TEST(DispatcherState, AnonymousJob) {
|
||||
std::shared_ptr<const Job> job;
|
||||
TF_EXPECT_OK(state.JobFromId(job_id, &job));
|
||||
EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
|
||||
EXPECT_EQ(dataset_id, job->dataset_id);
|
||||
EXPECT_EQ(job_id, job->job_id);
|
||||
EXPECT_EQ(job->dataset_id, dataset_id);
|
||||
EXPECT_EQ(job->job_id, job_id);
|
||||
std::vector<std::shared_ptr<const Task>> tasks;
|
||||
TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
|
||||
EXPECT_THAT(tasks, IsEmpty());
|
||||
EXPECT_FALSE(job->finished);
|
||||
}
|
||||
|
||||
@ -161,8 +211,8 @@ TEST(DispatcherState, NamedJob) {
|
||||
std::shared_ptr<const Job> job;
|
||||
TF_EXPECT_OK(state.NamedJobByKey(named_job_key, &job));
|
||||
EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
|
||||
EXPECT_EQ(dataset_id, job->dataset_id);
|
||||
EXPECT_EQ(job_id, job->job_id);
|
||||
EXPECT_EQ(job->dataset_id, dataset_id);
|
||||
EXPECT_EQ(job->job_id, job_id);
|
||||
EXPECT_FALSE(job->finished);
|
||||
}
|
||||
|
||||
@ -179,10 +229,10 @@ TEST(DispatcherState, CreateTask) {
|
||||
{
|
||||
std::shared_ptr<const Task> task;
|
||||
TF_EXPECT_OK(state.TaskFromId(task_id, &task));
|
||||
EXPECT_EQ(task_id, task->task_id);
|
||||
EXPECT_EQ(job_id, task->job_id);
|
||||
EXPECT_EQ(dataset_id, task->dataset_id);
|
||||
EXPECT_EQ(worker_address, task->worker_address);
|
||||
EXPECT_EQ(task->task_id, task_id);
|
||||
EXPECT_EQ(task->job_id, job_id);
|
||||
EXPECT_EQ(task->dataset_id, dataset_id);
|
||||
EXPECT_EQ(task->worker_address, worker_address);
|
||||
}
|
||||
{
|
||||
std::vector<std::shared_ptr<const Task>> tasks;
|
||||
@ -207,7 +257,7 @@ TEST(DispatcherState, CreateTasksForSameJob) {
|
||||
{
|
||||
std::vector<std::shared_ptr<const Task>> tasks;
|
||||
TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
|
||||
EXPECT_EQ(2, tasks.size());
|
||||
EXPECT_THAT(tasks, SizeIs(2));
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,12 +279,12 @@ TEST(DispatcherState, CreateTasksForDifferentJobs) {
|
||||
{
|
||||
std::vector<std::shared_ptr<const Task>> tasks;
|
||||
TF_EXPECT_OK(state.TasksForJob(job_id_1, &tasks));
|
||||
EXPECT_EQ(1, tasks.size());
|
||||
EXPECT_THAT(tasks, SizeIs(1));
|
||||
}
|
||||
{
|
||||
std::vector<std::shared_ptr<const Task>> tasks;
|
||||
TF_EXPECT_OK(state.TasksForJob(job_id_2, &tasks));
|
||||
EXPECT_EQ(1, tasks.size());
|
||||
EXPECT_THAT(tasks, SizeIs(1));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ Status FileJournalWriter::EnsureInitialized() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FileJournalWriter::Write(Update update) {
|
||||
Status FileJournalWriter::Write(const Update& update) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
std::string s = update.SerializeAsString();
|
||||
if (s.empty()) {
|
||||
|
||||
@ -32,7 +32,7 @@ class JournalWriter {
|
||||
public:
|
||||
virtual ~JournalWriter() = default;
|
||||
// Writes and syncs an update to the journal.
|
||||
virtual Status Write(Update update) = 0;
|
||||
virtual Status Write(const Update& update) = 0;
|
||||
};
|
||||
|
||||
// FileJournalWriter is not thread-safe, requiring external synchronization when
|
||||
@ -46,7 +46,7 @@ class FileJournalWriter : public JournalWriter {
|
||||
FileJournalWriter(const FileJournalWriter&) = delete;
|
||||
FileJournalWriter& operator=(const FileJournalWriter&) = delete;
|
||||
|
||||
Status Write(Update update) override;
|
||||
Status Write(const Update& update) override;
|
||||
|
||||
private:
|
||||
// Initializes the writer if it is not yet initialized.
|
||||
|
||||
@ -10,6 +10,7 @@ import "tensorflow/core/data/service/common.proto";
|
||||
message Update {
|
||||
oneof update_type {
|
||||
RegisterDatasetUpdate register_dataset = 1;
|
||||
RegisterWorkerUpdate register_worker = 5;
|
||||
CreateJobUpdate create_job = 2;
|
||||
CreateTaskUpdate create_task = 3;
|
||||
FinishTaskUpdate finish_task = 4;
|
||||
@ -22,6 +23,10 @@ message RegisterDatasetUpdate {
|
||||
uint64 fingerprint = 3;
|
||||
}
|
||||
|
||||
message RegisterWorkerUpdate {
|
||||
string worker_address = 1;
|
||||
}
|
||||
|
||||
message NamedJobKeyDef {
|
||||
string name = 1;
|
||||
int64 index = 2;
|
||||
|
||||
@ -197,8 +197,6 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (const TaskDef& task : resp.tasks()) {
|
||||
TF_RETURN_IF_ERROR(ProcessTaskInternal(task));
|
||||
}
|
||||
worker_id_ = resp.worker_id();
|
||||
VLOG(3) << "Registered worker with id " << resp.worker_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -207,7 +205,6 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
<< " task updates to dispatcher";
|
||||
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||
WorkerUpdateRequest req;
|
||||
req.set_worker_id(worker_id_);
|
||||
for (int task_id : pending_completed_tasks_) {
|
||||
TaskProgress* update = req.add_updates();
|
||||
update->set_task_id(task_id);
|
||||
|
||||
@ -338,7 +338,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
absl::flat_hash_map<int64, TaskInfo> task_id_to_task;
|
||||
for (auto& task : tasks) {
|
||||
task_id_to_task[task.id()] = task;
|
||||
task_id_to_task[task.task_id()] = task;
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
job_finished_ = job_finished;
|
||||
@ -371,8 +371,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
get_next_cv_.notify_all();
|
||||
continue;
|
||||
}
|
||||
tasks_.push_back(std::make_shared<Task>(
|
||||
task_info.id(), task_info.worker_address(), std::move(worker)));
|
||||
tasks_.push_back(std::make_shared<Task>(task_info.task_id(),
|
||||
task_info.worker_address(),
|
||||
std::move(worker)));
|
||||
}
|
||||
if (dataset()->max_outstanding_requests_ == model::kAutotune) {
|
||||
// Adjust max_outstanding_requests to account for newly added tasks.
|
||||
|
||||
@ -12,6 +12,9 @@ message DispatcherConfig {
|
||||
// An optional work directory to use for storing dispatcher state, and for
|
||||
// recovering during restarts.
|
||||
string work_dir = 3;
|
||||
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
||||
// across restarts.
|
||||
bool fault_tolerant_mode = 4;
|
||||
}
|
||||
|
||||
// Configuration for a tf.data service WorkerServer.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user