[tf.data service] Add support for one_epoch processing mode.

This mode relies on the dataset being splittable, i.e. having MakeSplitProvider implemented for its source dataset.

Caveats:
- If the dispatcher restarts mid-job, we will begin processing splits again from the beginning.
- If a worker restarts mid-job, its in-progress splits will not be trained on.

These caveats will be addressed in later CLs by taking periodic checkpoints to reduce the impact of restarts.

PiperOrigin-RevId: 332930574
Change-Id: I3666273d4d8b07406561b678c7b7f64deb934955
This commit is contained in:
Andrew Audibert 2020-09-21 14:10:48 -07:00 committed by TensorFlower Gardener
parent 1a32039b93
commit 91c1955163
20 changed files with 415 additions and 25 deletions

View File

@ -188,6 +188,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/data:standalone",
"//tensorflow/core/kernels/data:dataset_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
@ -359,6 +360,20 @@ cc_header_only_library(
],
)
cc_library(
name = "split_provider",
srcs = ["split_provider.cc"],
hdrs = [
"split_provider.h",
],
deps = [
":data_service",
":grpc_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "test_cluster",
testonly = True,
@ -450,6 +465,7 @@ cc_library(
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util",
":split_provider",
":utils",
":worker_proto_cc",
"//tensorflow/c:c_api_internal",

View File

@ -19,6 +19,7 @@ message TaskDef {
int64 dataset_id = 3;
int64 task_id = 4;
int64 job_id = 5;
ProcessingModeDef processing_mode = 6;
}
message TaskInfo {
@ -31,8 +32,9 @@ message TaskInfo {
}
enum ProcessingModeDef {
INVALID = 0;
// Each tf.data worker processes an entire epoch.
PARALLEL_EPOCHS = 0;
PARALLEL_EPOCHS = 1;
// Processing of an epoch is distributed across all tf.data workers.
ONE_EPOCH = 1;
DISTRIBUTED_EPOCH = 2;
}

View File

@ -28,14 +28,14 @@ namespace data {
namespace {
constexpr const char kParallelEpochs[] = "parallel_epochs";
constexpr const char kOneEpoch[] = "one_epoch";
constexpr const char kDistributedEpoch[] = "distributed_epoch";
} // namespace
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
if (s == kParallelEpochs) {
mode = ProcessingMode::PARALLEL_EPOCHS;
} else if (s == kOneEpoch) {
mode = ProcessingMode::ONE_EPOCH;
} else if (s == kDistributedEpoch) {
mode = ProcessingMode::DISTRIBUTED_EPOCH;
} else {
return errors::InvalidArgument("Unrecognized processing mode: ", s);
}
@ -46,8 +46,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
switch (mode) {
case ProcessingMode::PARALLEL_EPOCHS:
return kParallelEpochs;
case ProcessingMode::ONE_EPOCH:
return kOneEpoch;
case ProcessingMode::DISTRIBUTED_EPOCH:
return kDistributedEpoch;
default:
DCHECK(false);
return "Unknown";
@ -111,6 +111,28 @@ Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id,
return Status::OK();
}
Status DataServiceDispatcherClient::GetSplit(int64 job_id, int64 repetition,
Tensor& split,
bool& end_of_splits) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetSplitRequest req;
req.set_job_id(job_id);
req.set_repetition(repetition);
GetSplitResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetSplit(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to get split", status);
}
end_of_splits = resp.end_of_splits();
if (!end_of_splits) {
if (!split.FromProto(resp.split())) {
return errors::Internal("Failed to parse split tensor proto");
}
}
return Status::OK();
}
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
int64& dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());

View File

@ -26,11 +26,12 @@ namespace data {
// Modes for how a tf.data service job should process a dataset.
enum class ProcessingMode : int64 {
UNSET = 0,
// Each tf.data worker processes an entire epoch. If a dataset contains 2
// elements and there are 3 workers, the job will produce 6 elements.
PARALLEL_EPOCHS = 0,
PARALLEL_EPOCHS = 1,
// Processing of a single epoch is distributed across all tf.data workers.
ONE_EPOCH = 1,
DISTRIBUTED_EPOCH = 2,
};
// Parses a string representing a processing mode and stores the result in
@ -91,6 +92,10 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
// definition in `dataset_def`.
Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
// Gets the next split for the specified job id and repetition.
Status GetSplit(int64 job_id, int64 repetition, Tensor& split,
bool& end_of_splits);
// Registers a dataset with the tf.data service, and stores the generated
// dataset id in `dataset_id`.
Status RegisterDataset(GraphDef dataset, int64& dataset_id);

View File

@ -45,10 +45,10 @@ TEST(DataService, ParseParallelEpochsProcessingMode) {
EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS);
}
TEST(DataService, ParseOneEpochProcessingMode) {
TEST(DataService, ParseDistributedEpochProcessingMode) {
ProcessingMode mode;
TF_ASSERT_OK(ParseProcessingMode("one_epoch", mode));
EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH);
TF_ASSERT_OK(ParseProcessingMode("distributed_epoch", mode));
EXPECT_EQ(mode, ProcessingMode::DISTRIBUTED_EPOCH);
}
TEST(DataService, ParseInvalidProcessingMode) {
@ -60,7 +60,8 @@ TEST(DataService, ParseInvalidProcessingMode) {
TEST(DataService, ProcessingModeToString) {
EXPECT_EQ("parallel_epochs",
ProcessingModeToString(ProcessingMode::PARALLEL_EPOCHS));
EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH));
EXPECT_EQ("distributed_epoch",
ProcessingModeToString(ProcessingMode::DISTRIBUTED_EPOCH));
}
TEST(DataService, GetWorkers) {

View File

@ -3,6 +3,7 @@ syntax = "proto3";
package tensorflow.data;
import "tensorflow/core/data/service/common.proto";
import "tensorflow/core/framework/tensor.proto";
message TaskProgress {
// The task that this message is about.
@ -36,6 +37,16 @@ message GetDatasetDefResponse {
DatasetDef dataset_def = 1;
}
message GetSplitRequest {
int64 job_id = 1;
int64 repetition = 2;
}
message GetSplitResponse {
TensorProto split = 1;
bool end_of_splits = 2;
}
message GetOrRegisterDatasetRequest {
// The dataset to register.
DatasetDef dataset = 1;
@ -119,6 +130,9 @@ service DispatcherService {
// Gets a dataset defintion.
rpc GetDatasetDef(GetDatasetDefRequest) returns (GetDatasetDefResponse);
// Gets the next split for a given job.
rpc GetSplit(GetSplitRequest) returns (GetSplitResponse);
// Registers a dataset with the server, or returns its id if it is already
// registered.
//

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/journal.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/errors.h"
@ -140,6 +141,11 @@ Status DataServiceDispatcherImpl::Start() {
TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
}
}
for (const auto& job : state_.ListJobs()) {
if (job->processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
TF_RETURN_IF_ERROR(MakeDistributedEpochJob(job->job_id, job->dataset_id));
}
}
// Initialize the journal writer in `Start` so that we fail fast in case it
// can't be initialized.
TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
@ -147,6 +153,19 @@ Status DataServiceDispatcherImpl::Start() {
return Status::OK();
}
Status DataServiceDispatcherImpl::MakeDistributedEpochJob(int64 job_id,
int64 dataset_id)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::unique_ptr<DistributedEpochJob>& distributed_epoch_job =
distributed_epoch_jobs_[job_id];
DCHECK(!distributed_epoch_job);
std::unique_ptr<SplitProvider> split_provider;
TF_RETURN_IF_ERROR(MakeSplitProvider(dataset_id, split_provider));
distributed_epoch_job = absl::make_unique<DistributedEpochJob>(
job_id, dataset_id, std::move(split_provider));
return Status::OK();
}
Status DataServiceDispatcherImpl::WorkerHeartbeat(
const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
TF_RETURN_IF_ERROR(CheckStarted());
@ -194,6 +213,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat(
task_def->set_dataset_id(task->dataset_id);
task_def->set_job_id(task->job_id);
task_def->set_task_id(task->task_id);
task_def->set_processing_mode(ProcessingModeDef(task->processing_mode));
}
for (int64 current_task : current_tasks) {
if (!correct_tasks_set.contains(current_task)) {
@ -243,6 +263,51 @@ Status DataServiceDispatcherImpl::GetDatasetDef(
return Status::OK();
}
Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
GetSplitResponse* response) {
TF_RETURN_IF_ERROR(CheckStarted());
mutex_lock l(mu_);
int64 job_id = request->job_id();
int64 repetition = request->repetition();
std::unique_ptr<DistributedEpochJob>& distributed_epoch_job =
distributed_epoch_jobs_[job_id];
if (!distributed_epoch_job) {
return errors::NotFound("distributed_epoch_job id not found: ", job_id);
}
std::unique_ptr<SplitProvider>& split_provider =
distributed_epoch_job->split_providers[repetition];
if (!split_provider) {
VLOG(1) << "Creating split provider for job "
<< distributed_epoch_job->job_id << " repetition " << repetition;
TF_RETURN_IF_ERROR(
MakeSplitProvider(distributed_epoch_job->dataset_id, split_provider));
}
Tensor split;
bool end_of_splits = false;
TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
response->set_end_of_splits(end_of_splits);
if (!end_of_splits) {
split.AsProtoTensorContent(response->mutable_split());
}
return Status::OK();
}
Status DataServiceDispatcherImpl::MakeSplitProvider(
int64 dataset_id, std::unique_ptr<SplitProvider>& split_provider)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
std::string key = DatasetKey(dataset->dataset_id, dataset->fingerprint);
std::shared_ptr<const DatasetDef> dataset_def;
TF_RETURN_IF_ERROR(dataset_store_->Get(key, dataset_def));
standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> standalone_dataset;
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
params, dataset_def->graph(), &standalone_dataset));
TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProvider(&split_provider));
return Status::OK();
}
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) {
@ -405,17 +470,16 @@ Status DataServiceDispatcherImpl::CreateJob(
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS:
case ProcessingMode::DISTRIBUTED_EPOCH:
break;
case ProcessingMode::ONE_EPOCH:
return errors::Unimplemented(
"CreateJob only supports the PARALLEL_EPOCHS job mode. "
"ONE_EPOCH is not currently supported.");
default:
return errors::Unimplemented("ProcessingMode ",
ProcessingModeToString(processing_mode),
" not recognized");
return errors::Internal(
absl::StrCat("ProcessingMode ", processing_mode, " not recognized"));
}
int64 job_id = state_.NextAvailableJobId();
if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
TF_RETURN_IF_ERROR(MakeDistributedEpochJob(job_id, dataset_id));
}
Update update;
CreateJobUpdate* create_job = update.mutable_create_job();
create_job->set_job_id(job_id);
@ -482,6 +546,7 @@ Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
create_task->set_task_id(task_id);
create_task->set_job_id(job->job_id);
create_task->set_dataset_id(job->dataset_id);
create_task->set_processing_mode(ProcessingModeDef(job->processing_mode));
create_task->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(Apply(update));
TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
@ -530,6 +595,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
ProcessTaskRequest req;
TaskDef* task_def = req.mutable_task();
task_def->set_dataset_id(task->dataset_id);
task_def->set_job_id(task->job_id);
{
mutex_lock l(mu_);
std::shared_ptr<const Dataset> dataset;
@ -547,6 +613,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
}
}
task_def->set_task_id(task->task_id);
task_def->set_processing_mode(ProcessingModeDef(task->processing_mode));
ProcessTaskResponse resp;
WorkerService::Stub* stub;
TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));

