[tf.data service] Use the journal to keep track of registered workers.

As part of this change, we stop using integer worker ids, and instead use workers addresses as their identifiers.

PiperOrigin-RevId: 324927652
Change-Id: If6ef5a08aac6bf32cc603108f9045887619488f1
This commit is contained in:
Andrew Audibert 2020-08-04 17:46:36 -07:00 committed by TensorFlower Gardener
parent 3f45c33ba5
commit f9bb629525
13 changed files with 206 additions and 92 deletions

View File

@ -19,6 +19,15 @@ message TaskDef {
int64 job_id = 4;
}
message TaskInfo {
// The address of the worker processing the task.
string worker_address = 1;
// The task id.
int64 task_id = 2;
// The id of the job that the task is part of.
int64 job_id = 3;
}
enum ProcessingModeDef {
// Each tf.data worker processes an entire epoch.
PARALLEL_EPOCHS = 0;

View File

@ -10,8 +10,6 @@ message RegisterWorkerRequest {
}
message RegisterWorkerResponse {
// An id for the worker.
int64 worker_id = 1;
// Tasks to begin processing.
repeated TaskDef tasks = 2;
}
@ -24,8 +22,7 @@ message TaskProgress {
}
message WorkerUpdateRequest {
// The worker id that the update is for.
int64 worker_id = 1;
string worker_address = 1;
repeated TaskProgress updates = 2;
}
@ -75,13 +72,6 @@ message GetTasksRequest {
int64 job_id = 1;
}
message TaskInfo {
// The address of the worker processing the task.
string worker_address = 1;
// The task id.
int64 id = 2;
}
message GetTasksResponse {
// A list of all tasks for a job.
repeated TaskInfo task_info = 1;

View File

@ -46,6 +46,7 @@ namespace {
constexpr char kJournalDir[] = "journal";
using Dataset = DispatcherState::Dataset;
using Worker = DispatcherState::Worker;
using NamedJobKey = DispatcherState::NamedJobKey;
using Job = DispatcherState::Job;
using Task = DispatcherState::Task;
@ -77,10 +78,16 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
}
Status DataServiceDispatcherImpl::Start() {
if (config_.work_dir().empty()) {
if (!config_.fault_tolerant_mode()) {
LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
"not be able to recover its state on restart.";
return Status::OK();
}
mutex_lock l(mu_);
if (config_.work_dir().empty()) {
return errors::InvalidArgument(
"fault_tolerant_mode is True, but no work_dir is configured.");
}
Update update;
bool end_of_journal = false;
FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir()));
@ -104,12 +111,16 @@ Status DataServiceDispatcherImpl::RegisterWorker(
VLOG(3) << "Received register worker request";
mutex_lock l(mu_);
std::string worker_address = request->worker_address();
if (!workers_.contains(worker_address)) {
workers_[worker_address] =
std::make_shared<Worker>(next_worker_id_++, worker_address);
std::shared_ptr<const Worker> worker;
Status s = state_.WorkerFromAddress(worker_address, &worker);
if (errors::IsNotFound(s)) {
Update update;
update.mutable_register_worker()->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(Apply(update));
} else if (!s.ok()) {
return s;
}
int64 worker_id = workers_[worker_address]->worker_id;
response->set_worker_id(worker_id);
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
// Allocate tasks to the worker.
for (const auto& job : jobs) {
@ -127,8 +138,7 @@ Status DataServiceDispatcherImpl::RegisterWorker(
task_def->set_task_id(task->task_id);
}
VLOG(1) << "Registered worker at address " << request->worker_address()
<< " with id " << worker_id;
VLOG(1) << "Registered worker at address " << request->worker_address();
return Status::OK();
}
@ -146,8 +156,7 @@ Status DataServiceDispatcherImpl::WorkerUpdate(
continue;
}
Update update;
FinishTaskUpdate* finish_task = update.mutable_finish_task();
finish_task->set_task_id(task_id);
update.mutable_finish_task()->set_task_id(task_id);
TF_RETURN_IF_ERROR(Apply(update));
VLOG(3) << "Task " << task_id << " from job " << task->job_id
<< " completed";
@ -310,10 +319,10 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
std::shared_ptr<const Job> job,
std::vector<std::shared_ptr<const Task>>* tasks)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
tasks->clear();
tasks->reserve(workers_.size());
for (const auto& it : workers_) {
std::shared_ptr<Worker> worker = it.second;
tasks->reserve(workers.size());
for (const auto& worker : workers) {
std::shared_ptr<const Task> task;
TF_RETURN_IF_ERROR(CreateTask(job, worker->address, &task));
tasks->push_back(task);
@ -345,10 +354,28 @@ Status DataServiceDispatcherImpl::AssignTasks(
return Status::OK();
}
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
if (!worker->stub) {
TF_RETURN_IF_ERROR(
CreateWorkerStub(worker->address, config_.protocol(), &worker->stub));
Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
const std::string& worker_address, WorkerService::Stub** out_stub)
LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
auto it = worker_stubs_.find(worker_address);
if (it != worker_stubs_.end()) {
*out_stub = it->second.get();
return Status::OK();
}
}
std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(
CreateWorkerStub(worker_address, config_.protocol(), &stub));
{
mutex_lock l(mu_);
// A concurrent call could have already created the stub.
auto& worker = worker_stubs_[worker_address];
if (worker == nullptr) {
worker = std::move(stub);
}
*out_stub = worker.get();
}
return Status::OK();
}
@ -359,25 +386,21 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
ProcessTaskRequest req;
TaskDef* task_def = req.mutable_task();
task_def->set_dataset_id(task->dataset_id);
std::shared_ptr<Worker> worker;
{
mutex_lock l(mu_);
worker = workers_[task->worker_address];
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset));
*task_def->mutable_dataset() = dataset->dataset_def;
}
if (!worker) {
return errors::NotFound("No worker found for address ",
task->worker_address);
}
task_def->set_task_id(task->task_id);
ProcessTaskResponse resp;
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker.get()));
grpc::Status s = worker->stub->ProcessTask(&client_ctx, req, &resp);
WorkerService::Stub* stub;
TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, &stub));
grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError(
absl::StrCat("Failed to submit task to worker ", worker->address), s);
absl::StrCat("Failed to submit task to worker ", task->worker_address),
s);
}
return Status::OK();
}
@ -391,7 +414,8 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
for (const auto& task : tasks) {
TaskInfo* task_info = response->mutable_task_info()->Add();
task_info->set_worker_address(task->worker_address);
task_info->set_id(task->task_id);
task_info->set_task_id(task->task_id);
task_info->set_job_id(task->job_id);
}
std::shared_ptr<const Job> job;
TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), &job));
@ -405,13 +429,12 @@ Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) {
mutex_lock l(mu_);
VLOG(3) << "Enter GetWorkers";
for (const auto& it : workers_) {
std::shared_ptr<Worker> worker = it.second;
std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
for (const auto& worker : workers) {
WorkerInfo* info = response->add_workers();
info->set_address(worker->address);
info->set_id(worker->worker_id);
}
VLOG(3) << "Returning list of " << workers_.size()
VLOG(3) << "Returning list of " << response->workers_size()
<< " workers from GetWorkers";
return Status::OK();
}

View File

@ -71,21 +71,15 @@ class DataServiceDispatcherImpl {
GetWorkersResponse* response);
private:
struct Worker {
Worker(int64 worker_id, const std::string& address)
: worker_id(worker_id), address(address) {}
const int64 worker_id;
const std::string address;
std::unique_ptr<WorkerService::Stub> stub;
};
// Registers a dataset with the given fingerprint, storing the new dataset's
// id in `*dataset-id`.
Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
int64* dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Initializes a workers stub, if it hasn't been initialized already.
Status EnsureWorkerStubInitialized(Worker* worker);
// Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
// stub and stores it in `worker_stubs_`.
Status GetOrCreateWorkerStub(const std::string& worker_address,
WorkerService::Stub** out_stub)
LOCKS_EXCLUDED(mu_);
// Creates a job and stores it in `*job`. This method updates the
// dispatcher state with the new job, but does not assign tasks to workers.
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
@ -128,12 +122,11 @@ class DataServiceDispatcherImpl {
mutex mu_;
int64 next_worker_id_ TF_GUARDED_BY(mu_) = 0;
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
// Registered workers, keyed by their addresses.
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_
TF_GUARDED_BY(mu_);
// Cached worker stubs for communicating with workers.
absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
worker_stubs_ TF_GUARDED_BY(mu_);
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
TF_GUARDED_BY(mu_);

View File

@ -30,6 +30,9 @@ Status DispatcherState::Apply(Update update) {
case Update::kRegisterDataset:
RegisterDataset(update.register_dataset());
break;
case Update::kRegisterWorker:
RegisterWorker(update.register_worker());
break;
case Update::kCreateJob:
CreateJob(update.create_job());
break;
@ -59,6 +62,13 @@ void DispatcherState::RegisterDataset(
next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
}
void DispatcherState::RegisterWorker(
const RegisterWorkerUpdate& register_worker) {
std::string address = register_worker.worker_address();
DCHECK(!workers_.contains(address));
workers_[address] = std::make_shared<Worker>(address);
}
void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
int64 job_id = create_job.job_id();
absl::optional<NamedJobKey> named_job_key;
@ -71,6 +81,7 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
named_job_key);
DCHECK(!jobs_.contains(job_id));
jobs_[job_id] = job;
tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
if (named_job_key.has_value()) {
DCHECK(!named_jobs_.contains(named_job_key.value()));
named_jobs_[named_job_key.value()] = job;
@ -129,6 +140,26 @@ Status DispatcherState::DatasetFromFingerprint(
return Status::OK();
}
Status DispatcherState::WorkerFromAddress(
const std::string& address, std::shared_ptr<const Worker>* worker) const {
auto it = workers_.find(address);
if (it == workers_.end()) {
return errors::NotFound("Worker with address ", address, " not found.");
}
*worker = it->second;
return Status::OK();
}
std::vector<std::shared_ptr<const DispatcherState::Worker>>
DispatcherState::ListWorkers() const {
std::vector<std::shared_ptr<const Worker>> workers;
workers.reserve(workers_.size());
for (const auto& it : workers_) {
workers.push_back(it.second);
}
return workers;
}
std::vector<std::shared_ptr<const DispatcherState::Job>>
DispatcherState::ListJobs() {
std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;

View File

@ -48,11 +48,6 @@ namespace data {
// DispatcherImpl and for providing DispatcherImpl with read-only access to
// the state.
//
// Note that not all state needs to be journaled, and in general we journal
// as little state as possible. For example, worker and task state doesn't need
// to be journaled because we can recover that information from workers when
// they reconnect to a restarted dispatcher.
//
// DispatcherState is thread-compatible but not thread-safe.
class DispatcherState {
public:
@ -65,7 +60,8 @@ class DispatcherState {
// A dataset registered with the dispatcher.
struct Dataset {
Dataset(int64 dataset_id, int64 fingerprint, const DatasetDef& dataset_def)
explicit Dataset(int64 dataset_id, int64 fingerprint,
const DatasetDef& dataset_def)
: dataset_id(dataset_id),
fingerprint(fingerprint),
dataset_def(dataset_def) {}
@ -75,10 +71,17 @@ class DispatcherState {
const DatasetDef dataset_def;
};
// A worker registered with the dispatcher.
struct Worker {
explicit Worker(const std::string& address) : address(address) {}
const std::string address;
};
// A key for identifying a named job. The key contains a user-specified name,
// as well as an index describing which iteration of the job we are on.
struct NamedJobKey {
NamedJobKey(absl::string_view name, int64 index)
explicit NamedJobKey(absl::string_view name, int64 index)
: name(name), index(index) {}
friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
@ -96,8 +99,8 @@ class DispatcherState {
// A job for processing a dataset.
struct Job {
Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
absl::optional<NamedJobKey> named_job_key)
explicit Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
absl::optional<NamedJobKey> named_job_key)
: job_id(job_id),
dataset_id(dataset_id),
processing_mode(processing_mode),
@ -111,8 +114,8 @@ class DispatcherState {
};
struct Task {
Task(int64 task_id, int64 job_id, int64 dataset_id,
const std::string& worker_address)
explicit Task(int64 task_id, int64 job_id, int64 dataset_id,
const std::string& worker_address)
: task_id(task_id),
job_id(job_id),
dataset_id(dataset_id),
@ -134,6 +137,12 @@ class DispatcherState {
Status DatasetFromFingerprint(uint64 fingerprint,
std::shared_ptr<const Dataset>* dataset) const;
// Gets a worker by address. Returns NOT_FOUND if there is no such worker.
Status WorkerFromAddress(const std::string& address,
std::shared_ptr<const Worker>* worker) const;
// Lists all workers registered with the dispatcher.
std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
// Returns the next available job id.
int64 NextAvailableJobId() const;
// Returns a list of all jobs.
@ -153,8 +162,8 @@ class DispatcherState {
std::vector<std::shared_ptr<const Task>>* tasks) const;
private:
// Registers a dataset. The dataset must not already be registered.
void RegisterDataset(const RegisterDatasetUpdate& register_dataset);
void RegisterWorker(const RegisterWorkerUpdate& register_worker);
void CreateJob(const CreateJobUpdate& create_job);
void CreateTask(const CreateTaskUpdate& create_task);
void FinishTask(const FinishTaskUpdate& finish_task);
@ -166,6 +175,9 @@ class DispatcherState {
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>>
datasets_by_fingerprint_;
// Registered workers, keyed by address.
absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
int64 next_available_job_id_ = 0;
// Jobs, keyed by job ids.
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/dispatcher_state.h"
#include <memory>
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/journal.h"
#include "tensorflow/core/data/service/journal.pb.h"
@ -27,9 +29,11 @@ namespace data {
namespace {
using Dataset = DispatcherState::Dataset;
using Worker = DispatcherState::Worker;
using NamedJobKey = DispatcherState::NamedJobKey;
using Job = DispatcherState::Job;
using Task = DispatcherState::Task;
using ::testing::IsEmpty;
using ::testing::SizeIs;
Status RegisterDatasetWithIdAndFingerprint(int64 id, uint64 fingerprint,
@ -42,6 +46,13 @@ Status RegisterDatasetWithIdAndFingerprint(int64 id, uint64 fingerprint,
return Status::OK();
}
Status RegisterWorker(std::string worker_address, DispatcherState* state) {
Update update;
update.mutable_register_worker()->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(state->Apply(update));
return Status::OK();
}
Status CreateAnonymousJob(int64 job_id, int64 dataset_id,
DispatcherState* state) {
Update update;
@ -98,12 +109,12 @@ TEST(DispatcherState, RegisterDataset) {
{
std::shared_ptr<const Dataset> dataset;
TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, &dataset));
EXPECT_EQ(id, dataset->dataset_id);
EXPECT_EQ(dataset->dataset_id, id);
}
{
std::shared_ptr<const Dataset> dataset;
TF_EXPECT_OK(state.DatasetFromId(id, &dataset));
EXPECT_EQ(fingerprint, dataset->fingerprint);
EXPECT_EQ(dataset->fingerprint, fingerprint);
}
}
@ -126,10 +137,46 @@ TEST(DispatcherState, NextAvailableDatasetId) {
int64 id = state.NextAvailableDatasetId();
uint64 fingerprint = 20;
TF_EXPECT_OK(RegisterDatasetWithIdAndFingerprint(id, fingerprint, &state));
EXPECT_NE(id, state.NextAvailableDatasetId());
EXPECT_NE(state.NextAvailableDatasetId(), id);
EXPECT_EQ(state.NextAvailableDatasetId(), state.NextAvailableDatasetId());
}
TEST(DispatcherState, RegisterWorker) {
DispatcherState state;
std::string address = "test_worker_address";
TF_EXPECT_OK(RegisterWorker(address, &state));
std::shared_ptr<const Worker> worker;
TF_EXPECT_OK(state.WorkerFromAddress(address, &worker));
EXPECT_EQ(worker->address, address);
}
TEST(DispatcherState, ListWorkers) {
DispatcherState state;
std::string address_1 = "address_1";
std::string address_2 = "address_2";
{
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
EXPECT_THAT(workers, IsEmpty());
}
TF_EXPECT_OK(RegisterWorker(address_1, &state));
{
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
EXPECT_THAT(workers, SizeIs(1));
}
TF_EXPECT_OK(RegisterWorker(address_2, &state));
{
std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
EXPECT_THAT(workers, SizeIs(2));
}
}
TEST(DispatcherState, MissingWorker) {
DispatcherState state;
std::shared_ptr<const Worker> worker;
Status s = state.WorkerFromAddress("test_worker_address", &worker);
EXPECT_EQ(s.code(), error::NOT_FOUND);
}
TEST(DispatcherState, UnknownUpdate) {
DispatcherState state;
Update update;
@ -146,8 +193,11 @@ TEST(DispatcherState, AnonymousJob) {
std::shared_ptr<const Job> job;
TF_EXPECT_OK(state.JobFromId(job_id, &job));
EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
EXPECT_EQ(dataset_id, job->dataset_id);
EXPECT_EQ(job_id, job->job_id);
EXPECT_EQ(job->dataset_id, dataset_id);
EXPECT_EQ(job->job_id, job_id);
std::vector<std::shared_ptr<const Task>> tasks;
TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
EXPECT_THAT(tasks, IsEmpty());
EXPECT_FALSE(job->finished);
}
@ -161,8 +211,8 @@ TEST(DispatcherState, NamedJob) {
std::shared_ptr<const Job> job;
TF_EXPECT_OK(state.NamedJobByKey(named_job_key, &job));
EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
EXPECT_EQ(dataset_id, job->dataset_id);
EXPECT_EQ(job_id, job->job_id);
EXPECT_EQ(job->dataset_id, dataset_id);
EXPECT_EQ(job->job_id, job_id);
EXPECT_FALSE(job->finished);
}
@ -179,10 +229,10 @@ TEST(DispatcherState, CreateTask) {
{
std::shared_ptr<const Task> task;
TF_EXPECT_OK(state.TaskFromId(task_id, &task));
EXPECT_EQ(task_id, task->task_id);
EXPECT_EQ(job_id, task->job_id);
EXPECT_EQ(dataset_id, task->dataset_id);
EXPECT_EQ(worker_address, task->worker_address);
EXPECT_EQ(task->task_id, task_id);
EXPECT_EQ(task->job_id, job_id);
EXPECT_EQ(task->dataset_id, dataset_id);
EXPECT_EQ(task->worker_address, worker_address);
}
{
std::vector<std::shared_ptr<const Task>> tasks;
@ -207,7 +257,7 @@ TEST(DispatcherState, CreateTasksForSameJob) {
{
std::vector<std::shared_ptr<const Task>> tasks;
TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
EXPECT_EQ(2, tasks.size());
EXPECT_THAT(tasks, SizeIs(2));
}
}
@ -229,12 +279,12 @@ TEST(DispatcherState, CreateTasksForDifferentJobs) {
{
std::vector<std::shared_ptr<const Task>> tasks;
TF_EXPECT_OK(state.TasksForJob(job_id_1, &tasks));
EXPECT_EQ(1, tasks.size());
EXPECT_THAT(tasks, SizeIs(1));
}
{
std::vector<std::shared_ptr<const Task>> tasks;
TF_EXPECT_OK(state.TasksForJob(job_id_2, &tasks));
EXPECT_EQ(1, tasks.size());
EXPECT_THAT(tasks, SizeIs(1));
}
}

View File

@ -48,7 +48,7 @@ Status FileJournalWriter::EnsureInitialized() {
return Status::OK();
}
Status FileJournalWriter::Write(Update update) {
Status FileJournalWriter::Write(const Update& update) {
TF_RETURN_IF_ERROR(EnsureInitialized());
std::string s = update.SerializeAsString();
if (s.empty()) {

View File

@ -32,7 +32,7 @@ class JournalWriter {
public:
virtual ~JournalWriter() = default;
// Writes and syncs an update to the journal.
virtual Status Write(Update update) = 0;
virtual Status Write(const Update& update) = 0;
};
// FileJournalWriter is not thread-safe, requiring external synchronization when
@ -46,7 +46,7 @@ class FileJournalWriter : public JournalWriter {
FileJournalWriter(const FileJournalWriter&) = delete;
FileJournalWriter& operator=(const FileJournalWriter&) = delete;
Status Write(Update update) override;
Status Write(const Update& update) override;
private:
// Initializes the writer if it is not yet initialized.

View File

@ -10,6 +10,7 @@ import "tensorflow/core/data/service/common.proto";
message Update {
oneof update_type {
RegisterDatasetUpdate register_dataset = 1;
RegisterWorkerUpdate register_worker = 5;
CreateJobUpdate create_job = 2;
CreateTaskUpdate create_task = 3;
FinishTaskUpdate finish_task = 4;
@ -22,6 +23,10 @@ message RegisterDatasetUpdate {
uint64 fingerprint = 3;
}
message RegisterWorkerUpdate {
string worker_address = 1;
}
message NamedJobKeyDef {
string name = 1;
int64 index = 2;

View File

@ -197,8 +197,6 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (const TaskDef& task : resp.tasks()) {
TF_RETURN_IF_ERROR(ProcessTaskInternal(task));
}
worker_id_ = resp.worker_id();
VLOG(3) << "Registered worker with id " << resp.worker_id();
return Status::OK();
}
@ -207,7 +205,6 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
<< " task updates to dispatcher";
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
WorkerUpdateRequest req;
req.set_worker_id(worker_id_);
for (int task_id : pending_completed_tasks_) {
TaskProgress* update = req.add_updates();
update->set_task_id(task_id);

View File

@ -338,7 +338,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
absl::flat_hash_map<int64, TaskInfo> task_id_to_task;
for (auto& task : tasks) {
task_id_to_task[task.id()] = task;
task_id_to_task[task.task_id()] = task;
}
mutex_lock l(mu_);
job_finished_ = job_finished;
@ -371,8 +371,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
get_next_cv_.notify_all();
continue;
}
tasks_.push_back(std::make_shared<Task>(
task_info.id(), task_info.worker_address(), std::move(worker)));
tasks_.push_back(std::make_shared<Task>(task_info.task_id(),
task_info.worker_address(),
std::move(worker)));
}
if (dataset()->max_outstanding_requests_ == model::kAutotune) {
// Adjust max_outstanding_requests to account for newly added tasks.

View File

@ -12,6 +12,9 @@ message DispatcherConfig {
// An optional work directory to use for storing dispatcher state, and for
// recovering during restarts.
string work_dir = 3;
// Whether to run in fault tolerant mode, where dispatcher state is saved
// across restarts.
bool fault_tolerant_mode = 4;
}
// Configuration for a tf.data service WorkerServer.