[tf.data service] Add support for shared job names.

Shared job names give users a way to share tf.data service job output across multiple datasets.

PiperOrigin-RevId: 310037762
Change-Id: I5a8f9130806a361ada46b3f0b62ea71b0f9b2e8e
This commit is contained in:
Andrew Audibert 2020-05-05 15:54:30 -07:00 committed by TensorFlower Gardener
parent 48baba71cf
commit 22546b562d
17 changed files with 493 additions and 128 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "DummyIterationCounter"
visibility: HIDDEN
}

View File

@ -56,6 +56,7 @@ cc_library(
deps = [ deps = [
":common_proto_cc", ":common_proto_cc",
":credentials_factory", ":credentials_factory",
":data_service",
":grpc_util", ":grpc_util",
":master_proto_cc", ":master_proto_cc",
":worker_cc_grpc_proto", ":worker_cc_grpc_proto",

View File

@ -31,7 +31,7 @@ constexpr const char kParallelEpochs[] = "parallel_epochs";
constexpr const char kOneEpoch[] = "one_epoch"; constexpr const char kOneEpoch[] = "one_epoch";
} // namespace } // namespace
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) { Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) {
if (s == kParallelEpochs) { if (s == kParallelEpochs) {
*mode = ProcessingMode::PARALLEL_EPOCHS; *mode = ProcessingMode::PARALLEL_EPOCHS;
} else if (s == kOneEpoch) { } else if (s == kOneEpoch) {
@ -54,6 +54,21 @@ std::string ProcessingModeToString(ProcessingMode mode) {
} }
} }
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = dataset;
GetOrRegisterDatasetResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to register dataset", status);
}
*dataset_id = resp.dataset_id();
return Status::OK();
}
Status DataServiceMasterClient::CreateJob(int64 dataset_id, Status DataServiceMasterClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode, ProcessingMode processing_mode,
int64* job_id) { int64* job_id) {
@ -73,18 +88,27 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
int64* dataset_id) { ProcessingMode processing_mode,
const std::string& job_name,
int job_name_index,
int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req; GetOrCreateJobRequest req;
*req.mutable_dataset()->mutable_graph() = dataset; req.set_dataset_id(dataset_id);
GetOrRegisterDatasetResponse resp; req.set_processing_mode(ProcessingModeDef(processing_mode));
req.set_job_name(job_name);
req.set_job_name_index(job_name_index);
GetOrCreateJobResponse resp;
grpc::ClientContext client_ctx; grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp); grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
if (!status.ok()) { if (!status.ok()) {
return grpc_util::WrapError("Failed to register dataset", status); return grpc_util::WrapError(
absl::StrCat("Failed to get or create job for dataset with id ",
dataset_id),
status);
} }
*dataset_id = resp.dataset_id(); *job_id = resp.job_id();
return Status::OK(); return Status::OK();
} }
@ -148,7 +172,7 @@ Status DataServiceWorkerClient::EnsureInitialized() {
} }
Status CreateDataServiceMasterClient( Status CreateDataServiceMasterClient(
absl::string_view address, absl::string_view protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out) { std::unique_ptr<DataServiceMasterClient>* out) {
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol); auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize()); TF_RETURN_IF_ERROR(client->Initialize());
@ -157,7 +181,7 @@ Status CreateDataServiceMasterClient(
} }
Status CreateDataServiceWorkerClient( Status CreateDataServiceWorkerClient(
absl::string_view address, absl::string_view protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceWorkerClient>* out) { std::unique_ptr<DataServiceWorkerClient>* out) {
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol); auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize()); TF_RETURN_IF_ERROR(client->Initialize());

View File