View File

@ -63,6 +63,7 @@ class DataServiceDispatcherImpl {
WorkerUpdateResponse* response);
Status GetDatasetDef(const GetDatasetDefRequest* request,
GetDatasetDefResponse* response);
Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
/// Client-facing API.
Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
@ -78,6 +79,30 @@ class DataServiceDispatcherImpl {
GetWorkersResponse* response);
private:
struct DistributedEpochJob {
// When the distributed epoch job is first created, we eagerly create the
// split provider to fail fast in case the dataset doesn't support
// splitting. Split providers for later repetitions are created on demand.
explicit DistributedEpochJob(int64 job_id, int64 dataset_id,
std::unique_ptr<SplitProvider> split_provider)
: job_id(job_id), dataset_id(dataset_id) {
split_providers[0] = std::move(split_provider);
}
const int64 job_id;
const int64 dataset_id;
// Map from repetition index to split provider.
absl::flat_hash_map<int64, std::unique_ptr<SplitProvider>> split_providers;
};
// Creates a new DistributedEpochJob in `distributed_epoch_jobs_`.
Status MakeDistributedEpochJob(int64 job_id, int64 dataset_id)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Makes a split provider for the specified `dataset_id`, and stores it in
// `split_provider`.
Status MakeSplitProvider(int64 dataset_id,
std::unique_ptr<SplitProvider>& split_provider)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Registers a dataset with the given fingerprint, storing the new dataset's
// id in `dataset_id`.
Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
@ -151,6 +176,10 @@ class DataServiceDispatcherImpl {
worker_stubs_ TF_GUARDED_BY(mu_);
// Store of dataset definitions.
std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_);
// Mapping from job id to `DistributedEpochJob` for jobs with processing mode
// DISTRIBUTED_EPOCH.
absl::flat_hash_map<int64, std::unique_ptr<DistributedEpochJob>>
distributed_epoch_jobs_ TF_GUARDED_BY(mu_);
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
TF_GUARDED_BY(mu_);

