[tf.data service] Refactor master data structures to be classes instead of structs.

This ensures that they are constructed with all the required fields.

This CL improves consistency within the master implementation, and removes some unused fields.

PiperOrigin-RevId: 307985751
Change-Id: I30dcaa84d7ff3e214f0cf0320c1e993b30de5abd
This commit is contained in:
Andrew Audibert 2020-04-22 23:16:31 -07:00 committed by TensorFlower Gardener
parent e4341f21c4
commit 5d615f2f05
2 changed files with 133 additions and 77 deletions

View File

@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/core/data/service/master_impl.h" #include "tensorflow/core/data/service/master_impl.h"
#include <memory>
#include <tuple>
#include <utility>
#include "grpcpp/create_channel.h" #include "grpcpp/create_channel.h"
#include "grpcpp/impl/codegen/server_context.h" #include "grpcpp/impl/codegen/server_context.h"
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
@ -55,32 +59,23 @@ Status DataServiceMasterImpl::RegisterWorker(
VLOG(3) << "Received register worker request"; VLOG(3) << "Received register worker request";
mutex_lock l(mu_); mutex_lock l(mu_);
int64 worker_id = next_worker_id_++; int64 worker_id = next_worker_id_++;
workers_.emplace_back(); workers_.emplace_back(worker_id, request->worker_address());
workers_.back().address = request->worker_address();
workers_.back().id = worker_id;
response->set_worker_id(worker_id); response->set_worker_id(worker_id);
// Allocate tasks to the worker. // Allocate tasks to the worker.
for (auto& entry : jobs_) { for (auto& entry : jobs_) {
Job& job = entry.second; Job& job = entry.second;
if (job.finished) { if (job.finished()) {
continue; continue;
} }
int64 task_id = next_task_id_++; int64 task_id = CreateTask(&job, request->worker_address());
DCHECK(!tasks_.contains(task_id));
Task& task = tasks_[task_id];
task.id = task_id;
task.dataset_id = job.dataset_id;
task.worker_address = request->worker_address();
job.task_ids.push_back(task_id);
job.total_tasks++;
TaskDef* task_def = response->add_tasks(); TaskDef* task_def = response->add_tasks();
*task_def->mutable_dataset() = *task_def->mutable_dataset() =
datasets_by_id_[task.dataset_id]->dataset_def; datasets_by_id_[job.dataset_id()]->dataset_def();
task_def->set_dataset_id(task.dataset_id); task_def->set_dataset_id(job.dataset_id());
task_def->set_job_id(job.id); task_def->set_job_id(job.job_id());
task_def->set_task_id(task.id); task_def->set_task_id(task_id);
} }
VLOG(1) << "Registered worker " << workers_.back().DebugString(); VLOG(1) << "Registered worker " << workers_.back().DebugString();
@ -96,7 +91,7 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
VLOG(3) << "Registering dataset graph: " VLOG(3) << "Registering dataset graph: "
<< request->dataset().graph().DebugString(); << request->dataset().graph().DebugString();
if (datasets_by_fingerprint_.contains(fingerprint)) { if (datasets_by_fingerprint_.contains(fingerprint)) {
int64 id = datasets_by_fingerprint_[fingerprint]->id; int64 id = datasets_by_fingerprint_[fingerprint]->dataset_id();
VLOG(3) << "Received duplicate RegisterDataset request with fingerprint " VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
<< fingerprint << ". Returning id " << id; << fingerprint << ". Returning id " << id;
response->set_dataset_id(id); response->set_dataset_id(id);
@ -112,11 +107,9 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset) const DatasetDef& dataset)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
auto new_dataset = std::make_shared<Dataset>();
int64 dataset_id = next_dataset_id_++; int64 dataset_id = next_dataset_id_++;
new_dataset->id = dataset_id; auto new_dataset =
new_dataset->fingerprint = fingerprint; std::make_shared<Dataset>(dataset_id, fingerprint, dataset);
new_dataset->dataset_def = dataset;
DCHECK(!datasets_by_id_.contains(dataset_id)); DCHECK(!datasets_by_id_.contains(dataset_id));
datasets_by_id_[dataset_id] = new_dataset; datasets_by_id_[dataset_id] = new_dataset;
@ -148,25 +141,18 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
int64 job_id = next_job_id_++; int64 job_id = next_job_id_++;
DCHECK(!jobs_.contains(job_id)); DCHECK(!jobs_.contains(job_id));
Job& job = jobs_[job_id]; auto result =
job.id = job_id; jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id),
job.dataset_id = request->dataset_id(); std::forward_as_tuple(job_id, request->dataset_id()));
DCHECK(result.second);
Job& job = result.first->second;
response->set_job_id(job_id); response->set_job_id(job_id);
for (auto& worker : workers_) { for (auto& worker : workers_) {
int64 task_id = next_task_id_++; int64 task_id = CreateTask(&job, worker.address());
DCHECK(!tasks_.contains(task_id));
Task& task = tasks_[task_id];
task.id = task_id;
task.dataset_id = request->dataset_id();
task.worker_address = worker.address;
job.task_ids.push_back(task_id);
std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(CreateWorkerStub(worker.address, protocol_, &stub));
// TODO(aaudibert): perform these calls asynchronously. // TODO(aaudibert): perform these calls asynchronously.
TF_RETURN_IF_ERROR(AllocateTaskToWorker(task, &worker)); TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
job.total_tasks++;
} }
VLOG(3) << "Beginning job " << job_id << " for dataset " VLOG(3) << "Beginning job " << job_id << " for dataset "
@ -174,25 +160,45 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task, int64 DataServiceMasterImpl::CreateTask(Job* job,
WorkerInfo* worker) const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!worker->stub) { int64 task_id = next_task_id_++;
TF_RETURN_IF_ERROR( DCHECK(!tasks_.contains(task_id));
CreateWorkerStub(worker->address, protocol_, &worker->stub)); auto result =
tasks_.emplace(std::piecewise_construct, std::forward_as_tuple(task_id),
std::forward_as_tuple(task_id, job->job_id(),
job->dataset_id(), worker_address));
job->add_task_id(task_id);
DCHECK(result.second);
return task_id;
}
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
if (!worker->stub()) {
std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
worker->set_stub(std::move(stub));
} }
return Status::OK();
}
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
Worker* worker)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
grpc::ClientContext client_ctx; grpc::ClientContext client_ctx;
ProcessTaskRequest req; ProcessTaskRequest req;
req.mutable_task()->set_dataset_id(task.dataset_id); req.mutable_task()->set_dataset_id(task.dataset_id());
DCHECK(datasets_by_id_.contains(task.dataset_id)); DCHECK(datasets_by_id_.contains(task.dataset_id()));
*req.mutable_task()->mutable_dataset() = *req.mutable_task()->mutable_dataset() =
datasets_by_id_[task.dataset_id]->dataset_def; datasets_by_id_.at(task.dataset_id())->dataset_def();
req.mutable_task()->set_task_id(task.id); req.mutable_task()->set_task_id(task.task_id());
ProcessTaskResponse resp; ProcessTaskResponse resp;
grpc::Status s = worker->stub->ProcessTask(&client_ctx, req, &resp); grpc::Status s = worker->stub()->ProcessTask(&client_ctx, req, &resp);
if (!s.ok()) { if (!s.ok()) {
return grpc_util::WrapError( return grpc_util::WrapError(
absl::StrCat("Failed to submit task to worker ", worker->address), s); absl::StrCat("Failed to submit task to worker ", worker->address()), s);
} }
return Status::OK(); return Status::OK();
} }
@ -207,16 +213,15 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
"> not found."); "> not found.");
} }
Job& job = it->second; Job& job = it->second;
for (const auto& task_id : job.task_ids) { for (const auto& task_id : job.task_ids()) {
auto task_iter = tasks_.find(task_id); auto task_iter = tasks_.find(task_id);
DCHECK(task_iter != tasks_.end()); DCHECK(task_iter != tasks_.end());
Task& task = task_iter->second; Task& task = task_iter->second;
TaskInfo* task_info = response->mutable_task_info()->Add(); TaskInfo* task_info = response->mutable_task_info()->Add();
task_info->set_worker_address(task.worker_address); task_info->set_worker_address(task.worker_address());
task_info->set_id(task.id); task_info->set_id(task.task_id());
} }
job.finished = job.total_tasks > 0 && job.task_ids.empty(); response->set_job_finished(false);
response->set_job_finished(job.finished);
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id " VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
<< request->job_id(); << request->job_id();
return Status::OK(); return Status::OK();

