[tf.data service] Add support for one_epoch processing mode.
This mode relies on the dataset being splittable, i.e. having MakeSplitProvider implemented for its source dataset. Caveats: - If the dispatcher restarts mid-job, we will begin processing splits again from the beginning. - If a worker restarts mid-job, its in-progress splits will not be trained on. These caveats will be addressed in later CLs by taking periodic checkpoints to reduce the impact of restarts. PiperOrigin-RevId: 332930574 Change-Id: I3666273d4d8b07406561b678c7b7f64deb934955
This commit is contained in:
parent
1a32039b93
commit
91c1955163
@ -188,6 +188,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -359,6 +360,20 @@ cc_header_only_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split_provider",
|
||||
srcs = ["split_provider.cc"],
|
||||
hdrs = [
|
||||
"split_provider.h",
|
||||
],
|
||||
deps = [
|
||||
":data_service",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_cluster",
|
||||
testonly = True,
|
||||
@ -450,6 +465,7 @@ cc_library(
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":split_provider",
|
||||
":utils",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
|
@ -19,6 +19,7 @@ message TaskDef {
|
||||
int64 dataset_id = 3;
|
||||
int64 task_id = 4;
|
||||
int64 job_id = 5;
|
||||
ProcessingModeDef processing_mode = 6;
|
||||
}
|
||||
|
||||
message TaskInfo {
|
||||
@ -31,8 +32,9 @@ message TaskInfo {
|
||||
}
|
||||
|
||||
enum ProcessingModeDef {
|
||||
INVALID = 0;
|
||||
// Each tf.data worker processes an entire epoch.
|
||||
PARALLEL_EPOCHS = 0;
|
||||
PARALLEL_EPOCHS = 1;
|
||||
// Processing of an epoch is distributed across all tf.data workers.
|
||||
ONE_EPOCH = 1;
|
||||
DISTRIBUTED_EPOCH = 2;
|
||||
}
|
||||
|
@ -28,14 +28,14 @@ namespace data {
|
||||
|
||||
namespace {
|
||||
constexpr const char kParallelEpochs[] = "parallel_epochs";
|
||||
constexpr const char kOneEpoch[] = "one_epoch";
|
||||
constexpr const char kDistributedEpoch[] = "distributed_epoch";
|
||||
} // namespace
|
||||
|
||||
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
|
||||
if (s == kParallelEpochs) {
|
||||
mode = ProcessingMode::PARALLEL_EPOCHS;
|
||||
} else if (s == kOneEpoch) {
|
||||
mode = ProcessingMode::ONE_EPOCH;
|
||||
} else if (s == kDistributedEpoch) {
|
||||
mode = ProcessingMode::DISTRIBUTED_EPOCH;
|
||||
} else {
|
||||
return errors::InvalidArgument("Unrecognized processing mode: ", s);
|
||||
}
|
||||
@ -46,8 +46,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
||||
switch (mode) {
|
||||
case ProcessingMode::PARALLEL_EPOCHS:
|
||||
return kParallelEpochs;
|
||||
case ProcessingMode::ONE_EPOCH:
|
||||
return kOneEpoch;
|
||||
case ProcessingMode::DISTRIBUTED_EPOCH:
|
||||
return kDistributedEpoch;
|
||||
default:
|
||||
DCHECK(false);
|
||||
return "Unknown";
|
||||
@ -111,6 +111,28 @@ Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherClient::GetSplit(int64 job_id, int64 repetition,
|
||||
Tensor& split,
|
||||
bool& end_of_splits) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetSplitRequest req;
|
||||
req.set_job_id(job_id);
|
||||
req.set_repetition(repetition);
|
||||
GetSplitResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->GetSplit(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return grpc_util::WrapError("Failed to get split", status);
|
||||
}
|
||||
end_of_splits = resp.end_of_splits();
|
||||
if (!end_of_splits) {
|
||||
if (!split.FromProto(resp.split())) {
|
||||
return errors::Internal("Failed to parse split tensor proto");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
|
||||
int64& dataset_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
|
@ -26,11 +26,12 @@ namespace data {
|
||||
|
||||
// Modes for how a tf.data service job should process a dataset.
|
||||
enum class ProcessingMode : int64 {
|
||||
UNSET = 0,
|
||||
// Each tf.data worker processes an entire epoch. If a dataset contains 2
|
||||
// elements and there are 3 workers, the job will produce 6 elements.
|
||||
PARALLEL_EPOCHS = 0,
|
||||
PARALLEL_EPOCHS = 1,
|
||||
// Processing of a single epoch is distributed across all tf.data workers.
|
||||
ONE_EPOCH = 1,
|
||||
DISTRIBUTED_EPOCH = 2,
|
||||
};
|
||||
|
||||
// Parses a string representing a processing mode and stores the result in
|
||||
@ -91,6 +92,10 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||
// definition in `dataset_def`.
|
||||
Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
|
||||
|
||||
// Gets the next split for the specified job id and repetition.
|
||||
Status GetSplit(int64 job_id, int64 repetition, Tensor& split,
|
||||
bool& end_of_splits);
|
||||
|
||||
// Registers a dataset with the tf.data service, and stores the generated
|
||||
// dataset id in `dataset_id`.
|
||||
Status RegisterDataset(GraphDef dataset, int64& dataset_id);
|
||||
|
@ -45,10 +45,10 @@ TEST(DataService, ParseParallelEpochsProcessingMode) {
|
||||
EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS);
|
||||
}
|
||||
|
||||
TEST(DataService, ParseOneEpochProcessingMode) {
|
||||
TEST(DataService, ParseDistributedEpochProcessingMode) {
|
||||
ProcessingMode mode;
|
||||
TF_ASSERT_OK(ParseProcessingMode("one_epoch", mode));
|
||||
EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH);
|
||||
TF_ASSERT_OK(ParseProcessingMode("distributed_epoch", mode));
|
||||
EXPECT_EQ(mode, ProcessingMode::DISTRIBUTED_EPOCH);
|
||||
}
|
||||
|
||||
TEST(DataService, ParseInvalidProcessingMode) {
|
||||
@ -60,7 +60,8 @@ TEST(DataService, ParseInvalidProcessingMode) {
|
||||
TEST(DataService, ProcessingModeToString) {
|
||||
EXPECT_EQ("parallel_epochs",
|
||||
ProcessingModeToString(ProcessingMode::PARALLEL_EPOCHS));
|
||||
EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH));
|
||||
EXPECT_EQ("distributed_epoch",
|
||||
ProcessingModeToString(ProcessingMode::DISTRIBUTED_EPOCH));
|
||||
}
|
||||
|
||||
TEST(DataService, GetWorkers) {
|
||||
|
@ -3,6 +3,7 @@ syntax = "proto3";
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/data/service/common.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
|
||||
message TaskProgress {
|
||||
// The task that this message is about.
|
||||
@ -36,6 +37,16 @@ message GetDatasetDefResponse {
|
||||
DatasetDef dataset_def = 1;
|
||||
}
|
||||
|
||||
message GetSplitRequest {
|
||||
int64 job_id = 1;
|
||||
int64 repetition = 2;
|
||||
}
|
||||
|
||||
message GetSplitResponse {
|
||||
TensorProto split = 1;
|
||||
bool end_of_splits = 2;
|
||||
}
|
||||
|
||||
message GetOrRegisterDatasetRequest {
|
||||
// The dataset to register.
|
||||
DatasetDef dataset = 1;
|
||||
@ -119,6 +130,9 @@ service DispatcherService {
|
||||
// Gets a dataset defintion.
|
||||
rpc GetDatasetDef(GetDatasetDefRequest) returns (GetDatasetDefResponse);
|
||||
|
||||
// Gets the next split for a given job.
|
||||
rpc GetSplit(GetSplitRequest) returns (GetSplitResponse);
|
||||
|
||||
// Registers a dataset with the server, or returns its id if it is already
|
||||
// registered.
|
||||
//
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/journal.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -140,6 +141,11 @@ Status DataServiceDispatcherImpl::Start() {
|
||||
TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
|
||||
}
|
||||
}
|
||||
for (const auto& job : state_.ListJobs()) {
|
||||
if (job->processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
|
||||
TF_RETURN_IF_ERROR(MakeDistributedEpochJob(job->job_id, job->dataset_id));
|
||||
}
|
||||
}
|
||||
// Initialize the journal writer in `Start` so that we fail fast in case it
|
||||
// can't be initialized.
|
||||
TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
|
||||
@ -147,6 +153,19 @@ Status DataServiceDispatcherImpl::Start() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::MakeDistributedEpochJob(int64 job_id,
|
||||
int64 dataset_id)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::unique_ptr<DistributedEpochJob>& distributed_epoch_job =
|
||||
distributed_epoch_jobs_[job_id];
|
||||
DCHECK(!distributed_epoch_job);
|
||||
std::unique_ptr<SplitProvider> split_provider;
|
||||
TF_RETURN_IF_ERROR(MakeSplitProvider(dataset_id, split_provider));
|
||||
distributed_epoch_job = absl::make_unique<DistributedEpochJob>(
|
||||
job_id, dataset_id, std::move(split_provider));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::WorkerHeartbeat(
|
||||
const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
|
||||
TF_RETURN_IF_ERROR(CheckStarted());
|
||||
@ -194,6 +213,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat(
|
||||
task_def->set_dataset_id(task->dataset_id);
|
||||
task_def->set_job_id(task->job_id);
|
||||
task_def->set_task_id(task->task_id);
|
||||
task_def->set_processing_mode(ProcessingModeDef(task->processing_mode));
|
||||
}
|
||||
for (int64 current_task : current_tasks) {
|
||||
if (!correct_tasks_set.contains(current_task)) {
|
||||
@ -243,6 +263,51 @@ Status DataServiceDispatcherImpl::GetDatasetDef(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
|
||||
GetSplitResponse* response) {
|
||||
TF_RETURN_IF_ERROR(CheckStarted());
|
||||
mutex_lock l(mu_);
|
||||
int64 job_id = request->job_id();
|
||||
int64 repetition = request->repetition();
|
||||
std::unique_ptr<DistributedEpochJob>& distributed_epoch_job =
|
||||
distributed_epoch_jobs_[job_id];
|
||||
if (!distributed_epoch_job) {
|
||||
return errors::NotFound("distributed_epoch_job id not found: ", job_id);
|
||||
}
|
||||
std::unique_ptr<SplitProvider>& split_provider =
|
||||
distributed_epoch_job->split_providers[repetition];
|
||||
if (!split_provider) {
|
||||
VLOG(1) << "Creating split provider for job "
|
||||
<< distributed_epoch_job->job_id << " repetition " << repetition;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeSplitProvider(distributed_epoch_job->dataset_id, split_provider));
|
||||
}
|
||||
Tensor split;
|
||||
bool end_of_splits = false;
|
||||
TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
|
||||
response->set_end_of_splits(end_of_splits);
|
||||
if (!end_of_splits) {
|
||||
split.AsProtoTensorContent(response->mutable_split());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::MakeSplitProvider(
|
||||
int64 dataset_id, std::unique_ptr<SplitProvider>& split_provider)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
|
||||
std::string key = DatasetKey(dataset->dataset_id, dataset->fingerprint);
|
||||
std::shared_ptr<const DatasetDef> dataset_def;
|
||||
TF_RETURN_IF_ERROR(dataset_store_->Get(key, dataset_def));
|
||||
standalone::Dataset::Params params;
|
||||
std::unique_ptr<standalone::Dataset> standalone_dataset;
|
||||
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
|
||||
params, dataset_def->graph(), &standalone_dataset));
|
||||
TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProvider(&split_provider));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||
const GetOrRegisterDatasetRequest* request,
|
||||
GetOrRegisterDatasetResponse* response) {
|
||||
@ -405,17 +470,16 @@ Status DataServiceDispatcherImpl::CreateJob(
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
switch (processing_mode) {
|
||||
case ProcessingMode::PARALLEL_EPOCHS:
|
||||
case ProcessingMode::DISTRIBUTED_EPOCH:
|
||||
break;
|
||||
case ProcessingMode::ONE_EPOCH:
|
||||
return errors::Unimplemented(
|
||||
"CreateJob only supports the PARALLEL_EPOCHS job mode. "
|
||||
"ONE_EPOCH is not currently supported.");
|
||||
default:
|
||||
return errors::Unimplemented("ProcessingMode ",
|
||||
ProcessingModeToString(processing_mode),
|
||||
" not recognized");
|
||||
return errors::Internal(
|
||||
absl::StrCat("ProcessingMode ", processing_mode, " not recognized"));
|
||||
}
|
||||
int64 job_id = state_.NextAvailableJobId();
|
||||
if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
|
||||
TF_RETURN_IF_ERROR(MakeDistributedEpochJob(job_id, dataset_id));
|
||||
}
|
||||
Update update;
|
||||
CreateJobUpdate* create_job = update.mutable_create_job();
|
||||
create_job->set_job_id(job_id);
|
||||
@ -482,6 +546,7 @@ Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
|
||||
create_task->set_task_id(task_id);
|
||||
create_task->set_job_id(job->job_id);
|
||||
create_task->set_dataset_id(job->dataset_id);
|
||||
create_task->set_processing_mode(ProcessingModeDef(job->processing_mode));
|
||||
create_task->set_worker_address(worker_address);
|
||||
TF_RETURN_IF_ERROR(Apply(update));
|
||||
TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
|
||||
@ -530,6 +595,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
ProcessTaskRequest req;
|
||||
TaskDef* task_def = req.mutable_task();
|
||||
task_def->set_dataset_id(task->dataset_id);
|
||||
task_def->set_job_id(task->job_id);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
@ -547,6 +613,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
}
|
||||
}
|
||||
task_def->set_task_id(task->task_id);
|
||||
task_def->set_processing_mode(ProcessingModeDef(task->processing_mode));
|
||||
ProcessTaskResponse resp;
|
||||
WorkerService::Stub* stub;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
|
||||
|
@ -63,6 +63,7 @@ class DataServiceDispatcherImpl {
|
||||
WorkerUpdateResponse* response);
|
||||
Status GetDatasetDef(const GetDatasetDefRequest* request,
|
||||
GetDatasetDefResponse* response);
|
||||
Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
|
||||
|
||||
/// Client-facing API.
|
||||
Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
|
||||
@ -78,6 +79,30 @@ class DataServiceDispatcherImpl {
|
||||
GetWorkersResponse* response);
|
||||
|
||||
private:
|
||||
struct DistributedEpochJob {
|
||||
// When the distributed epoch job is first created, we eagerly create the
|
||||
// split provider to fail fast in case the dataset doesn't support
|
||||
// splitting. Split providers for later repetitions are created on demand.
|
||||
explicit DistributedEpochJob(int64 job_id, int64 dataset_id,
|
||||
std::unique_ptr<SplitProvider> split_provider)
|
||||
: job_id(job_id), dataset_id(dataset_id) {
|
||||
split_providers[0] = std::move(split_provider);
|
||||
}
|
||||
|
||||
const int64 job_id;
|
||||
const int64 dataset_id;
|
||||
// Map from repetition index to split provider.
|
||||
absl::flat_hash_map<int64, std::unique_ptr<SplitProvider>> split_providers;
|
||||
};
|
||||
|
||||
// Creates a new DistributedEpochJob in `distributed_epoch_jobs_`.
|
||||
Status MakeDistributedEpochJob(int64 job_id, int64 dataset_id)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Makes a split provider for the specified `dataset_id`, and stores it in
|
||||
// `split_provider`.
|
||||
Status MakeSplitProvider(int64 dataset_id,
|
||||
std::unique_ptr<SplitProvider>& split_provider)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Registers a dataset with the given fingerprint, storing the new dataset's
|
||||
// id in `dataset_id`.
|
||||
Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
|
||||
@ -151,6 +176,10 @@ class DataServiceDispatcherImpl {
|
||||
worker_stubs_ TF_GUARDED_BY(mu_);
|
||||
// Store of dataset definitions.
|
||||
std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_);
|
||||
// Mapping from job id to `DistributedEpochJob` for jobs with processing mode
|
||||
// DISTRIBUTED_EPOCH.
|
||||
absl::flat_hash_map<int64, std::unique_ptr<DistributedEpochJob>>
|
||||
distributed_epoch_jobs_ TF_GUARDED_BY(mu_);
|
||||
|
||||
absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
@ -125,6 +125,7 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
|
||||
DCHECK_EQ(task, nullptr);
|
||||
task = std::make_shared<Task>(task_id, create_task.job_id(),
|
||||
create_task.dataset_id(),
|
||||
ProcessingMode(create_task.processing_mode()),
|
||||
create_task.worker_address());
|
||||
tasks_by_job_[create_task.job_id()].push_back(task);
|
||||
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
|
||||
|
@ -113,15 +113,18 @@ class DispatcherState {
|
||||
|
||||
struct Task {
|
||||
explicit Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
const std::string& worker_address)
|
||||
: task_id(task_id),
|
||||
job_id(job_id),
|
||||
dataset_id(dataset_id),
|
||||
processing_mode(processing_mode),
|
||||
worker_address(worker_address) {}
|
||||
|
||||
const int64 task_id;
|
||||
const int64 job_id;
|
||||
const int64 dataset_id;
|
||||
const ProcessingMode processing_mode;
|
||||
const std::string worker_address;
|
||||
bool finished = false;
|
||||
};
|
||||
|
@ -43,6 +43,7 @@ Status GrpcDispatcherImpl::Start() { return impl_.Start(); }
|
||||
HANDLER(WorkerHeartbeat);
|
||||
HANDLER(WorkerUpdate);
|
||||
HANDLER(GetDatasetDef);
|
||||
HANDLER(GetSplit);
|
||||
HANDLER(GetOrRegisterDataset);
|
||||
HANDLER(CreateJob);
|
||||
HANDLER(ReleaseJobClient);
|
||||
|
@ -42,6 +42,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||
HANDLER(WorkerHeartbeat);
|
||||
HANDLER(WorkerUpdate);
|
||||
HANDLER(GetDatasetDef);
|
||||
HANDLER(GetSplit);
|
||||
HANDLER(GetOrRegisterDataset);
|
||||
HANDLER(CreateJob);
|
||||
HANDLER(ReleaseJobClient);
|
||||
|
@ -57,6 +57,7 @@ message CreateTaskUpdate {
|
||||
int64 task_id = 1;
|
||||
int64 job_id = 2;
|
||||
int64 dataset_id = 3;
|
||||
ProcessingModeDef processing_mode = 5;
|
||||
string worker_address = 4;
|
||||
}
|
||||
|
||||
|
65
tensorflow/core/data/service/split_provider.cc
Normal file
65
tensorflow/core/data/service/split_provider.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/split_provider.h"
|
||||
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
||||
} // namespace
|
||||
|
||||
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||
mutex_lock l(mu_);
|
||||
if (!dispatcher_) {
|
||||
dispatcher_ =
|
||||
absl::make_unique<DataServiceDispatcherClient>(address_, protocol_);
|
||||
}
|
||||
return grpc_util::Retry(
|
||||
[this, split, end_of_splits] {
|
||||
return dispatcher_->GetSplit(job_id_, repetition_, *split,
|
||||
*end_of_splits);
|
||||
},
|
||||
"get next split",
|
||||
/*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros);
|
||||
}
|
||||
|
||||
Status DataServiceSplitProvider::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
repetition_++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceSplitProvider::Save(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) {
|
||||
return errors::Unimplemented(
|
||||
"Save is not implemented for DataServiceSplitProvider");
|
||||
}
|
||||
|
||||
Status DataServiceSplitProvider::Restore(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) {
|
||||
return errors::Unimplemented(
|
||||
"Restore is not implemented for DataServiceSplitProvider");
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
54
tensorflow/core/data/service/split_provider.h
Normal file
54
tensorflow/core/data/service/split_provider.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// SplitProvider which reads splits from a tf.data service dispatcher over RPC.
|
||||
class DataServiceSplitProvider : public SplitProvider {
|
||||
public:
|
||||
DataServiceSplitProvider(const std::string& address,
|
||||
const std::string& protocol, int64 job_id)
|
||||
: address_(address), protocol_(protocol), job_id_(job_id) {}
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
||||
Status Reset() override;
|
||||
Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) override;
|
||||
Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
const std::string address_;
|
||||
const std::string protocol_;
|
||||
const int64 job_id_;
|
||||
|
||||
mutex mu_;
|
||||
int64 repetition_ = 0;
|
||||
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_
|
@ -20,11 +20,13 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/split_provider.h"
|
||||
#include "tensorflow/core/data/service/utils.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -32,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/snappy.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
@ -108,7 +111,8 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
|
||||
return Status::OK();
|
||||
}
|
||||
task = absl::make_unique<Task>(task_def);
|
||||
VLOG(3) << "Began processing for task " << task_def.task_id();
|
||||
VLOG(3) << "Began processing for task " << task_def.task_id()
|
||||
<< " with processing mode " << task_def.processing_mode();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -142,7 +146,22 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
return errors::Internal("Unrecognized dataset case: ",
|
||||
task.task_def.dataset_case());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
|
||||
switch (task.task_def.processing_mode()) {
|
||||
case DISTRIBUTED_EPOCH: {
|
||||
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
|
||||
config_.dispatcher_address(), config_.protocol(),
|
||||
task.task_def.job_id());
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
|
||||
&task.iterator));
|
||||
break;
|
||||
}
|
||||
case PARALLEL_EPOCHS:
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Unrecognized processing mode: ",
|
||||
task.task_def.processing_mode());
|
||||
}
|
||||
task.initialized = true;
|
||||
VLOG(3) << "Created iterator for task " << task.task_def.task_id();
|
||||
return Status::OK();
|
||||
|
@ -99,7 +99,8 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
|
||||
return Status::OK();
|
||||
} // static
|
||||
|
||||
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
|
||||
Status Dataset::MakeIterator(std::unique_ptr<SplitProvider> split_provider,
|
||||
std::unique_ptr<Iterator>* result) {
|
||||
// Create an `IteratorContext`, which bundles together the necessary runtime
|
||||
// support to create and get elements from an iterator.
|
||||
std::unique_ptr<IteratorContext> ctx;
|
||||
@ -116,6 +117,7 @@ Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
|
||||
params.function_handle_cache = function_handle_cache_.get();
|
||||
params.resource_mgr = &resource_mgr_;
|
||||
params.cancellation_manager = &cancellation_manager_;
|
||||
params.split_provider = std::move(split_provider);
|
||||
|
||||
ctx = absl::make_unique<IteratorContext>(std::move(params));
|
||||
}
|
||||
@ -130,6 +132,14 @@ Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
|
||||
return MakeIterator(/*split_provider=*/nullptr, result);
|
||||
}
|
||||
|
||||
Status Dataset::MakeSplitProvider(std::unique_ptr<SplitProvider>* result) {
|
||||
return dataset_->MakeSplitProvider(result);
|
||||
}
|
||||
|
||||
Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
ProcessFunctionLibraryRuntime* pflr,
|
||||
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
|
||||
|
@ -98,6 +98,12 @@ class Dataset {
|
||||
|
||||
// Creates an iterator for this dataset.
|
||||
Status MakeIterator(std::unique_ptr<Iterator>* result);
|
||||
// Creates an iterator, optionally with a split provider.
|
||||
Status MakeIterator(std::unique_ptr<SplitProvider> split_provider,
|
||||
std::unique_ptr<Iterator>* result);
|
||||
|
||||
// Creates a split provider for this dataset.
|
||||
Status MakeSplitProvider(std::unique_ptr<SplitProvider>* result);
|
||||
|
||||
private:
|
||||
Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
|
@ -34,12 +34,17 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class ProcessingMode(object):
|
||||
"""tf.data service processing modes."""
|
||||
|
||||
PARALLEL_EPOCHS = "parallel_epochs"
|
||||
DISTRIBUTED_EPOCH = "distributed_epoch"
|
||||
|
||||
@staticmethod
|
||||
def validate(mode):
|
||||
"""Raises a ValueError if the given object is not a valid processing mode."""
|
||||
valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
|
||||
valid_modes = [
|
||||
ProcessingMode.PARALLEL_EPOCHS, ProcessingMode.DISTRIBUTED_EPOCH
|
||||
]
|
||||
if mode not in valid_modes:
|
||||
raise ValueError(
|
||||
"{0} is not a valid processing mode. Valid modes: {1}".format(
|
||||
@ -315,7 +320,7 @@ def distribute(processing_mode,
|
||||
dataset in different orders.
|
||||
|
||||
In the future, there will be additional processing modes. For example,
|
||||
a "one_epoch" mode which partitions the dataset across the tf.data
|
||||
a "distributed_epoch" mode which partitions the dataset across the tf.data
|
||||
workers, so that the consumers see each element of the dataset only once.
|
||||
|
||||
```
|
||||
|
@ -645,6 +645,40 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochTensorSlices(self):
|
||||
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
vals = [5, 1, 2, 4]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(vals)
|
||||
ds = ds.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="distributed_epoch", service=dispatcher.target))
|
||||
self.assertDatasetProduces(ds, vals, assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochRepeat(self):
|
||||
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
num_repeats = 5
|
||||
num_elements = 20
|
||||
ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
|
||||
ds = ds.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="distributed_epoch", service=dispatcher.target))
|
||||
self.assertDatasetProduces(
|
||||
ds, num_repeats * list(range(num_elements)), assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochShuffleAndRepeat(self):
|
||||
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
num_repeats = 5
|
||||
num_elements = 20
|
||||
ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat(
|
||||
num_repeats)
|
||||
ds = ds.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="distributed_epoch", service=dispatcher.target))
|
||||
self.assertDatasetProduces(
|
||||
ds, num_repeats * list(range(num_elements)), assert_items_equal=True)
|
||||
|
||||
def testDistributeFromInterleave(self):
|
||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
ds = dataset_ops.Dataset.range(2)
|
||||
@ -657,6 +691,40 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = ds.interleave(interleave_fn, cycle_length=2)
|
||||
self.assertDatasetProduces(ds, [0, 0, 1, 1])
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpoch(self):
|
||||
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="distributed_epoch", service=dispatcher.target))
|
||||
self.assertDatasetProduces(
|
||||
ds, list(range(num_elements)), assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testChangeProcessingModeAfterRestart(self):
|
||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
num_elements = 100
|
||||
range_dataset = dataset_ops.Dataset.range(num_elements)
|
||||
ds = range_dataset.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="parallel_epochs",
|
||||
service=dispatcher.target,
|
||||
job_name="test"))
|
||||
iterator = iter(ds)
|
||||
for i in range(num_elements // 2):
|
||||
self.assertEqual(i, next(iterator).numpy())
|
||||
dispatcher = self.restart_dispatcher(dispatcher)
|
||||
ds = range_dataset.apply(
|
||||
data_service_ops.distribute(
|
||||
processing_mode="distributed_epoch",
|
||||
service=dispatcher.target,
|
||||
job_name="test"))
|
||||
with self.assertRaisesOpError("already an existing job with that name "
|
||||
"using processing mode <parallel_epochs>"):
|
||||
next(iter(ds)).numpy()
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeNonStringAddresses(self):
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
|
Loading…
Reference in New Issue
Block a user