View File

@ -125,6 +125,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
DCHECK_EQ(task, nullptr);
task = std::make_shared<Task>(task_id, create_task.job_id(),
create_task.dataset_id(),
ProcessingMode(create_task.processing_mode()),
create_task.worker_address());
tasks_by_job_[create_task.job_id()].push_back(task);
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;

View File

@ -113,15 +113,18 @@ class DispatcherState {
struct Task {
explicit Task(int64 task_id, int64 job_id, int64 dataset_id,
ProcessingMode processing_mode,
const std::string& worker_address)
: task_id(task_id),
job_id(job_id),
dataset_id(dataset_id),
processing_mode(processing_mode),
worker_address(worker_address) {}
const int64 task_id;
const int64 job_id;
const int64 dataset_id;
const ProcessingMode processing_mode;
const std::string worker_address;
bool finished = false;
};

View File

@ -43,6 +43,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); }
HANDLER(WorkerHeartbeat);
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
HANDLER(GetOrRegisterDataset);
HANDLER(CreateJob);
HANDLER(ReleaseJobClient);

View File

@ -42,6 +42,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
HANDLER(WorkerHeartbeat);
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
HANDLER(GetOrRegisterDataset);
HANDLER(CreateJob);
HANDLER(ReleaseJobClient);

View File

@ -57,6 +57,7 @@ message CreateTaskUpdate {
int64 task_id = 1;
int64 job_id = 2;
int64 dataset_id = 3;
ProcessingModeDef processing_mode = 5;
string worker_address = 4;
}

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/split_provider.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
namespace data {
namespace {
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
} // namespace
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
mutex_lock l(mu_);
if (!dispatcher_) {
dispatcher_ =
absl::make_unique<DataServiceDispatcherClient>(address_, protocol_);
}
return grpc_util::Retry(
[this, split, end_of_splits] {
return dispatcher_->GetSplit(job_id_, repetition_, *split,
*end_of_splits);
},
"get next split",
/*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros);
}
Status DataServiceSplitProvider::Reset() {
mutex_lock l(mu_);
repetition_++;
return Status::OK();
}
Status DataServiceSplitProvider::Save(
std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) {
return errors::Unimplemented(
"Save is not implemented for DataServiceSplitProvider");
}
Status DataServiceSplitProvider::Restore(
std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) {
return errors::Unimplemented(
"Restore is not implemented for DataServiceSplitProvider");
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_
#define TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_
#include <queue>
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// SplitProvider which reads splits from a tf.data service dispatcher over RPC.
class DataServiceSplitProvider : public SplitProvider {
public:
DataServiceSplitProvider(const std::string& address,
const std::string& protocol, int64 job_id)
: address_(address), protocol_(protocol), job_id_(job_id) {}
Status GetNext(Tensor* split, bool* end_of_splits) override;
Status Reset() override;
Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) override;
Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) override;
private:
const std::string address_;
const std::string protocol_;
const int64 job_id_;
mutex mu_;
int64 repetition_ = 0;
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_

View File

@ -20,11 +20,13 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/split_provider.h"
#include "tensorflow/core/data/service/utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h"
@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/public/session_options.h"
@ -108,7 +111,8 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
return Status::OK();
}
task = absl::make_unique<Task>(task_def);
VLOG(3) << "Began processing for task " << task_def.task_id();
VLOG(3) << "Began processing for task " << task_def.task_id()
<< " with processing mode " << task_def.processing_mode();
return Status::OK();
}
@ -142,7 +146,22 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
return errors::Internal("Unrecognized dataset case: ",
task.task_def.dataset_case());
}
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
switch (task.task_def.processing_mode()) {
case DISTRIBUTED_EPOCH: {
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
config_.dispatcher_address(), config_.protocol(),
task.task_def.job_id());
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
&task.iterator));
break;
}
case PARALLEL_EPOCHS:
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
break;
default:
return errors::InvalidArgument("Unrecognized processing mode: ",
task.task_def.processing_mode());
}
task.initialized = true;
VLOG(3) << "Created iterator for task " << task.task_def.task_id();
return Status::OK();

View File

@ -99,7 +99,8 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
return Status::OK();
} // static
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
Status Dataset::MakeIterator(std::unique_ptr<SplitProvider> split_provider,
std::unique_ptr<Iterator>* result) {
// Create an `IteratorContext`, which bundles together the necessary runtime
// support to create and get elements from an iterator.
std::unique_ptr<IteratorContext> ctx;
@ -116,6 +117,7 @@ Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
params.cancellation_manager = &cancellation_manager_;
params.split_provider = std::move(split_provider);
ctx = absl::make_unique<IteratorContext>(std::move(params));
}
@ -130,6 +132,14 @@ Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
return Status::OK();
}
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
return MakeIterator(/*split_provider=*/nullptr, result);
}
Status Dataset::MakeSplitProvider(std::unique_ptr<SplitProvider>* result) {
return dataset_->MakeSplitProvider(result);
}
Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
ProcessFunctionLibraryRuntime* pflr,
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)