View File

@ -57,41 +57,92 @@ class DataServiceMasterImpl {
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response); Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
private: private:
typedef struct WorkerInfo { class Worker {
std::string address; public:
int64 id; Worker(int64 worker_id, const std::string address)
std::unique_ptr<WorkerService::Stub> stub; : worker_id_(worker_id), address_(address) {}
int64 worker_id() { return worker_id_; }
std::string address() { return address_; }
WorkerService::Stub* stub() { return stub_.get(); }
void set_stub(std::unique_ptr<WorkerService::Stub> stub) {
stub_ = std::move(stub);
}
std::string DebugString() { std::string DebugString() {
return absl::StrCat("id: ", id, "address: ", address); return absl::StrCat("id: ", worker_id_, "address: ", address_);
} }
} WorkerInfo;
typedef struct Dataset { private:
int64 id; const int64 worker_id_;
int64 fingerprint; const std::string address_;
DatasetDef dataset_def; std::unique_ptr<WorkerService::Stub> stub_;
} Dataset; };
typedef struct Job { class Dataset {
int64 id; public:
int64 dataset_id; Dataset(int64 dataset_id, int64 fingerprint, const DatasetDef& dataset_def)
std::vector<int64> task_ids; : dataset_id_(dataset_id),
// The total number of tasks that have been created for this job. fingerprint_(fingerprint),
int64 total_tasks = 0; dataset_def_(dataset_def) {}
bool finished = false;
} Job;
typedef struct Task { int64 dataset_id() const { return dataset_id_; }
int64 id; int64 fingerprint() const { return fingerprint_; }
int64 dataset_id; const DatasetDef& dataset_def() { return dataset_def_; }
std::string worker_address;
} Task; private:
const int64 dataset_id_;
const int64 fingerprint_;
const DatasetDef dataset_def_;
};
class Job {
public:
Job(int64 job_id, int64 dataset_id)
: job_id_(job_id), dataset_id_(dataset_id) {}
int64 job_id() const { return job_id_; }
int64 dataset_id() const { return dataset_id_; }
const std::vector<int64>& task_ids() const { return task_ids_; }
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
bool finished() const { return finished_; }
private:
const int64 job_id_;
const int64 dataset_id_;
std::vector<int64> task_ids_;
bool finished_ = false;
};
class Task {
public:
Task(int64 task_id, int64 job_id, int64 dataset_id,
const std::string& worker_address)
: task_id_(task_id),
job_id_(job_id),
dataset_id_(dataset_id),
worker_address_(worker_address) {}
int64 task_id() const { return task_id_; }
int64 job_id() const { return job_id_; }
int64 dataset_id() const { return dataset_id_; }
std::string worker_address() const { return worker_address_; }
private:
const int64 task_id_;
const int64 job_id_;
const int64 dataset_id_;
const std::string worker_address_;
};
// Registers a dataset with the given fingerprint, returning a new dataset id. // Registers a dataset with the given fingerprint, returning a new dataset id.
int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset); int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset);
// Initializes a workers stub, if it hasn't been initialized already.
Status EnsureWorkerStubInitialized(Worker* worker);
// Instructs a worker to begin processing a task. // Instructs a worker to begin processing a task.
Status AllocateTaskToWorker(const Task& task_id, WorkerInfo* worker); Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
// Creates a new task for a job, returning the new task's id.
int64 CreateTask(Job* job, const std::string& worker_address);
// Protocol to use for communicating with workers. // Protocol to use for communicating with workers.
const std::string protocol_; const std::string protocol_;
@ -104,7 +155,7 @@ class DataServiceMasterImpl {
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0; int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
// Registered workers. // Registered workers.
std::vector<WorkerInfo> workers_ TF_GUARDED_BY(mu_); std::vector<Worker> workers_ TF_GUARDED_BY(mu_);
// Registered datasets, keyed by dataset ids. // Registered datasets, keyed by dataset ids.
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_ absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_
TF_GUARDED_BY(mu_); TF_GUARDED_BY(mu_);