[tf.data service] Refactor master data structures to be classes instead of structs.
This ensures that they are constructed with all the required fields. This CL improves consistency within the master implementation, and removes some unused fields. PiperOrigin-RevId: 307985751 Change-Id: I30dcaa84d7ff3e214f0cf0320c1e993b30de5abd
This commit is contained in:
parent
e4341f21c4
commit
5d615f2f05
@ -15,6 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/data/service/master_impl.h"
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/impl/codegen/server_context.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
@ -55,32 +59,23 @@ Status DataServiceMasterImpl::RegisterWorker(
|
||||
VLOG(3) << "Received register worker request";
|
||||
mutex_lock l(mu_);
|
||||
int64 worker_id = next_worker_id_++;
|
||||
workers_.emplace_back();
|
||||
workers_.back().address = request->worker_address();
|
||||
workers_.back().id = worker_id;
|
||||
workers_.emplace_back(worker_id, request->worker_address());
|
||||
response->set_worker_id(worker_id);
|
||||
|
||||
// Allocate tasks to the worker.
|
||||
for (auto& entry : jobs_) {
|
||||
Job& job = entry.second;
|
||||
if (job.finished) {
|
||||
if (job.finished()) {
|
||||
continue;
|
||||
}
|
||||
int64 task_id = next_task_id_++;
|
||||
DCHECK(!tasks_.contains(task_id));
|
||||
Task& task = tasks_[task_id];
|
||||
task.id = task_id;
|
||||
task.dataset_id = job.dataset_id;
|
||||
task.worker_address = request->worker_address();
|
||||
job.task_ids.push_back(task_id);
|
||||
job.total_tasks++;
|
||||
int64 task_id = CreateTask(&job, request->worker_address());
|
||||
|
||||
TaskDef* task_def = response->add_tasks();
|
||||
*task_def->mutable_dataset() =
|
||||
datasets_by_id_[task.dataset_id]->dataset_def;
|
||||
task_def->set_dataset_id(task.dataset_id);
|
||||
task_def->set_job_id(job.id);
|
||||
task_def->set_task_id(task.id);
|
||||
datasets_by_id_[job.dataset_id()]->dataset_def();
|
||||
task_def->set_dataset_id(job.dataset_id());
|
||||
task_def->set_job_id(job.job_id());
|
||||
task_def->set_task_id(task_id);
|
||||
}
|
||||
|
||||
VLOG(1) << "Registered worker " << workers_.back().DebugString();
|
||||
@ -96,7 +91,7 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
|
||||
VLOG(3) << "Registering dataset graph: "
|
||||
<< request->dataset().graph().DebugString();
|
||||
if (datasets_by_fingerprint_.contains(fingerprint)) {
|
||||
int64 id = datasets_by_fingerprint_[fingerprint]->id;
|
||||
int64 id = datasets_by_fingerprint_[fingerprint]->dataset_id();
|
||||
VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
|
||||
<< fingerprint << ". Returning id " << id;
|
||||
response->set_dataset_id(id);
|
||||
@ -112,11 +107,9 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
|
||||
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
||||
const DatasetDef& dataset)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
auto new_dataset = std::make_shared<Dataset>();
|
||||
int64 dataset_id = next_dataset_id_++;
|
||||
new_dataset->id = dataset_id;
|
||||
new_dataset->fingerprint = fingerprint;
|
||||
new_dataset->dataset_def = dataset;
|
||||
auto new_dataset =
|
||||
std::make_shared<Dataset>(dataset_id, fingerprint, dataset);
|
||||
|
||||
DCHECK(!datasets_by_id_.contains(dataset_id));
|
||||
datasets_by_id_[dataset_id] = new_dataset;
|
||||
@ -148,25 +141,18 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||
|
||||
int64 job_id = next_job_id_++;
|
||||
DCHECK(!jobs_.contains(job_id));
|
||||
Job& job = jobs_[job_id];
|
||||
job.id = job_id;
|
||||
job.dataset_id = request->dataset_id();
|
||||
auto result =
|
||||
jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id),
|
||||
std::forward_as_tuple(job_id, request->dataset_id()));
|
||||
DCHECK(result.second);
|
||||
Job& job = result.first->second;
|
||||
response->set_job_id(job_id);
|
||||
|
||||
for (auto& worker : workers_) {
|
||||
int64 task_id = next_task_id_++;
|
||||
DCHECK(!tasks_.contains(task_id));
|
||||
Task& task = tasks_[task_id];
|
||||
task.id = task_id;
|
||||
task.dataset_id = request->dataset_id();
|
||||
task.worker_address = worker.address;
|
||||
job.task_ids.push_back(task_id);
|
||||
int64 task_id = CreateTask(&job, worker.address());
|
||||
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
TF_RETURN_IF_ERROR(CreateWorkerStub(worker.address, protocol_, &stub));
|
||||
// TODO(aaudibert): perform these calls asynchronously.
|
||||
TF_RETURN_IF_ERROR(AllocateTaskToWorker(task, &worker));
|
||||
job.total_tasks++;
|
||||
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
|
||||
}
|
||||
|
||||
VLOG(3) << "Beginning job " << job_id << " for dataset "
|
||||
@ -174,25 +160,45 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
||||
WorkerInfo* worker)
|
||||
int64 DataServiceMasterImpl::CreateTask(Job* job,
|
||||
const std::string& worker_address)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!worker->stub) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateWorkerStub(worker->address, protocol_, &worker->stub));
|
||||
int64 task_id = next_task_id_++;
|
||||
DCHECK(!tasks_.contains(task_id));
|
||||
auto result =
|
||||
tasks_.emplace(std::piecewise_construct, std::forward_as_tuple(task_id),
|
||||
std::forward_as_tuple(task_id, job->job_id(),
|
||||
job->dataset_id(), worker_address));
|
||||
job->add_task_id(task_id);
|
||||
DCHECK(result.second);
|
||||
return task_id;
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
if (!worker->stub()) {
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
||||
worker->set_stub(std::move(stub));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
||||
Worker* worker)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
||||
grpc::ClientContext client_ctx;
|
||||
ProcessTaskRequest req;
|
||||
req.mutable_task()->set_dataset_id(task.dataset_id);
|
||||
DCHECK(datasets_by_id_.contains(task.dataset_id));
|
||||
req.mutable_task()->set_dataset_id(task.dataset_id());
|
||||
DCHECK(datasets_by_id_.contains(task.dataset_id()));
|
||||
*req.mutable_task()->mutable_dataset() =
|
||||
datasets_by_id_[task.dataset_id]->dataset_def;
|
||||
req.mutable_task()->set_task_id(task.id);
|
||||
datasets_by_id_.at(task.dataset_id())->dataset_def();
|
||||
req.mutable_task()->set_task_id(task.task_id());
|
||||
ProcessTaskResponse resp;
|
||||
grpc::Status s = worker->stub->ProcessTask(&client_ctx, req, &resp);
|
||||
grpc::Status s = worker->stub()->ProcessTask(&client_ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError(
|
||||
absl::StrCat("Failed to submit task to worker ", worker->address), s);
|
||||
absl::StrCat("Failed to submit task to worker ", worker->address()), s);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -207,16 +213,15 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
||||
"> not found.");
|
||||
}
|
||||
Job& job = it->second;
|
||||
for (const auto& task_id : job.task_ids) {
|
||||
for (const auto& task_id : job.task_ids()) {
|
||||
auto task_iter = tasks_.find(task_id);
|
||||
DCHECK(task_iter != tasks_.end());
|
||||
Task& task = task_iter->second;
|
||||
TaskInfo* task_info = response->mutable_task_info()->Add();
|
||||
task_info->set_worker_address(task.worker_address);
|
||||
task_info->set_id(task.id);
|
||||
task_info->set_worker_address(task.worker_address());
|
||||
task_info->set_id(task.task_id());
|
||||
}
|
||||
job.finished = job.total_tasks > 0 && job.task_ids.empty();
|
||||
response->set_job_finished(job.finished);
|
||||
response->set_job_finished(false);
|
||||
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
|
||||
<< request->job_id();
|
||||
return Status::OK();
|
||||
|
@ -57,41 +57,92 @@ class DataServiceMasterImpl {
|
||||
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
|
||||
|
||||
private:
|
||||
typedef struct WorkerInfo {
|
||||
std::string address;
|
||||
int64 id;
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
class Worker {
|
||||
public:
|
||||
Worker(int64 worker_id, const std::string address)
|
||||
: worker_id_(worker_id), address_(address) {}
|
||||
|
||||
int64 worker_id() { return worker_id_; }
|
||||
std::string address() { return address_; }
|
||||
WorkerService::Stub* stub() { return stub_.get(); }
|
||||
void set_stub(std::unique_ptr<WorkerService::Stub> stub) {
|
||||
stub_ = std::move(stub);
|
||||
}
|
||||
|
||||
std::string DebugString() {
|
||||
return absl::StrCat("id: ", id, "address: ", address);
|
||||
return absl::StrCat("id: ", worker_id_, "address: ", address_);
|
||||
}
|
||||
} WorkerInfo;
|
||||
|
||||
typedef struct Dataset {
|
||||
int64 id;
|
||||
int64 fingerprint;
|
||||
DatasetDef dataset_def;
|
||||
} Dataset;
|
||||
private:
|
||||
const int64 worker_id_;
|
||||
const std::string address_;
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
};
|
||||
|
||||
typedef struct Job {
|
||||
int64 id;
|
||||
int64 dataset_id;
|
||||
std::vector<int64> task_ids;
|
||||
// The total number of tasks that have been created for this job.
|
||||
int64 total_tasks = 0;
|
||||
bool finished = false;
|
||||
} Job;
|
||||
class Dataset {
|
||||
public:
|
||||
Dataset(int64 dataset_id, int64 fingerprint, const DatasetDef& dataset_def)
|
||||
: dataset_id_(dataset_id),
|
||||
fingerprint_(fingerprint),
|
||||
dataset_def_(dataset_def) {}
|
||||
|
||||
typedef struct Task {
|
||||
int64 id;
|
||||
int64 dataset_id;
|
||||
std::string worker_address;
|
||||
} Task;
|
||||
int64 dataset_id() const { return dataset_id_; }
|
||||
int64 fingerprint() const { return fingerprint_; }
|
||||
const DatasetDef& dataset_def() { return dataset_def_; }
|
||||
|
||||
private:
|
||||
const int64 dataset_id_;
|
||||
const int64 fingerprint_;
|
||||
const DatasetDef dataset_def_;
|
||||
};
|
||||
|
||||
class Job {
|
||||
public:
|
||||
Job(int64 job_id, int64 dataset_id)
|
||||
: job_id_(job_id), dataset_id_(dataset_id) {}
|
||||
|
||||
int64 job_id() const { return job_id_; }
|
||||
int64 dataset_id() const { return dataset_id_; }
|
||||
const std::vector<int64>& task_ids() const { return task_ids_; }
|
||||
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
|
||||
bool finished() const { return finished_; }
|
||||
|
||||
private:
|
||||
const int64 job_id_;
|
||||
const int64 dataset_id_;
|
||||
std::vector<int64> task_ids_;
|
||||
bool finished_ = false;
|
||||
};
|
||||
|
||||
class Task {
|
||||
public:
|
||||
Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||
const std::string& worker_address)
|
||||
: task_id_(task_id),
|
||||
job_id_(job_id),
|
||||
dataset_id_(dataset_id),
|
||||
worker_address_(worker_address) {}
|
||||
|
||||
int64 task_id() const { return task_id_; }
|
||||
int64 job_id() const { return job_id_; }
|
||||
int64 dataset_id() const { return dataset_id_; }
|
||||
std::string worker_address() const { return worker_address_; }
|
||||
|
||||
private:
|
||||
const int64 task_id_;
|
||||
const int64 job_id_;
|
||||
const int64 dataset_id_;
|
||||
const std::string worker_address_;
|
||||
};
|
||||
|
||||
// Registers a dataset with the given fingerprint, returning a new dataset id.
|
||||
int64 RegisterDataset(uint64 fingerprint, const DatasetDef& dataset);
|
||||
// Initializes a workers stub, if it hasn't been initialized already.
|
||||
Status EnsureWorkerStubInitialized(Worker* worker);
|
||||
// Instructs a worker to begin processing a task.
|
||||
Status AllocateTaskToWorker(const Task& task_id, WorkerInfo* worker);
|
||||
Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
|
||||
// Creates a new task for a job, returning the new task's id.
|
||||
int64 CreateTask(Job* job, const std::string& worker_address);
|
||||
|
||||
// Protocol to use for communicating with workers.
|
||||
const std::string protocol_;
|
||||
@ -104,7 +155,7 @@ class DataServiceMasterImpl {
|
||||
int64 next_task_id_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Registered workers.
|
||||
std::vector<WorkerInfo> workers_ TF_GUARDED_BY(mu_);
|
||||
std::vector<Worker> workers_ TF_GUARDED_BY(mu_);
|
||||
// Registered datasets, keyed by dataset ids.
|
||||
absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
Loading…
Reference in New Issue
Block a user