@ -35,7 +35,7 @@ enum class ProcessingMode : int64 {
// Parses a string representing a processing mode and stores the result in // Parses a string representing a processing mode and stores the result in
// *mode. Returns an InvalidArgument status if the string is not recognized. // *mode. Returns an InvalidArgument status if the string is not recognized.
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode); Status ParseProcessingMode(const std::string& s, ProcessingMode* mode);
// Converts a processing mode to its corresponding string. // Converts a processing mode to its corresponding string.
std::string ProcessingModeToString(ProcessingMode mode); std::string ProcessingModeToString(ProcessingMode mode);
@ -45,7 +45,7 @@ std::string ProcessingModeToString(ProcessingMode mode);
// threads. // threads.
class DataServiceClientBase { class DataServiceClientBase {
public: public:
DataServiceClientBase(absl::string_view address, absl::string_view protocol) DataServiceClientBase(const std::string& address, const std::string& protocol)
: address_(address), protocol_(protocol) {} : address_(address), protocol_(protocol) {}
virtual ~DataServiceClientBase() = default; virtual ~DataServiceClientBase() = default;
@ -70,7 +70,8 @@ class DataServiceClientBase {
// Client for communicating with the tf.data service master. // Client for communicating with the tf.data service master.
class DataServiceMasterClient : public DataServiceClientBase { class DataServiceMasterClient : public DataServiceClientBase {
public: public:
DataServiceMasterClient(absl::string_view address, absl::string_view protocol) DataServiceMasterClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {} : DataServiceClientBase(address, protocol) {}
// Registers a dataset with the tf.data service, and stores the generated // Registers a dataset with the tf.data service, and stores the generated
@ -82,6 +83,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
int64* job_id); int64* job_id);
// Gets the job id for the job represented by the tuple
// (job_name, job_name_index), and stores the id in *job_id. If the
// job doesn't exist yet, it will be created.
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
const std::string& job_name, int job_name_index,
int64* job_id);
// Queries the master for the tasks associated with the specified job. // Queries the master for the tasks associated with the specified job.
// The tasks will be stored in *tasks, and whether the job is finished will // The tasks will be stored in *tasks, and whether the job is finished will
// be stored in `*job_finished`. // be stored in `*job_finished`.
@ -98,7 +106,8 @@ class DataServiceMasterClient : public DataServiceClientBase {
// Client for communicating with the tf.data service worker. // Client for communicating with the tf.data service worker.
class DataServiceWorkerClient : public DataServiceClientBase { class DataServiceWorkerClient : public DataServiceClientBase {
public: public:
DataServiceWorkerClient(absl::string_view address, absl::string_view protocol) DataServiceWorkerClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {} : DataServiceClientBase(address, protocol) {}
// Fetches the next element for the specified task_id. The element's // Fetches the next element for the specified task_id. The element's
@ -116,12 +125,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Creates and initializes a new tf.data service master client. // Creates and initializes a new tf.data service master client.
Status CreateDataServiceMasterClient( Status CreateDataServiceMasterClient(
absl::string_view address, absl::string_view protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out); std::unique_ptr<DataServiceMasterClient>* out);
// Creates and initializes a new tf.data service worker client. // Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient( Status CreateDataServiceWorkerClient(
absl::string_view address, absl::string_view protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceWorkerClient>* out); std::unique_ptr<DataServiceWorkerClient>* out);
} // namespace data } // namespace data

View File

@ -42,6 +42,7 @@ HANDLER(RegisterWorker);
HANDLER(WorkerUpdate); HANDLER(WorkerUpdate);
HANDLER(GetOrRegisterDataset); HANDLER(GetOrRegisterDataset);
HANDLER(CreateJob); HANDLER(CreateJob);
HANDLER(GetOrCreateJob);
HANDLER(GetTasks); HANDLER(GetTasks);
#undef HANDLER #undef HANDLER

View File

@ -46,6 +46,7 @@ class GrpcMasterImpl : public MasterService::Service {
HANDLER(WorkerUpdate); HANDLER(WorkerUpdate);
HANDLER(GetOrRegisterDataset); HANDLER(GetOrRegisterDataset);
HANDLER(CreateJob); HANDLER(CreateJob);
HANDLER(GetOrCreateJob);
HANDLER(GetTasks); HANDLER(GetTasks);
#undef HANDLER #undef HANDLER

View File

@ -56,7 +56,24 @@ message CreateJobRequest {
} }
message CreateJobResponse { message CreateJobResponse {
// An id for the begun job. // An id for the created job.
int64 job_id = 1;
}
message GetOrCreateJobRequest {
// The id of the dataset to create a job for.
int64 dataset_id = 1;
// A mode controlling how the tf.data service produces data for the job.
ProcessingModeDef processing_mode = 2;
// A name for the job.
string job_name = 3;
// An index for the job. Multiple jobs can be created for the same name, if
// they have different indices.
int64 job_name_index = 4;
}
message GetOrCreateJobResponse {
// The id of the (potentially newly created) job.
int64 job_id = 1; int64 job_id = 1;
} }
@ -96,6 +113,9 @@ service MasterService {
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest) rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
returns (GetOrRegisterDatasetResponse); returns (GetOrRegisterDatasetResponse);
// Gets a job if it already exists, otherwise creates it.
rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse);
// Creates a job for reading from the tf.data service. // Creates a job for reading from the tf.data service.
rpc CreateJob(CreateJobRequest) returns (CreateJobResponse); rpc CreateJob(CreateJobRequest) returns (CreateJobResponse);

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
@ -65,17 +66,17 @@ Status DataServiceMasterImpl::RegisterWorker(
// Allocate tasks to the worker. // Allocate tasks to the worker.
for (auto& entry : jobs_) { for (auto& entry : jobs_) {
Job& job = entry.second; std::shared_ptr<Job> job = entry.second;
if (job.finished()) { if (job->finished()) {
continue; continue;
} }
int64 task_id = CreateTask(&job, request->worker_address()); int64 task_id = CreateTask(job.get(), request->worker_address());
TaskDef* task_def = response->add_tasks(); TaskDef* task_def = response->add_tasks();
*task_def->mutable_dataset() = *task_def->mutable_dataset() =
datasets_by_id_[job.dataset_id()]->dataset_def(); datasets_by_id_[job->dataset_id()]->dataset_def();
task_def->set_dataset_id(job.dataset_id()); task_def->set_dataset_id(job->dataset_id());
task_def->set_job_id(job.job_id()); task_def->set_job_id(job->job_id());
task_def->set_task_id(task_id); task_def->set_task_id(task_id);
} }
@ -96,7 +97,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
if (update.completed()) { if (update.completed()) {
int64 job_id = tasks_.at(task_id).job_id(); int64 job_id = tasks_.at(task_id).job_id();
DCHECK(jobs_.contains(job_id)); DCHECK(jobs_.contains(job_id));
jobs_.at(job_id).task_finished(task_id); jobs_.at(job_id)->task_finished(task_id);
VLOG(3) << "Task " << task_id << " from job " << job_id << " completed"; VLOG(3) << "Task " << task_id << " from job " << job_id << " completed";
} }
} }
@ -135,49 +136,109 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
DCHECK(!datasets_by_id_.contains(dataset_id)); DCHECK(!datasets_by_id_.contains(dataset_id));
datasets_by_id_[dataset_id] = new_dataset; datasets_by_id_[dataset_id] = new_dataset;
DCHECK(!datasets_by_fingerprint_.contains(fingerprint)); DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
datasets_by_fingerprint_[dataset_id] = new_dataset; datasets_by_fingerprint_[fingerprint] = new_dataset;
return dataset_id; return dataset_id;
} }
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request, Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) { CreateJobResponse* response) {
VLOG(3) << "Received begin job request for dataset id " VLOG(3) << "Received create job request for dataset id "
<< request->dataset_id(); << request->dataset_id();
switch (request->processing_mode()) { ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
case PARALLEL_EPOCHS: mutex_lock l(mu_);
int64 job_id;
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
absl::optional<std::string>(), &job_id));
response->set_job_id(job_id);
VLOG(3) << "Creating job " << job_id << " for dataset "
<< request->dataset_id();
return Status::OK();
}
Status DataServiceMasterImpl::GetOrCreateJob(
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
VLOG(3) << "Received get or create job request for dataset id "
<< request->dataset_id() << " with name " << request->job_name()
<< " and index " << request->job_name_index();
mutex_lock l(mu_);
NamedJobKey key(request->job_name(), request->job_name_index());
ProcessingMode requested_processing_mode =
ProcessingMode(request->processing_mode());
std::shared_ptr<Job>* job = gtl::FindOrNull(named_jobs_, key);
if (job != nullptr) {
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
request->dataset_id()));
response->set_job_id((*job)->job_id());
return Status::OK();
}
int64 job_id;
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), requested_processing_mode,
request->job_name(), &job_id));
named_jobs_[key] = jobs_[job_id];
response->set_job_id(job_id);
return Status::OK();
}
// Validates that the job matches the given processing_mode and dataset_id.
Status DataServiceMasterImpl::ValidateMatchingJob(
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
DCHECK(job.name().has_value());
std::string job_name = job.name().value();
if (job.processing_mode() != processing_mode) {
std::string requested = ProcessingModeToString(processing_mode);
std::string actual = ProcessingModeToString(job.processing_mode());
return errors::FailedPrecondition(
"Found a job with name ", job_name, ", but the processing mode <",
actual, "> doesn't match the requested processing mode <", requested,
">.");
}
if (job.dataset_id() != dataset_id) {
return errors::FailedPrecondition(
"Found a job with name ", job_name, ", but the dataset id <",
job.dataset_id(), "> doesn't match the requested dataset id <",
dataset_id, ">.");
}
return Status::OK();
}
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
absl::optional<std::string> job_name,
int64* out_job_id)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS:
break; break;
case ONE_EPOCH: case ProcessingMode::ONE_EPOCH:
return errors::Unimplemented( return errors::Unimplemented(
"CreateJob only supports the PARALLEL_EPOCHS job mode. " "CreateJob only supports the PARALLEL_EPOCHS job mode. "
"ONE_EPOCH is not currently supported."); "ONE_EPOCH is not currently supported.");
default: default:
return errors::Unimplemented( return errors::Unimplemented("ProcessingMode ",
"ProcessingMode ", request->processing_mode(), " not recognized"); ProcessingModeToString(processing_mode),
" not recognized");
} }
mutex_lock l(mu_); if (!datasets_by_id_.contains(dataset_id)) {
if (!datasets_by_id_.contains(request->dataset_id())) { return errors::NotFound("Dataset id: <", dataset_id, "> not found.");
return errors::NotFound("CreateJob failed. Dataset id: <",
request->dataset_id(), "> not found.");
} }
int64 job_id = next_job_id_++; int64 job_id = next_job_id_++;
DCHECK(!jobs_.contains(job_id)); DCHECK(!jobs_.contains(job_id));
auto result = auto job =
jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id), std::make_shared<Job>(job_id, dataset_id, processing_mode, job_name);
std::forward_as_tuple(job_id, request->dataset_id())); jobs_[job_id] = job;
DCHECK(result.second);
Job& job = result.first->second;
response->set_job_id(job_id);
for (auto& worker : workers_) { for (auto& worker : workers_) {
int64 task_id = CreateTask(&job, worker.address()); int64 task_id = CreateTask(job.get(), worker.address());
// TODO(aaudibert): perform these calls asynchronously. // TODO(aaudibert): perform these calls asynchronously.
// TODO(aaudibert): clean up in case some calls succeed, but later calls
// fail
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker)); TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
} }
VLOG(3) << "Beginning job " << job_id << " for dataset " *out_job_id = job_id;
<< request->dataset_id();
return Status::OK(); return Status::OK();
} }
@ -233,8 +294,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
return errors::NotFound("GetTasks failed. Job id <", request->job_id(), return errors::NotFound("GetTasks failed. Job id <", request->job_id(),
"> not found."); "> not found.");
} }
Job& job = it->second; std::shared_ptr<Job> job = it->second;
for (const auto& task_id : job.task_ids()) { for (const auto& task_id : job->task_ids()) {
auto task_iter = tasks_.find(task_id); auto task_iter = tasks_.find(task_id);
DCHECK(task_iter != tasks_.end()); DCHECK(task_iter != tasks_.end());
Task& task = task_iter->second; Task& task = task_iter->second;
@ -242,7 +303,7 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
task_info->set_worker_address(task.worker_address()); task_info->set_worker_address(task.worker_address());
task_info->set_id(task.task_id()); task_info->set_id(task.task_id());
} }
response->set_job_finished(job.finished()); response->set_job_finished(job->finished());
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id " VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
<< request->job_id(); << request->job_id();
return Status::OK(); return Status::OK();

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -56,6 +57,8 @@ class DataServiceMasterImpl {
GetOrRegisterDatasetResponse* response); GetOrRegisterDatasetResponse* response);
Status CreateJob(const CreateJobRequest* request, Status CreateJob(const CreateJobRequest* request,
CreateJobResponse* response); CreateJobResponse* response);
Status GetOrCreateJob(const GetOrCreateJobRequest* request,
GetOrCreateJobResponse* response);
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response); Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
private: private:
@ -100,11 +103,17 @@ class DataServiceMasterImpl {
class Job { class Job {
public: public:
Job(int64 job_id, int64 dataset_id) Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
: job_id_(job_id), dataset_id_(dataset_id) {} absl::optional<absl::string_view> job_name)
: job_id_(job_id),
dataset_id_(dataset_id),
processing_mode_(processing_mode),
job_name_(job_name) {}
int64 job_id() const { return job_id_; } int64 job_id() const { return job_id_; }
int64 dataset_id() const { return dataset_id_; } int64 dataset_id() const { return dataset_id_; }
ProcessingMode processing_mode() const { return processing_mode_; }
absl::optional<std::string> name() const { return job_name_; }
const std::vector<int64>& task_ids() const { return task_ids_; } const std::vector<int64>& task_ids() const { return task_ids_; }
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); } void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
void task_finished(int64 task_id) { void task_finished(int64 task_id) {
@ -118,11 +127,32 @@ class DataServiceMasterImpl {
private: private:
const int64 job_id_; const int64 job_id_;
const int64 dataset_id_; const int64 dataset_id_;
const ProcessingMode processing_mode_;
const absl::optional<std::string> job_name_;
std::vector<int64> task_ids_; std::vector<int64> task_ids_;
std::vector<int64> finished_tasks_; std::vector<int64> finished_tasks_;
bool finished_ = false; bool finished_ = false;
}; };
class NamedJobKey {
public:
NamedJobKey(absl::string_view name, int64 index)
: name_(name), index_(index) {}
friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
return lhs.name_ == rhs.name_ && lhs.index_ == rhs.index_;
}
template <typename H>
friend H AbslHashValue(H h, const NamedJobKey& k) {
return H::combine(std::move(h), k.name_, k.index_);
}
private:
const std::string name_;
const int64 index_;
};
class Task { class Task {
public: public:
Task(int64 task_id, int64 job_id, int64 dataset_id, Task(int64 task_id, int64 job_id, int64 dataset_id,
@ -150,9 +180,15 @@ class DataServiceMasterImpl {
Status EnsureWorkerStubInitialized(Worker* worker); Status EnsureWorkerStubInitialized(Worker* worker);
// Instructs a worker to begin processing a task. // Instructs a worker to begin processing a task.
Status AllocateTaskToWorker(const Task& task_id, Worker* worker); Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
// Creates a job and stores its job_id in `*job_id`.
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
absl::optional<std::string> job_name, int64* out_job_id);
// Creates a new task for a job, returning the new task's id. // Creates a new task for a job, returning the new task's id.
int64 CreateTask(Job* job, const std::string& worker_address); int64 CreateTask(Job* job, const std::string& worker_address);
// Validates that an existing job matches the given processing_mode and
// dataset_id, returning an error status describing any difference.
Status ValidateMatchingJob(const Job& job, ProcessingMode processing_mode,
int64 dataset_id);
// Protocol to use for communicating with workers. // Protocol to use for communicating with workers.
const std::string protocol_; const std::string protocol_;
@ -172,9 +208,13 @@ class DataServiceMasterImpl {
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> datasets_by_fingerprint_ absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> datasets_by_fingerprint_
TF_GUARDED_BY(mu_); TF_GUARDED_BY(mu_);
// Information about jobs, keyed by job ids. // Information about jobs, keyed by job ids.
absl::flat_hash_map<int64, Job> jobs_ TF_GUARDED_BY(mu_); absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_ TF_GUARDED_BY(mu_);
// Information about tasks, keyed by task ids. // Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_); absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
// Named jobs, keyed by their names and indices. Not all jobs have names, so
// this is a subset of the jobs stored in `jobs_`.
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl); TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
}; };

View File

@ -47,8 +47,11 @@ namespace data {
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode; /* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress; /* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
/* static */ constexpr const char* const /* static */ constexpr const char* const
DataServiceDatasetOp::kMaxOutstandingRequests; DataServiceDatasetOp::kMaxOutstandingRequests;
/* static */ constexpr const char* const
DataServiceDatasetOp::kIterationCounter;
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
@ -71,23 +74,45 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
public: public:
Dataset(OpKernelContext* ctx, int64 dataset_id, Dataset(OpKernelContext* ctx, int64 dataset_id,
ProcessingMode processing_mode, const std::string& address, ProcessingMode processing_mode, const std::string& address,
const std::string& protocol, int64 max_outstanding_requests, const std::string& protocol, const std::string& job_name,
int64 task_refresh_interval_ms, const DataTypeVector& output_types, int64 max_outstanding_requests, int64 task_refresh_interval_ms,
IterationCounter* iteration_counter, bool owns_resource,
ResourceHandle iteration_counter_handle,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes) const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
dataset_id_(dataset_id), dataset_id_(dataset_id),
processing_mode_(processing_mode), processing_mode_(processing_mode),
address_(address), address_(address),
protocol_(protocol), protocol_(protocol),
job_name_(job_name),
max_outstanding_requests_(max_outstanding_requests), max_outstanding_requests_(max_outstanding_requests),
task_refresh_interval_ms_(task_refresh_interval_ms), task_refresh_interval_ms_(task_refresh_interval_ms),
iteration_counter_(iteration_counter),
owns_resource_(owns_resource),
iteration_counter_handle_(iteration_counter_handle),
resource_mgr_(ctx->resource_manager()),
output_types_(output_types), output_types_(output_types),
output_shapes_(output_shapes) {} output_shapes_(output_shapes) {}
~Dataset() override {
iteration_counter_->Unref();
if (owns_resource_) {
Status s = resource_mgr_->Delete<IterationCounter>(
iteration_counter_handle_.container(),
iteration_counter_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
}
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{ return absl::make_unique<Iterator>(
this, name_utils::IteratorPrefix(kDatasetType, prefix)}); Iterator::Params{this,
name_utils::IteratorPrefix(kDatasetType, prefix)},
iteration_counter_->GetAndIncrement());
} }
const DataTypeVector& output_dtypes() const override { return output_types_; } const DataTypeVector& output_dtypes() const override { return output_types_; }
@ -123,18 +148,26 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Node* protocol; Node* protocol;
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol)); TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
Node* job_name;
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
Node* max_outstanding_requests; Node* max_outstanding_requests;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests)); b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
Node* iteration_counter_handle = nullptr;
Tensor handle(DT_RESOURCE, TensorShape({}));
handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
AttrValue task_refresh_interval_hint_ms; AttrValue task_refresh_interval_hint_ms;
b->BuildAttrValue(task_refresh_interval_ms_, b->BuildAttrValue(task_refresh_interval_ms_,
&task_refresh_interval_hint_ms); &task_refresh_interval_hint_ms);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
b->AddDataset(this, b->AddDataset(this,
{dataset_id, processing_mode, address, protocol, {dataset_id, processing_mode, address, protocol, job_name,
max_outstanding_requests}, max_outstanding_requests, iteration_counter_handle},
{std::make_pair(kTaskRefreshIntervalHintMs, {std::make_pair(kTaskRefreshIntervalHintMs,
task_refresh_interval_hint_ms)}, task_refresh_interval_hint_ms)},
output)); output));
@ -144,8 +177,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
explicit Iterator(const Params& params) explicit Iterator(const Params& params, int64 iterator_index)
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params), iterator_index_(iterator_index) {}
~Iterator() override { ~Iterator() override {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -161,8 +194,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
VLOG(3) << "Connecting to " << dataset()->address_ VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op"; << " in data service dataset op";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
if (dataset()->job_name_.empty()) {
TF_RETURN_IF_ERROR(master.CreateJob( TF_RETURN_IF_ERROR(master.CreateJob(
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
} else {
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
dataset()->dataset_id_, dataset()->processing_mode_,
dataset()->job_name_, iterator_index_, &job_id_));
}
VLOG(1) << "Created data service job with id " << job_id_; VLOG(1) << "Created data service job with id " << job_id_;
return Status::OK(); return Status::OK();
} }
@ -418,6 +457,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return Status::OK(); return Status::OK();
} }
const int64 iterator_index_;
mutex mu_; mutex mu_;
// TODO(aaudibert): split this into a couple cvs for different conditions // TODO(aaudibert): split this into a couple cvs for different conditions
// so that we can use notify_one and avoid unnecessary wakeups. // so that we can use notify_one and avoid unnecessary wakeups.
@ -449,8 +490,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
const ProcessingMode processing_mode_; const ProcessingMode processing_mode_;
const tstring address_; const tstring address_;
const tstring protocol_; const tstring protocol_;
const tstring job_name_;
const int64 max_outstanding_requests_; const int64 max_outstanding_requests_;
const int64 task_refresh_interval_ms_; const int64 task_refresh_interval_ms_;
IterationCounter* const iteration_counter_; // Owned
const bool owns_resource_;
const ResourceHandle iteration_counter_handle_;
ResourceMgr* const resource_mgr_; // Not owned
const DataTypeVector output_types_; const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_; const std::vector<PartialTensorShape> output_shapes_;
}; };
@ -488,9 +534,41 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(ctx, !protocol.empty(), OP_REQUIRES(ctx, !protocol.empty(),
errors::InvalidArgument(kProtocol, " must be non-empty.")); errors::InvalidArgument(kProtocol, " must be non-empty."));
tstring job_name;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
int64 max_outstanding_requests; int64 max_outstanding_requests;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests, OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
&max_outstanding_requests)); &max_outstanding_requests));
ResourceHandle iteration_counter_handle;
OP_REQUIRES_OK(
ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
IterationCounter* iteration_counter = nullptr;
Status s = ctx->resource_manager()->Lookup<IterationCounter>(
iteration_counter_handle.container(), iteration_counter_handle.name(),
&iteration_counter);
bool owns_resource = false;
if (errors::IsNotFound(s)) {
owns_resource = true;
static std::atomic<int64> resource_id_counter(0);
const std::string& container = ctx->resource_manager()->default_container();
std::string name =
strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
resource_id_counter.fetch_add(1));
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->LookupOrCreate<IterationCounter>(
container, name, &iteration_counter,
[](IterationCounter** counter) {
*counter = new IterationCounter();
return Status::OK();
}));
iteration_counter_handle =
MakeResourceHandle<IterationCounter>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
OP_REQUIRES( OP_REQUIRES(
ctx, ctx,
max_outstanding_requests == model::kAutotune || max_outstanding_requests == model::kAutotune ||
@ -499,13 +577,16 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
model::kAutotune)); model::kAutotune));
*output = *output =
new Dataset(ctx, dataset_id, processing_mode, address, protocol, new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name,
max_outstanding_requests, task_refresh_interval_hint_ms_, max_outstanding_requests, task_refresh_interval_hint_ms_,
iteration_counter, owns_resource, iteration_counter_handle,
output_types_, output_shapes_); output_types_, output_shapes_);
} }
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
DataServiceDatasetOp); DataServiceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
DummyResourceOp<IterationCounter>);
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,11 +15,34 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
// A resource which counts how many iterators have been created. This is used
// by the DataServiceDataset to coordinate jobs across multiple iterations.
class IterationCounter : public ResourceBase {
public:
IterationCounter() : counter_(0) {}
std::string DebugString() const override {
mutex_lock l(mu_);
return absl::StrCat(counter_);
}
int64 GetAndIncrement() {
mutex_lock l(mu_);
return ++counter_;
}
private:
mutable mutex mu_;
int64 counter_ TF_GUARDED_BY(mu_) = 0;
};
// Creates a dataset for reading from the tf.data service. // Creates a dataset for reading from the tf.data service.
class DataServiceDatasetOp : public DatasetOpKernel { class DataServiceDatasetOp : public DatasetOpKernel {
public: public:
@ -28,10 +51,12 @@ class DataServiceDatasetOp : public DatasetOpKernel {
static constexpr const char* const kProcessingMode = "processing_mode"; static constexpr const char* const kProcessingMode = "processing_mode";
static constexpr const char* const kAddress = "address"; static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol"; static constexpr const char* const kProtocol = "protocol";
static constexpr const char* const kJobName = "job_name";
static constexpr const char* const kMaxOutstandingRequests = static constexpr const char* const kMaxOutstandingRequests =
"max_outstanding_requests"; "max_outstanding_requests";
static constexpr const char* const kTaskRefreshIntervalHintMs = static constexpr const char* const kTaskRefreshIntervalHintMs =
"task_refresh_interval_hint_ms"; "task_refresh_interval_hint_ms";
static constexpr const char* const kIterationCounter = "iteration_counter";
static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes"; static constexpr const char* const kOutputShapes = "output_shapes";

View File

@ -1,47 +0,0 @@
op {
name: "DataServiceDataset"
input_arg {
name: "dataset_id"
type: DT_INT64
}
input_arg {
name: "processing_mode"
type: DT_STRING
}
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "protocol"
type: DT_STRING
}
input_arg {
name: "max_outstanding_requests"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "task_refresh_interval_hint_ms"
type: "int"
default_value {
i: -1
}
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}

View File

@ -1037,12 +1037,21 @@ REGISTER_OP("ExperimentalUniqueDataset")
.Attr("output_shapes: list(shape) >= 1") .Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("DummyIterationCounter")
.Output("handle: resource")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("DataServiceDataset") REGISTER_OP("DataServiceDataset")
.Input("dataset_id: int64") .Input("dataset_id: int64")
.Input("processing_mode: string") .Input("processing_mode: string")
.Input("address: string") .Input("address: string")
.Input("protocol: string") .Input("protocol: string")
.Input("job_name: string")
.Input("max_outstanding_requests: int64") .Input("max_outstanding_requests: int64")
.Input("iteration_counter: resource")
.Output("handle: variant") .Output("handle: variant")
.Attr("task_refresh_interval_hint_ms: int = -1") .Attr("task_refresh_interval_hint_ms: int = -1")
.Attr("output_types: list(type) >= 1") .Attr("output_types: list(type) >= 1")

View File

@ -51,6 +51,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
processing_mode, processing_mode,
address, address,
protocol, protocol,
job_name=None,
max_outstanding_requests=None, max_outstanding_requests=None,
task_refresh_interval_hint_ms=None): task_refresh_interval_hint_ms=None):
"""Constructs a _DataServiceDatasetV2. """Constructs a _DataServiceDatasetV2.
@ -65,6 +66,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
address: The tf.data service address, e.g. "localhost:5000". address: The tf.data service address, e.g. "localhost:5000".
protocol: The protocol to use for communicating with the tf.data service, protocol: The protocol to use for communicating with the tf.data service,
e.g. "grpc". e.g. "grpc".
job_name: (Optional.) The name of the job. This argument makes it
possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the requested at the same time. You can use this option to control the
amount of memory used, since `distribute` won't use more than amount of memory used, since `distribute` won't use more than
@ -73,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
the master for task changes. the master for task changes.
""" """
if job_name is None:
job_name = ""
if max_outstanding_requests is None: if max_outstanding_requests is None:
max_outstanding_requests = dataset_ops.AUTOTUNE max_outstanding_requests = dataset_ops.AUTOTUNE
if task_refresh_interval_hint_ms is None: if task_refresh_interval_hint_ms is None:
@ -85,8 +91,11 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
processing_mode=processing_mode, processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
),
**self._flat_structure) **self._flat_structure)
super(_DataServiceDatasetV2, self).__init__(variant_tensor) super(_DataServiceDatasetV2, self).__init__(variant_tensor)
@ -100,7 +109,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
@functools.wraps(_DataServiceDatasetV2.__init__) @functools.wraps(_DataServiceDatasetV2.__init__)
def __init__(self, input_dataset, dataset_id, processing_mode, address, def __init__(self, input_dataset, dataset_id, processing_mode, address,
protocol, max_outstanding_requests, protocol, job_name, max_outstanding_requests,
task_refresh_interval_hint_ms): task_refresh_interval_hint_ms):
self._wrapped = _DataServiceDatasetV2( self._wrapped = _DataServiceDatasetV2(
@ -109,6 +118,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
processing_mode=processing_mode, processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
super(_DataServiceDatasetV1, self).__init__(self._wrapped) super(_DataServiceDatasetV1, self).__init__(self._wrapped)
@ -122,6 +132,7 @@ else:
def _distribute(processing_mode, def _distribute(processing_mode,
service, service,
job_name=None,
max_outstanding_requests=None, max_outstanding_requests=None,
task_refresh_interval_hint_ms=None): task_refresh_interval_hint_ms=None):
"""A transformation that moves dataset processing to the tf.data service. """A transformation that moves dataset processing to the tf.data service.
@ -136,6 +147,9 @@ def _distribute(processing_mode,
service: A string indicating how to connect to the tf.data service. The service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g. string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000. grpc://localhost:5000.
job_name: (Optional.) The name of the job. This argument makes it
possible for multiple datasets to share the same job. The default behavior
is that the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` * of memory used, since `distribute` won't use more than `element_size` *
@ -147,6 +161,12 @@ def _distribute(processing_mode,
Dataset: A `Dataset` of the elements produced by the data service. Dataset: A `Dataset` of the elements produced by the data service.
""" """
ProcessingMode.validate(processing_mode) ProcessingMode.validate(processing_mode)
if job_name is not None:
if not isinstance(job_name, six.string_types):
raise ValueError("job_name must be a string, but job_name was of type "
"{0}. job_name={1}".format(type(job_name), job_name))
if not job_name:
raise ValueError("job_name must not be empty")
if not isinstance(service, six.string_types): if not isinstance(service, six.string_types):
raise ValueError( raise ValueError(
"service must be a string, but service was of type {0}. service={1}" "service must be a string, but service was of type {0}. service={1}"
@ -182,41 +202,49 @@ def _distribute(processing_mode,
processing_mode=processing_mode, processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
return _apply_fn return _apply_fn
def distribute(processing_mode, service, max_outstanding_requests=None): def distribute(processing_mode,
service,
job_name=None,
max_outstanding_requests=None):
"""A transformation that moves dataset processing to the tf.data service. """A transformation that moves dataset processing to the tf.data service.
The `processing_mode` argument controls how data is processed by the When you iterate over a dataset containing the `distribute` transformation,
tf.data service. Currently, the only supported mode is "parallel_epochs". the tf.data service creates a "job" which produces data for the dataset
iteration.
The `processing_mode` argument controls what data is produced by a tf.data
service job. Currently, the only supported mode is "parallel_epochs".
processing_mode="parallel_epochs" means that multiple tf.data workers will processing_mode="parallel_epochs" means that multiple tf.data workers will
iterate through the dataset in parallel, each producing all elements of the iterate through the dataset in parallel, each producing all elements of the
dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker
used for execution will produce {0, 1, 2}. If there are 3 workers and one used for execution will produce {0, 1, 2}. If there are 3 workers, the job
consumer, the consumer will receive the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} will produce the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} (though not necessarily
(though not necessarily in that order). To account for this, it is recommended in that order). To account for this, it is recommended to randomly shuffle
to randomly shuffle your dataset, so that different tf.data workers will your dataset, so that different tf.data workers will iterate through the
iterate through the dataset in different orders. dataset in different orders.
In the future, there will be additional epoch modes. For example, In the future, there will be additional processing modes. For example,
a "one_epoch" mode which partitions the dataset across the tf.data a "one_epoch" mode which partitions the dataset across the tf.data
workers, so that the consumers see each element of the dataset only once. workers, so that the consumers see each element of the dataset only once.
``` ```
dataset = tf.data.Dataset.range(10) dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x*x) dataset = dataset.map(lambda x: x*x)
dataset = dataset.apply( dataset = dataset.apply(
tf.data.experimental.service.distribute("parallel_epochs", tf.data.experimental.service.distribute("parallel_epochs",
"grpc://dataservice:5000")) "grpc://dataservice:5000"))
dataset = dataset.map(lambda x: x+10) dataset = dataset.map(lambda x: x+1)
for element in dataset: for element in dataset:
# process element print(element) # prints { 1, 2, 5, 10, 17 }
``` ```
In the above example, the first two lines (before the call to `distribute`) In the above example, the first two lines (before the call to `distribute`)
@ -224,6 +252,33 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
RPC. The remaining transformations (after the call to `distribute`) will be RPC. The remaining transformations (after the call to `distribute`) will be
executed locally. executed locally.
The `job_name` argument allows jobs to be shared across multiple
datasets. Instead of each dataset creating its own job, all datasets with the
same `job_name` will consume from the same job. A new job will
be created for each iteration of the dataset (with each repetition of
`Dataset.repeat` counting as a new iteration). The following example
demonstrates shared iteration, with the assumption that the tf.data service is
running with a single worker.
```
range5_dataset = tf.data.Dataset.range(5)
dataset1 = range5_dataset.apply(tf.data.experimental.service.distribute(
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
dataset2 = range5_dataset.apply(tf.data.experimental.service.distribute(
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
iter_1_1 = iter(dataset1)
iter_1_2 = iter(dataset1)
iter_2_1 = iter(dataset2)
iter_2_2 = iter(dataset2)
print(next(iter_1_1)) # Prints "0"
# iter_1_2 consumes from the same job as iter_1_1
print(next(iter_1_2)) # Prints "1"
# iter_2_1 consumes from a new job
print(next(iter_2_1)) # Prints "0"
# iter_2_2 consumes from the same job as iter_2_1
print(next(iter_2_2)) # Prints "1"
```
Args: Args:
processing_mode: A string specifying the policy for how data should be processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is processed by tf.data workers. Currently, the only supported value is
@ -231,6 +286,9 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
service: A string indicating how to connect to the tf.data service. The service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g. string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000. grpc://localhost:5000.
job_name: (Optional.) The name of the job. This argument makes it possible
for multiple datasets to share the same job. The default behavior is that
the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` * of memory used, since `distribute` won't use more than `element_size` *
@ -239,4 +297,5 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
Returns: Returns:
Dataset: A `Dataset` of the elements produced by the data service. Dataset: A `Dataset` of the elements produced by the data service.
""" """
return _distribute(processing_mode, service, max_outstanding_requests) return _distribute(processing_mode, service, job_name,
max_outstanding_requests)

View File

@ -37,11 +37,14 @@ from tensorflow.python.platform import test
PROTOCOL = "grpc" PROTOCOL = "grpc"
def _make_distributed_dataset(dataset, service): def _make_distributed_dataset(dataset, service, job_name=None):
"""Creates a distributed dataset with a short task refresh interval.""" """Creates a distributed dataset with a short task refresh interval."""
return dataset.apply( return dataset.apply(
data_service_ops._distribute( data_service_ops._distribute(
"parallel_epochs", service, task_refresh_interval_hint_ms=20)) "parallel_epochs",
service,
job_name=job_name,
task_refresh_interval_hint_ms=20))
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@ -233,6 +236,72 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
result = list(f().numpy()) result = list(f().numpy())
self.assertCountEqual(num_workers * list(range(num_elements)), result) self.assertCountEqual(num_workers * list(range(num_elements)), result)
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobName(self):
num_elements = 10
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
iter1 = iter(ds1)
iter2 = iter(ds2)
results = []
for _ in range(3):
results.append(next(iter1).numpy())
results.append(next(iter2).numpy())
for elem in iter1:
results.append(elem.numpy())
for elem in iter2:
results.append(elem.numpy())
self.assertCountEqual(list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations())
def testDifferentJobNames(self):
num_elements = 10
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, service, job_name="job_name1")
ds2 = _make_distributed_dataset(ds, service, job_name="job_name2")
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameMultiIteration(self):
num_elements = 10
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
# iteration 1
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, [])
# iteration 2
self.assertDatasetProduces(ds2, list(range(num_elements)))
self.assertDatasetProduces(ds1, [])
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameRepeat(self):
num_elements = 10
num_repetitions = 3
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
ds1 = ds1.repeat(num_repetitions)
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
ds2 = ds2.repeat(num_repetitions)
results = []
iter1 = iter(ds1)
iter2 = iter(ds2)
for _ in range(((num_elements * num_repetitions) // 2) - 1):
results.append(next(iter1).numpy())
for _ in range(((num_elements * num_repetitions) // 2) - 1):
results.append(next(iter2).numpy())
for elem in iter1:
results.append(elem.numpy())
for elem in iter2:
results.append(elem.numpy())
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
def run_stateful(self, external_state_policy): def run_stateful(self, external_state_policy):
num_elements = 10 num_elements = 10
ds = dataset_ops.Dataset.range(num_elements).map( ds = dataset_ops.Dataset.range(num_elements).map(

View File

@ -934,7 +934,7 @@ tf_module {
} }
member_method { member_method {
name: "DataServiceDataset" name: "DataServiceDataset"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
} }
member_method { member_method {
name: "DatasetCardinality" name: "DatasetCardinality"
@ -1176,6 +1176,10 @@ tf_module {
name: "DrawBoundingBoxesV2" name: "DrawBoundingBoxesV2"
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DummyIterationCounter"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "DummyMemoryCache" name: "DummyMemoryCache"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -934,7 +934,7 @@ tf_module {
} }
member_method { member_method {
name: "DataServiceDataset" name: "DataServiceDataset"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
} }
member_method { member_method {
name: "DatasetCardinality" name: "DatasetCardinality"
@ -1176,6 +1176,10 @@ tf_module {
name: "DrawBoundingBoxesV2" name: "DrawBoundingBoxesV2"
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DummyIterationCounter"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "DummyMemoryCache" name: "DummyMemoryCache"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "