[tf.data service] Add support for shared job names.
Shared job names give users a way to share tf.data service job output across multiple datasets. PiperOrigin-RevId: 310037762 Change-Id: I5a8f9130806a361ada46b3f0b62ea71b0f9b2e8e
This commit is contained in:
parent
48baba71cf
commit
22546b562d
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DummyIterationCounter"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -56,6 +56,7 @@ cc_library(
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
":credentials_factory",
|
||||
":data_service",
|
||||
":grpc_util",
|
||||
":master_proto_cc",
|
||||
":worker_cc_grpc_proto",
|
||||
|
@ -31,7 +31,7 @@ constexpr const char kParallelEpochs[] = "parallel_epochs";
|
||||
constexpr const char kOneEpoch[] = "one_epoch";
|
||||
} // namespace
|
||||
|
||||
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) {
|
||||
Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) {
|
||||
if (s == kParallelEpochs) {
|
||||
*mode = ProcessingMode::PARALLEL_EPOCHS;
|
||||
} else if (s == kOneEpoch) {
|
||||
@ -54,6 +54,21 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
||||
}
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||
int64* dataset_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return grpc_util::WrapError("Failed to register dataset", status);
|
||||
}
|
||||
*dataset_id = resp.dataset_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
int64* job_id) {
|
||||
@ -73,18 +88,27 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||
int64* dataset_id) {
|
||||
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
const std::string& job_name,
|
||||
int job_name_index,
|
||||
int64* job_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
GetOrCreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
req.set_processing_mode(ProcessingModeDef(processing_mode));
|
||||
req.set_job_name(job_name);
|
||||
req.set_job_name_index(job_name_index);
|
||||
GetOrCreateJobResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return grpc_util::WrapError("Failed to register dataset", status);
|
||||
return grpc_util::WrapError(
|
||||
absl::StrCat("Failed to get or create job for dataset with id ",
|
||||
dataset_id),
|
||||
status);
|
||||
}
|
||||
*dataset_id = resp.dataset_id();
|
||||
*job_id = resp.job_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -148,7 +172,7 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
||||
}
|
||||
|
||||
Status CreateDataServiceMasterClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
@ -157,7 +181,7 @@ Status CreateDataServiceMasterClient(
|
||||
}
|
||||
|
||||
Status CreateDataServiceWorkerClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>* out) {
|
||||
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
|
@ -35,7 +35,7 @@ enum class ProcessingMode : int64 {
|
||||
|
||||
// Parses a string representing a processing mode and stores the result in
|
||||
// *mode. Returns an InvalidArgument status if the string is not recognized.
|
||||
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode);
|
||||
Status ParseProcessingMode(const std::string& s, ProcessingMode* mode);
|
||||
|
||||
// Converts a processing mode to its corresponding string.
|
||||
std::string ProcessingModeToString(ProcessingMode mode);
|
||||
@ -45,7 +45,7 @@ std::string ProcessingModeToString(ProcessingMode mode);
|
||||
// threads.
|
||||
class DataServiceClientBase {
|
||||
public:
|
||||
DataServiceClientBase(absl::string_view address, absl::string_view protocol)
|
||||
DataServiceClientBase(const std::string& address, const std::string& protocol)
|
||||
: address_(address), protocol_(protocol) {}
|
||||
|
||||
virtual ~DataServiceClientBase() = default;
|
||||
@ -70,7 +70,8 @@ class DataServiceClientBase {
|
||||
// Client for communicating with the tf.data service master.
|
||||
class DataServiceMasterClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceMasterClient(absl::string_view address, absl::string_view protocol)
|
||||
DataServiceMasterClient(const std::string& address,
|
||||
const std::string& protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Registers a dataset with the tf.data service, and stores the generated
|
||||
@ -82,6 +83,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
int64* job_id);
|
||||
|
||||
// Gets the job id for the job represented by the tuple
|
||||
// (job_name, job_name_index), and stores the id in *job_id. If the
|
||||
// job doesn't exist yet, it will be created.
|
||||
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
const std::string& job_name, int job_name_index,
|
||||
int64* job_id);
|
||||
|
||||
// Queries the master for the tasks associated with the specified job.
|
||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||
// be stored in `*job_finished`.
|
||||
@ -98,7 +106,8 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
||||
// Client for communicating with the tf.data service worker.
|
||||
class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceWorkerClient(absl::string_view address, absl::string_view protocol)
|
||||
DataServiceWorkerClient(const std::string& address,
|
||||
const std::string& protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Fetches the next element for the specified task_id. The element's
|
||||
@ -116,12 +125,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
|
||||
// Creates and initializes a new tf.data service master client.
|
||||
Status CreateDataServiceMasterClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out);
|
||||
|
||||
// Creates and initializes a new tf.data service worker client.
|
||||
Status CreateDataServiceWorkerClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>* out);
|
||||
|
||||
} // namespace data
|
||||
|
@ -42,6 +42,7 @@ HANDLER(RegisterWorker);
|
||||
HANDLER(WorkerUpdate);
|
||||
HANDLER(GetOrRegisterDataset);
|
||||
HANDLER(CreateJob);
|
||||
HANDLER(GetOrCreateJob);
|
||||
HANDLER(GetTasks);
|
||||
#undef HANDLER
|
||||
|
||||
|
@ -46,6 +46,7 @@ class GrpcMasterImpl : public MasterService::Service {
|
||||
HANDLER(WorkerUpdate);
|
||||
HANDLER(GetOrRegisterDataset);
|
||||
HANDLER(CreateJob);
|
||||
HANDLER(GetOrCreateJob);
|
||||
HANDLER(GetTasks);
|
||||
#undef HANDLER
|
||||
|
||||
|
@ -56,7 +56,24 @@ message CreateJobRequest {
|
||||
}
|
||||
|
||||
message CreateJobResponse {
|
||||
// An id for the begun job.
|
||||
// An id for the created job.
|
||||
int64 job_id = 1;
|
||||
}
|
||||
|
||||
message GetOrCreateJobRequest {
|
||||
// The id of the dataset to create a job for.
|
||||
int64 dataset_id = 1;
|
||||
// A mode controlling how the tf.data service produces data for the job.
|
||||
ProcessingModeDef processing_mode = 2;
|
||||
// A name for the job.
|
||||
string job_name = 3;
|
||||
// An index for the job. Multiple jobs can be created for the same name, if
|
||||
// they have different indices.
|
||||
int64 job_name_index = 4;
|
||||
}
|
||||
|
||||
message GetOrCreateJobResponse {
|
||||
// The id of the (potentially newly created) job.
|
||||
int64 job_id = 1;
|
||||
}
|
||||
|
||||
@ -96,6 +113,9 @@ service MasterService {
|
||||
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
|
||||
returns (GetOrRegisterDatasetResponse);
|
||||
|
||||
// Gets a job if it already exists, otherwise creates it.
|
||||
rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse);
|
||||
|
||||
// Creates a job for reading from the tf.data service.
|
||||
rpc CreateJob(CreateJobRequest) returns (CreateJobResponse);
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
@ -65,17 +66,17 @@ Status DataServiceMasterImpl::RegisterWorker(
|
||||
|
||||
// Allocate tasks to the worker.
|
||||
for (auto& entry : jobs_) {
|
||||
Job& job = entry.second;
|
||||
if (job.finished()) {
|
||||
std::shared_ptr<Job> job = entry.second;
|
||||
if (job->finished()) {
|
||||
continue;
|
||||
}
|
||||
int64 task_id = CreateTask(&job, request->worker_address());
|
||||
int64 task_id = CreateTask(job.get(), request->worker_address());
|
||||
|
||||
TaskDef* task_def = response->add_tasks();
|
||||
*task_def->mutable_dataset() =
|
||||
datasets_by_id_[job.dataset_id()]->dataset_def();
|
||||
task_def->set_dataset_id(job.dataset_id());
|
||||
task_def->set_job_id(job.job_id());
|
||||
datasets_by_id_[job->dataset_id()]->dataset_def();
|
||||
task_def->set_dataset_id(job->dataset_id());
|
||||
task_def->set_job_id(job->job_id());
|
||||
task_def->set_task_id(task_id);
|
||||
}
|
||||
|
||||
@ -96,7 +97,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
||||
if (update.completed()) {
|
||||
int64 job_id = tasks_.at(task_id).job_id();
|
||||
DCHECK(jobs_.contains(job_id));
|
||||
jobs_.at(job_id).task_finished(task_id);
|
||||
jobs_.at(job_id)->task_finished(task_id);
|
||||
VLOG(3) << "Task " << task_id << " from job " << job_id << " completed";
|
||||
}
|
||||
}
|
||||
@ -135,49 +136,109 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
||||
DCHECK(!datasets_by_id_.contains(dataset_id));
|
||||
datasets_by_id_[dataset_id] = new_dataset;
|
||||
DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
|
||||
datasets_by_fingerprint_[dataset_id] = new_dataset;
|
||||
datasets_by_fingerprint_[fingerprint] = new_dataset;
|
||||
return dataset_id;
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||
CreateJobResponse* response) {
|
||||
VLOG(3) << "Received begin job request for dataset id "
|
||||
VLOG(3) << "Received create job request for dataset id "
|
||||
<< request->dataset_id();
|
||||
switch (request->processing_mode()) {
|
||||
case PARALLEL_EPOCHS:
|
||||
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
||||
mutex_lock l(mu_);
|
||||
int64 job_id;
|
||||
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
|
||||
absl::optional<std::string>(), &job_id));
|
||||
response->set_job_id(job_id);
|
||||
|
||||
VLOG(3) << "Creating job " << job_id << " for dataset "
|
||||
<< request->dataset_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::GetOrCreateJob(
|
||||
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
||||
VLOG(3) << "Received get or create job request for dataset id "
|
||||
<< request->dataset_id() << " with name " << request->job_name()
|
||||
<< " and index " << request->job_name_index();
|
||||
mutex_lock l(mu_);
|
||||
NamedJobKey key(request->job_name(), request->job_name_index());
|
||||
ProcessingMode requested_processing_mode =
|
||||
ProcessingMode(request->processing_mode());
|
||||
std::shared_ptr<Job>* job = gtl::FindOrNull(named_jobs_, key);
|
||||
if (job != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
|
||||
request->dataset_id()));
|
||||
response->set_job_id((*job)->job_id());
|
||||
return Status::OK();
|
||||
}
|
||||
int64 job_id;
|
||||
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), requested_processing_mode,
|
||||
request->job_name(), &job_id));
|
||||
named_jobs_[key] = jobs_[job_id];
|
||||
response->set_job_id(job_id);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates that the job matches the given processing_mode and dataset_id.
|
||||
Status DataServiceMasterImpl::ValidateMatchingJob(
|
||||
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
||||
DCHECK(job.name().has_value());
|
||||
std::string job_name = job.name().value();
|
||||
if (job.processing_mode() != processing_mode) {
|
||||
std::string requested = ProcessingModeToString(processing_mode);
|
||||
std::string actual = ProcessingModeToString(job.processing_mode());
|
||||
return errors::FailedPrecondition(
|
||||
"Found a job with name ", job_name, ", but the processing mode <",
|
||||
actual, "> doesn't match the requested processing mode <", requested,
|
||||
">.");
|
||||
}
|
||||
if (job.dataset_id() != dataset_id) {
|
||||
return errors::FailedPrecondition(
|
||||
"Found a job with name ", job_name, ", but the dataset id <",
|
||||
job.dataset_id(), "> doesn't match the requested dataset id <",
|
||||
dataset_id, ">.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name,
|
||||
int64* out_job_id)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
switch (processing_mode) {
|
||||
case ProcessingMode::PARALLEL_EPOCHS:
|
||||
break;
|
||||
case ONE_EPOCH:
|
||||
case ProcessingMode::ONE_EPOCH:
|
||||
return errors::Unimplemented(
|
||||
"CreateJob only supports the PARALLEL_EPOCHS job mode. "
|
||||
"ONE_EPOCH is not currently supported.");
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
"ProcessingMode ", request->processing_mode(), " not recognized");
|
||||
return errors::Unimplemented("ProcessingMode ",
|
||||
ProcessingModeToString(processing_mode),
|
||||
" not recognized");
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
if (!datasets_by_id_.contains(request->dataset_id())) {
|
||||
return errors::NotFound("CreateJob failed. Dataset id: <",
|
||||
request->dataset_id(), "> not found.");
|
||||
if (!datasets_by_id_.contains(dataset_id)) {
|
||||
return errors::NotFound("Dataset id: <", dataset_id, "> not found.");
|
||||
}
|
||||
|
||||
int64 job_id = next_job_id_++;
|
||||
DCHECK(!jobs_.contains(job_id));
|
||||
auto result =
|
||||
jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id),
|
||||
std::forward_as_tuple(job_id, request->dataset_id()));
|
||||
DCHECK(result.second);
|
||||
Job& job = result.first->second;
|
||||
response->set_job_id(job_id);
|
||||
auto job =
|
||||
std::make_shared<Job>(job_id, dataset_id, processing_mode, job_name);
|
||||
jobs_[job_id] = job;
|
||||
|
||||
for (auto& worker : workers_) {
|
||||
int64 task_id = CreateTask(&job, worker.address());
|
||||
int64 task_id = CreateTask(job.get(), worker.address());
|
||||
|
||||
// TODO(aaudibert): perform these calls asynchronously.
|
||||
// TODO(aaudibert): clean up in case some calls succeed, but later calls
|
||||
// fail
|
||||
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
|
||||
}
|
||||
|
||||
VLOG(3) << "Beginning job " << job_id << " for dataset "
|
||||
<< request->dataset_id();
|
||||
*out_job_id = job_id;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -233,8 +294,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
||||
return errors::NotFound("GetTasks failed. Job id <", request->job_id(),
|
||||
"> not found.");
|
||||
}
|
||||
Job& job = it->second;
|
||||
for (const auto& task_id : job.task_ids()) {
|
||||
std::shared_ptr<Job> job = it->second;
|
||||
for (const auto& task_id : job->task_ids()) {
|
||||
auto task_iter = tasks_.find(task_id);
|
||||
DCHECK(task_iter != tasks_.end());
|
||||
Task& task = task_iter->second;
|
||||
@ -242,7 +303,7 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
||||
task_info->set_worker_address(task.worker_address());
|
||||
task_info->set_id(task.task_id());
|
||||
}
|
||||
response->set_job_finished(job.finished());
|
||||
response->set_job_finished(job->finished());
|
||||
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
|
||||
<< request->job_id();
|
||||
return Status::OK();
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -56,6 +57,8 @@ class DataServiceMasterImpl {
|
||||
GetOrRegisterDatasetResponse* response);
|
||||
Status CreateJob(const CreateJobRequest* request,
|
||||
CreateJobResponse* response);
|
||||
Status GetOrCreateJob(const GetOrCreateJobRequest* request,
|
||||
GetOrCreateJobResponse* response);
|
||||
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
|
||||
|
||||
private:
|
||||
@ -100,11 +103,17 @@ class DataServiceMasterImpl {
|
||||
|
||||
class Job {
|
||||
public:
|
||||
Job(int64 job_id, int64 dataset_id)
|
||||
: job_id_(job_id), dataset_id_(dataset_id) {}
|
||||
Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<absl::string_view> job_name)
|
||||
: job_id_(job_id),
|
||||
dataset_id_(dataset_id),
|
||||
processing_mode_(processing_mode),
|
||||
job_name_(job_name) {}
|
||||
|
||||
int64 job_id() const { return job_id_; }
|
||||
int64 dataset_id() const { return dataset_id_; }
|
||||
ProcessingMode processing_mode() const { return processing_mode_; }
|
||||
absl::optional<std::string> name() const { return job_name_; }
|
||||
const std::vector<int64>& task_ids() const { return task_ids_; }
|
||||
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
|
||||
void task_finished(int64 task_id) {
|
||||
@ -118,11 +127,32 @@ class DataServiceMasterImpl {
|
||||
private:
|
||||
const int64 job_id_;
|
||||
const int64 dataset_id_;
|
||||
const ProcessingMode processing_mode_;
|
||||
const absl::optional<std::string> job_name_;
|
||||
std::vector<int64> task_ids_;
|
||||
std::vector<int64> finished_tasks_;
|
||||
bool finished_ = false;
|
||||
};
|
||||
|
||||
class NamedJobKey {
|
||||
public:
|
||||
NamedJobKey(absl::string_view name, int64 index)
|
||||
: name_(name), index_(index) {}
|
||||
|
||||
friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
|
||||
return lhs.name_ == rhs.name_ && lhs.index_ == rhs.index_;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const NamedJobKey& k) {
|
||||
return H::combine(std::move(h), k.name_, k.index_);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
const int64 index_;
|
||||
};
|
||||
|
||||
class Task {
|
||||
public:
|
||||
Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||
@ -150,9 +180,15 @@ class DataServiceMasterImpl {
|
||||
Status EnsureWorkerStubInitialized(Worker* worker);
|
||||
// Instructs a worker to begin processing a task.
|
||||
Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
|
||||
// Creates a job and stores its job_id in `*job_id`.
|
||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name, int64* out_job_id);
|
||||
// Creates a new task for a job, returning the new task's id.
|
||||
int64 CreateTask(Job* job, const std::string& worker_address);
|
||||
|
||||
// Validates that an existing job matches the given processing_mode and
|
||||
// dataset_id, returning an error status describing any difference.
|
||||
Status ValidateMatchingJob(const Job& job, ProcessingMode processing_mode,
|
||||
int64 dataset_id);
|
||||
// Protocol to use for communicating with workers.
|
||||
const std::string protocol_;
|
||||
|
||||
@ -172,9 +208,13 @@ class DataServiceMasterImpl {
|
||||
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> datasets_by_fingerprint_
|
||||
TF_GUARDED_BY(mu_);
|
||||
// Information about jobs, keyed by job ids.
|
||||
absl::flat_hash_map<int64, Job> jobs_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_ TF_GUARDED_BY(mu_);
|
||||
// Information about tasks, keyed by task ids.
|
||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||
// Named jobs, keyed by their names and indices. Not all jobs have names, so
|
||||
// this is a subset of the jobs stored in `jobs_`.
|
||||
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
||||
};
|
||||
|
@ -47,8 +47,11 @@ namespace data {
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
|
||||
/* static */ constexpr const char* const
|
||||
DataServiceDatasetOp::kMaxOutstandingRequests;
|
||||
/* static */ constexpr const char* const
|
||||
DataServiceDatasetOp::kIterationCounter;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
|
||||
|
||||
@ -71,23 +74,45 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, int64 dataset_id,
|
||||
ProcessingMode processing_mode, const std::string& address,
|
||||
const std::string& protocol, int64 max_outstanding_requests,
|
||||
int64 task_refresh_interval_ms, const DataTypeVector& output_types,
|
||||
const std::string& protocol, const std::string& job_name,
|
||||
int64 max_outstanding_requests, int64 task_refresh_interval_ms,
|
||||
IterationCounter* iteration_counter, bool owns_resource,
|
||||
ResourceHandle iteration_counter_handle,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
dataset_id_(dataset_id),
|
||||
processing_mode_(processing_mode),
|
||||
address_(address),
|
||||
protocol_(protocol),
|
||||
job_name_(job_name),
|
||||
max_outstanding_requests_(max_outstanding_requests),
|
||||
task_refresh_interval_ms_(task_refresh_interval_ms),
|
||||
iteration_counter_(iteration_counter),
|
||||
owns_resource_(owns_resource),
|
||||
iteration_counter_handle_(iteration_counter_handle),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {}
|
||||
|
||||
~Dataset() override {
|
||||
iteration_counter_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<IterationCounter>(
|
||||
iteration_counter_handle_.container(),
|
||||
iteration_counter_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this,
|
||||
name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
iteration_counter_->GetAndIncrement());
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
@ -123,18 +148,26 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
Node* protocol;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
|
||||
|
||||
Node* job_name;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
|
||||
|
||||
Node* max_outstanding_requests;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
|
||||
|
||||
Node* iteration_counter_handle = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
|
||||
|
||||
AttrValue task_refresh_interval_hint_ms;
|
||||
b->BuildAttrValue(task_refresh_interval_ms_,
|
||||
&task_refresh_interval_hint_ms);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this,
|
||||
{dataset_id, processing_mode, address, protocol,
|
||||
max_outstanding_requests},
|
||||
{dataset_id, processing_mode, address, protocol, job_name,
|
||||
max_outstanding_requests, iteration_counter_handle},
|
||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||
task_refresh_interval_hint_ms)},
|
||||
output));
|
||||
@ -144,8 +177,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
explicit Iterator(const Params& params, int64 iterator_index)
|
||||
: DatasetIterator<Dataset>(params), iterator_index_(iterator_index) {}
|
||||
|
||||
~Iterator() override {
|
||||
mutex_lock l(mu_);
|
||||
@ -161,8 +194,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
VLOG(3) << "Connecting to " << dataset()->address_
|
||||
<< " in data service dataset op";
|
||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||
if (dataset()->job_name_.empty()) {
|
||||
TF_RETURN_IF_ERROR(master.CreateJob(
|
||||
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
|
||||
dataset()->dataset_id_, dataset()->processing_mode_,
|
||||
dataset()->job_name_, iterator_index_, &job_id_));
|
||||
}
|
||||
VLOG(1) << "Created data service job with id " << job_id_;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -418,6 +457,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const int64 iterator_index_;
|
||||
|
||||
mutex mu_;
|
||||
// TODO(aaudibert): split this into a couple cvs for different conditions
|
||||
// so that we can use notify_one and avoid unnecessary wakeups.
|
||||
@ -449,8 +490,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
const ProcessingMode processing_mode_;
|
||||
const tstring address_;
|
||||
const tstring protocol_;
|
||||
const tstring job_name_;
|
||||
const int64 max_outstanding_requests_;
|
||||
const int64 task_refresh_interval_ms_;
|
||||
IterationCounter* const iteration_counter_; // Owned
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle iteration_counter_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
@ -488,9 +534,41 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
OP_REQUIRES(ctx, !protocol.empty(),
|
||||
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
||||
|
||||
tstring job_name;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
|
||||
|
||||
int64 max_outstanding_requests;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
|
||||
&max_outstanding_requests));
|
||||
|
||||
ResourceHandle iteration_counter_handle;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
|
||||
IterationCounter* iteration_counter = nullptr;
|
||||
Status s = ctx->resource_manager()->Lookup<IterationCounter>(
|
||||
iteration_counter_handle.container(), iteration_counter_handle.name(),
|
||||
&iteration_counter);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const std::string& container = ctx->resource_manager()->default_container();
|
||||
std::string name =
|
||||
strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<IterationCounter>(
|
||||
container, name, &iteration_counter,
|
||||
[](IterationCounter** counter) {
|
||||
*counter = new IterationCounter();
|
||||
return Status::OK();
|
||||
}));
|
||||
iteration_counter_handle =
|
||||
MakeResourceHandle<IterationCounter>(ctx, container, name);
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
max_outstanding_requests == model::kAutotune ||
|
||||
@ -499,13 +577,16 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
model::kAutotune));
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, dataset_id, processing_mode, address, protocol,
|
||||
new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name,
|
||||
max_outstanding_requests, task_refresh_interval_hint_ms_,
|
||||
iteration_counter, owns_resource, iteration_counter_handle,
|
||||
output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
||||
DataServiceDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
|
||||
DummyResourceOp<IterationCounter>);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -15,11 +15,34 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// A resource which counts how many iterators have been created. This is used
|
||||
// by the DataServiceDataset to coordinate jobs across multiple iterations.
|
||||
class IterationCounter : public ResourceBase {
|
||||
public:
|
||||
IterationCounter() : counter_(0) {}
|
||||
|
||||
std::string DebugString() const override {
|
||||
mutex_lock l(mu_);
|
||||
return absl::StrCat(counter_);
|
||||
}
|
||||
|
||||
int64 GetAndIncrement() {
|
||||
mutex_lock l(mu_);
|
||||
return ++counter_;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
int64 counter_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
// Creates a dataset for reading from the tf.data service.
|
||||
class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
@ -28,10 +51,12 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
static constexpr const char* const kProcessingMode = "processing_mode";
|
||||
static constexpr const char* const kAddress = "address";
|
||||
static constexpr const char* const kProtocol = "protocol";
|
||||
static constexpr const char* const kJobName = "job_name";
|
||||
static constexpr const char* const kMaxOutstandingRequests =
|
||||
"max_outstanding_requests";
|
||||
static constexpr const char* const kTaskRefreshIntervalHintMs =
|
||||
"task_refresh_interval_hint_ms";
|
||||
static constexpr const char* const kIterationCounter = "iteration_counter";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
|
||||
|
@ -1,47 +0,0 @@
|
||||
op {
|
||||
name: "DataServiceDataset"
|
||||
input_arg {
|
||||
name: "dataset_id"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "processing_mode"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "address"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "protocol"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "max_outstanding_requests"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "task_refresh_interval_hint_ms"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: -1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -1037,12 +1037,21 @@ REGISTER_OP("ExperimentalUniqueDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("DummyIterationCounter")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("DataServiceDataset")
|
||||
.Input("dataset_id: int64")
|
||||
.Input("processing_mode: string")
|
||||
.Input("address: string")
|
||||
.Input("protocol: string")
|
||||
.Input("job_name: string")
|
||||
.Input("max_outstanding_requests: int64")
|
||||
.Input("iteration_counter: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("task_refresh_interval_hint_ms: int = -1")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
|
@ -51,6 +51,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
processing_mode,
|
||||
address,
|
||||
protocol,
|
||||
job_name=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
"""Constructs a _DataServiceDatasetV2.
|
||||
@ -65,6 +66,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
address: The tf.data service address, e.g. "localhost:5000".
|
||||
protocol: The protocol to use for communicating with the tf.data service,
|
||||
e.g. "grpc".
|
||||
job_name: (Optional.) The name of the job. This argument makes it
|
||||
possible for multiple datasets to share the same job. The default
|
||||
behavior is that the dataset creates anonymous, exclusively owned jobs.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the
|
||||
amount of memory used, since `distribute` won't use more than
|
||||
@ -73,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
the master for task changes.
|
||||
"""
|
||||
|
||||
if job_name is None:
|
||||
job_name = ""
|
||||
if max_outstanding_requests is None:
|
||||
max_outstanding_requests = dataset_ops.AUTOTUNE
|
||||
if task_refresh_interval_hint_ms is None:
|
||||
@ -85,8 +91,11 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
processing_mode=processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
job_name=job_name,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
|
||||
),
|
||||
**self._flat_structure)
|
||||
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@ -100,7 +109,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
||||
|
||||
@functools.wraps(_DataServiceDatasetV2.__init__)
|
||||
def __init__(self, input_dataset, dataset_id, processing_mode, address,
|
||||
protocol, max_outstanding_requests,
|
||||
protocol, job_name, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms):
|
||||
|
||||
self._wrapped = _DataServiceDatasetV2(
|
||||
@ -109,6 +118,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
||||
processing_mode=processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
job_name=job_name,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
||||
@ -122,6 +132,7 @@ else:
|
||||
|
||||
def _distribute(processing_mode,
|
||||
service,
|
||||
job_name=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
@ -136,6 +147,9 @@ def _distribute(processing_mode,
|
||||
service: A string indicating how to connect to the tf.data service. The
|
||||
string should be in the format <protocol>://<address>, e.g.
|
||||
grpc://localhost:5000.
|
||||
job_name: (Optional.) The name of the job. This argument makes it
|
||||
possible for multiple datasets to share the same job. The default behavior
|
||||
is that the dataset creates anonymous, exclusively owned jobs.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -147,6 +161,12 @@ def _distribute(processing_mode,
|
||||
Dataset: A `Dataset` of the elements produced by the data service.
|
||||
"""
|
||||
ProcessingMode.validate(processing_mode)
|
||||
if job_name is not None:
|
||||
if not isinstance(job_name, six.string_types):
|
||||
raise ValueError("job_name must be a string, but job_name was of type "
|
||||
"{0}. job_name={1}".format(type(job_name), job_name))
|
||||
if not job_name:
|
||||
raise ValueError("job_name must not be empty")
|
||||
if not isinstance(service, six.string_types):
|
||||
raise ValueError(
|
||||
"service must be a string, but service was of type {0}. service={1}"
|
||||
@ -182,41 +202,49 @@ def _distribute(processing_mode,
|
||||
processing_mode=processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
job_name=job_name,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
def distribute(processing_mode, service, max_outstanding_requests=None):
|
||||
def distribute(processing_mode,
|
||||
service,
|
||||
job_name=None,
|
||||
max_outstanding_requests=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
|
||||
The `processing_mode` argument controls how data is processed by the
|
||||
tf.data service. Currently, the only supported mode is "parallel_epochs".
|
||||
When you iterate over a dataset containing the `distribute` transformation,
|
||||
the tf.data service creates a "job" which produces data for the dataset
|
||||
iteration.
|
||||
|
||||
The `processing_mode` argument controls what data is produced by a tf.data
|
||||
service job. Currently, the only supported mode is "parallel_epochs".
|
||||
|
||||
processing_mode="parallel_epochs" means that multiple tf.data workers will
|
||||
iterate through the dataset in parallel, each producing all elements of the
|
||||
dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker
|
||||
used for execution will produce {0, 1, 2}. If there are 3 workers and one
|
||||
consumer, the consumer will receive the elements {0, 0, 0, 1, 1, 1, 2, 2, 2}
|
||||
(though not necessarily in that order). To account for this, it is recommended
|
||||
to randomly shuffle your dataset, so that different tf.data workers will
|
||||
iterate through the dataset in different orders.
|
||||
used for execution will produce {0, 1, 2}. If there are 3 workers, the job
|
||||
will produce the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} (though not necessarily
|
||||
in that order). To account for this, it is recommended to randomly shuffle
|
||||
your dataset, so that different tf.data workers will iterate through the
|
||||
dataset in different orders.
|
||||
|
||||
In the future, there will be additional epoch modes. For example,
|
||||
In the future, there will be additional processing modes. For example,
|
||||
a "one_epoch" mode which partitions the dataset across the tf.data
|
||||
workers, so that the consumers see each element of the dataset only once.
|
||||
|
||||
```
|
||||
dataset = tf.data.Dataset.range(10)
|
||||
dataset = tf.data.Dataset.range(5)
|
||||
dataset = dataset.map(lambda x: x*x)
|
||||
dataset = dataset.apply(
|
||||
tf.data.experimental.service.distribute("parallel_epochs",
|
||||
"grpc://dataservice:5000"))
|
||||
dataset = dataset.map(lambda x: x+10)
|
||||
dataset = dataset.map(lambda x: x+1)
|
||||
|
||||
for element in dataset:
|
||||
# process element
|
||||
print(element) # prints { 1, 2, 5, 10, 17 }
|
||||
```
|
||||
|
||||
In the above example, the first two lines (before the call to `distribute`)
|
||||
@ -224,6 +252,33 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
||||
RPC. The remaining transformations (after the call to `distribute`) will be
|
||||
executed locally.
|
||||
|
||||
The `job_name` argument allows jobs to be shared across multiple
|
||||
datasets. Instead of each dataset creating its own job, all datasets with the
|
||||
same `job_name` will consume from the same job. A new job will
|
||||
be created for each iteration of the dataset (with each repetition of
|
||||
`Dataset.repeat` counting as a new iteration). The following example
|
||||
demonstrates shared iteration, with the assumption that the tf.data service is
|
||||
running with a single worker.
|
||||
|
||||
```
|
||||
range5_dataset = tf.data.Dataset.range(5)
|
||||
dataset1 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
||||
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
||||
dataset2 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
||||
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
||||
iter_1_1 = iter(dataset1)
|
||||
iter_1_2 = iter(dataset1)
|
||||
iter_2_1 = iter(dataset2)
|
||||
iter_2_2 = iter(dataset2)
|
||||
print(next(iter_1_1)) # Prints "0"
|
||||
# iter_1_2 consumes from the same job as iter_1_1
|
||||
print(next(iter_1_2)) # Prints "1"
|
||||
# iter_2_1 consumes from a new job
|
||||
print(next(iter_2_1)) # Prints "0"
|
||||
# iter_2_2 consumes from the same job as iter_2_1
|
||||
print(next(iter_2_2)) # Prints "1"
|
||||
```
|
||||
|
||||
Args:
|
||||
processing_mode: A string specifying the policy for how data should be
|
||||
processed by tf.data workers. Currently, the only supported value is
|
||||
@ -231,6 +286,9 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
||||
service: A string indicating how to connect to the tf.data service. The
|
||||
string should be in the format <protocol>://<address>, e.g.
|
||||
grpc://localhost:5000.
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is that
|
||||
the dataset creates anonymous, exclusively owned jobs.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -239,4 +297,5 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
||||
Returns:
|
||||
Dataset: A `Dataset` of the elements produced by the data service.
|
||||
"""
|
||||
return _distribute(processing_mode, service, max_outstanding_requests)
|
||||
return _distribute(processing_mode, service, job_name,
|
||||
max_outstanding_requests)
|
||||
|
@ -37,11 +37,14 @@ from tensorflow.python.platform import test
|
||||
PROTOCOL = "grpc"
|
||||
|
||||
|
||||
def _make_distributed_dataset(dataset, service):
|
||||
def _make_distributed_dataset(dataset, service, job_name=None):
|
||||
"""Creates a distributed dataset with a short task refresh interval."""
|
||||
return dataset.apply(
|
||||
data_service_ops._distribute(
|
||||
"parallel_epochs", service, task_refresh_interval_hint_ms=20))
|
||||
"parallel_epochs",
|
||||
service,
|
||||
job_name=job_name,
|
||||
task_refresh_interval_hint_ms=20))
|
||||
|
||||
|
||||
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@ -233,6 +236,72 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = list(f().numpy())
|
||||
self.assertCountEqual(num_workers * list(range(num_elements)), result)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobName(self):
|
||||
num_elements = 10
|
||||
service = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
iter1 = iter(ds1)
|
||||
iter2 = iter(ds2)
|
||||
results = []
|
||||
for _ in range(3):
|
||||
results.append(next(iter1).numpy())
|
||||
results.append(next(iter2).numpy())
|
||||
for elem in iter1:
|
||||
results.append(elem.numpy())
|
||||
for elem in iter2:
|
||||
results.append(elem.numpy())
|
||||
self.assertCountEqual(list(range(num_elements)), results)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDifferentJobNames(self):
|
||||
num_elements = 10
|
||||
service = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, service, job_name="job_name1")
|
||||
ds2 = _make_distributed_dataset(ds, service, job_name="job_name2")
|
||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobNameMultiIteration(self):
|
||||
num_elements = 10
|
||||
service = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
# iteration 1
|
||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||
self.assertDatasetProduces(ds2, [])
|
||||
# iteration 2
|
||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||
self.assertDatasetProduces(ds1, [])
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobNameRepeat(self):
|
||||
num_elements = 10
|
||||
num_repetitions = 3
|
||||
service = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
ds1 = ds1.repeat(num_repetitions)
|
||||
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||
ds2 = ds2.repeat(num_repetitions)
|
||||
results = []
|
||||
iter1 = iter(ds1)
|
||||
iter2 = iter(ds2)
|
||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||
results.append(next(iter1).numpy())
|
||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||
results.append(next(iter2).numpy())
|
||||
for elem in iter1:
|
||||
results.append(elem.numpy())
|
||||
for elem in iter2:
|
||||
results.append(elem.numpy())
|
||||
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
|
||||
|
||||
def run_stateful(self, external_state_policy):
|
||||
num_elements = 10
|
||||
ds = dataset_ops.Dataset.range(num_elements).map(
|
||||
|
@ -934,7 +934,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
@ -1176,6 +1176,10 @@ tf_module {
|
||||
name: "DrawBoundingBoxesV2"
|
||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyIterationCounter"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -934,7 +934,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
@ -1176,6 +1176,10 @@ tf_module {
|
||||
name: "DrawBoundingBoxesV2"
|
||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyIterationCounter"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user