View File

@ -98,6 +98,12 @@ class Dataset {
// Creates an iterator for this dataset.
Status MakeIterator(std::unique_ptr<Iterator>* result);
// Creates an iterator, optionally with a split provider.
Status MakeIterator(std::unique_ptr<SplitProvider> split_provider,
std::unique_ptr<Iterator>* result);
// Creates a split provider for this dataset.
Status MakeSplitProvider(std::unique_ptr<SplitProvider>* result);
private:
Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,

View File

@ -34,12 +34,17 @@ from tensorflow.python.util.tf_export import tf_export
class ProcessingMode(object):
"""tf.data service processing modes."""
PARALLEL_EPOCHS = "parallel_epochs"
DISTRIBUTED_EPOCH = "distributed_epoch"
@staticmethod
def validate(mode):
"""Raises a ValueError if the given object is not a valid processing mode."""
valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
valid_modes = [
ProcessingMode.PARALLEL_EPOCHS, ProcessingMode.DISTRIBUTED_EPOCH
]
if mode not in valid_modes:
raise ValueError(
"{0} is not a valid processing mode. Valid modes: {1}".format(
@ -315,7 +320,7 @@ def distribute(processing_mode,
dataset in different orders.
In the future, there will be additional processing modes. For example,
a "one_epoch" mode which partitions the dataset across the tf.data
a "distributed_epoch" mode which partitions the dataset across the tf.data
workers, so that the consumers see each element of the dataset only once.
```

View File

@ -645,6 +645,40 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochTensorSlices(self):
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
vals = [5, 1, 2, 4]
ds = dataset_ops.Dataset.from_tensor_slices(vals)
ds = ds.apply(
data_service_ops.distribute(
processing_mode="distributed_epoch", service=dispatcher.target))
self.assertDatasetProduces(ds, vals, assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochRepeat(self):
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
num_repeats = 5
num_elements = 20
ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
ds = ds.apply(
data_service_ops.distribute(
processing_mode="distributed_epoch", service=dispatcher.target))
self.assertDatasetProduces(
ds, num_repeats * list(range(num_elements)), assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochShuffleAndRepeat(self):
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
num_repeats = 5
num_elements = 20
ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat(
num_repeats)
ds = ds.apply(
data_service_ops.distribute(
processing_mode="distributed_epoch", service=dispatcher.target))
self.assertDatasetProduces(
ds, num_repeats * list(range(num_elements)), assert_items_equal=True)
def testDistributeFromInterleave(self):
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
ds = dataset_ops.Dataset.range(2)
@ -657,6 +691,40 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = ds.interleave(interleave_fn, cycle_length=2)
self.assertDatasetProduces(ds, [0, 0, 1, 1])
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpoch(self):
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.apply(
data_service_ops.distribute(
processing_mode="distributed_epoch", service=dispatcher.target))
self.assertDatasetProduces(
ds, list(range(num_elements)), assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testChangeProcessingModeAfterRestart(self):
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
num_elements = 100
range_dataset = dataset_ops.Dataset.range(num_elements)
ds = range_dataset.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs",
service=dispatcher.target,
job_name="test"))
iterator = iter(ds)
for i in range(num_elements // 2):
self.assertEqual(i, next(iterator).numpy())
dispatcher = self.restart_dispatcher(dispatcher)
ds = range_dataset.apply(
data_service_ops.distribute(
processing_mode="distributed_epoch",
service=dispatcher.target,
job_name="test"))
with self.assertRaisesOpError("already an existing job with that name "
"using processing mode <parallel_epochs>"):
next(iter(ds)).numpy()
@combinations.generate(test_base.eager_only_combinations())
def testDistributeNonStringAddresses(self):
ds = dataset_ops.Dataset.range(10)