[tf.data service] Add dataset_sharing_mode option.

Previously, the dataset_sharing_mode was always "rpc", with entire (potentially large) dataset graphs being sent over RPC. This CL adds another option "shared_filesystem", which shares datasets by writing them to the dispatcher's work_dir, then transmitting only the filesystem path, instead of the full dataset graph.

PiperOrigin-RevId: 327130518
Change-Id: I8565689de2ce35448e8944ecc39e7ba8bb053ff9
This commit is contained in:
Andrew Audibert 2020-08-17 17:09:29 -07:00 committed by TensorFlower Gardener
parent 33e968d542
commit 94ca496b8a
8 changed files with 181 additions and 71 deletions

View File

@ -12,11 +12,13 @@ message DatasetDef {
message TaskDef {
// The dataset to iterate over.
// TODO(aaudibert): load the dataset from disk instead of passing it here.
DatasetDef dataset = 1;
int64 dataset_id = 2;
int64 task_id = 3;
int64 job_id = 4;
oneof dataset {
DatasetDef dataset_def = 1;
string path = 2;
}
int64 dataset_id = 3;
int64 task_id = 4;
int64 job_id = 5;
}
message TaskInfo {

View File

@ -54,6 +54,8 @@ using Worker = DispatcherState::Worker;
using NamedJobKey = DispatcherState::NamedJobKey;
using Job = DispatcherState::Job;
using Task = DispatcherState::Task;
using ::tensorflow::data::experimental::RPC;
using ::tensorflow::data::experimental::SHARED_FILESYSTEM;
std::string JournalDir(const std::string& work_dir) {
return io::JoinPath(work_dir, kJournalDir);
@ -93,7 +95,17 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
Status DataServiceDispatcherImpl::Start() {
mutex_lock l(mu_);
if (!config_.work_dir().empty()) {
if (config_.work_dir().empty()) {
if (config_.fault_tolerant_mode()) {
return errors::InvalidArgument(
"fault_tolerant_mode is True, but no work_dir is configured.");
}
if (config_.dataset_sharing_mode() == SHARED_FILESYSTEM) {
return errors::InvalidArgument(
"dataset sharing mode is shared_filesystem, but no work_dir is "
"configured.");
}
} else {
TF_RETURN_IF_ERROR(
Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
}
@ -102,10 +114,6 @@ Status DataServiceDispatcherImpl::Start() {
"not be able to recover its state on restart.";
return Status::OK();
}
if (config_.work_dir().empty()) {
return errors::InvalidArgument(
"fault_tolerant_mode is True, but no work_dir is configured.");
}
journal_writer_ = absl::make_unique<FileJournalWriter>(
Env::Default(), JournalDir(config_.work_dir()));
LOG(INFO) << "Restoring dispatcher state from journal in "
@ -169,10 +177,25 @@ Status DataServiceDispatcherImpl::RegisterWorker(
TaskDef* task_def = response->add_tasks();
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, &dataset));
std::shared_ptr<const DatasetDef> dataset_def;
TF_RETURN_IF_ERROR(dataset_store_->Get(
DatasetKey(dataset->dataset_id, dataset->fingerprint), dataset_def));
*(task_def->mutable_dataset()) = *dataset_def;
std::string dataset_key =
DatasetKey(dataset->dataset_id, dataset->fingerprint);
switch (config_.dataset_sharing_mode()) {
case SHARED_FILESYSTEM: {
std::string path =
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
task_def->set_path(path);
break;
}
case RPC: {
std::shared_ptr<const DatasetDef> dataset_def;
TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
*task_def->mutable_dataset_def() = *dataset_def;
break;
}
default:
return errors::Internal("Unrecognized dataset sharing mode: ",
config_.dataset_sharing_mode());
}
task_def->set_dataset_id(job->dataset_id);
task_def->set_job_id(job->job_id);
task_def->set_task_id(task->task_id);
@ -458,6 +481,8 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
LOCKS_EXCLUDED(mu_) {
VLOG(2) << "Started assigning task " << task->task_id << " to worker "
<< task->worker_address;
grpc::ClientContext client_ctx;
ProcessTaskRequest req;
TaskDef* task_def = req.mutable_task();
@ -466,10 +491,25 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
mutex_lock l(mu_);
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset));
std::shared_ptr<const DatasetDef> dataset_def;
TF_RETURN_IF_ERROR(dataset_store_->Get(
DatasetKey(dataset->dataset_id, dataset->fingerprint), dataset_def));
*task_def->mutable_dataset() = *dataset_def;
std::string dataset_key =
DatasetKey(dataset->dataset_id, dataset->fingerprint);
switch (config_.dataset_sharing_mode()) {
case SHARED_FILESYSTEM: {
std::string path =
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
task_def->set_path(path);
break;
}
case RPC: {
std::shared_ptr<const DatasetDef> dataset_def;
TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
*task_def->mutable_dataset_def() = *dataset_def;
break;
}
default:
return errors::Internal("Unrecognized dataset sharing mode: ",
config_.dataset_sharing_mode());
}
}
task_def->set_task_id(task->task_id);
ProcessTaskResponse resp;
@ -481,6 +521,8 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
absl::StrCat("Failed to submit task to worker ", task->worker_address),
s);
}
VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
<< task->worker_address;
return Status::OK();
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -94,27 +95,46 @@ Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Received request to process task " << task_def.task_id();
standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> dataset;
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
params, task_def.dataset().graph(), &dataset));
std::unique_ptr<standalone::Iterator> iterator;
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
if (tasks_.contains(task_def.task_id())) {
std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
if (task) {
return errors::AlreadyExists("A task with id ", task_def.task_id(),
" already exists.");
}
Task& task = tasks_[task_def.task_id()];
task.task_id = task_def.task_id();
task.dataset = std::move(dataset);
task.iterator = std::move(iterator);
task = absl::make_unique<Task>(task_def);
VLOG(3) << "Began processing for task " << task_def.task_id();
return Status::OK();
}
Status DataServiceWorkerImpl::EnsureTaskInitialized(
DataServiceWorkerImpl::Task& task) {
mutex_lock l(task.mu);
if (task.initialized) {
return Status::OK();
}
standalone::Dataset::Params params;
switch (task.task_def.dataset_case()) {
case TaskDef::kDatasetDef:
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
params, task.task_def.dataset_def().graph(), &task.dataset));
break;
case TaskDef::kPath: {
DatasetDef def;
TF_RETURN_IF_ERROR(ReadDatasetDef(task.task_def.path(), def));
TF_RETURN_IF_ERROR(
standalone::Dataset::FromGraph(params, def.graph(), &task.dataset));
break;
}
case TaskDef::DATASET_NOT_SET:
return errors::Internal("Unrecognized dataset case: ",
task.task_def.dataset_case());
}
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
task.initialized = true;
VLOG(3) << "Created iterator for task " << task.task_def.task_id();
return Status::OK();
}
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
GetElementResponse* response) {
VLOG(3) << "Received GetElement request for task " << request->task_id();
@ -134,7 +154,9 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
"Task id ", request->task_id(), " not found");
}
std::unique_ptr<standalone::Iterator>& iter = it->second.iterator;
auto& task = it->second;
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
std::unique_ptr<standalone::Iterator>& iter = task->iterator;
if (iter == nullptr) {
VLOG(3) << "Task " << request->task_id() << " is already finished";
response->set_end_of_sequence(true);

View File

@ -51,6 +51,18 @@ class DataServiceWorkerImpl {
GetElementResponse* response);
private:
struct Task {
explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {}
TaskDef task_def;
mutex mu;
bool initialized TF_GUARDED_BY(mu) = false;
// TODO(aaudibert): Have standalone::Iterator own a reference to
// standalone::Dataset so that we don't need to store the dataset here.
std::unique_ptr<standalone::Dataset> dataset;
std::unique_ptr<standalone::Iterator> iterator;
};
Status MakeDispatcherStub(std::unique_ptr<DispatcherService::Stub>* stub);
// Registers the worker with the dispatcher.
Status Register(DispatcherService::Stub* dispatcher) LOCKS_EXCLUDED(mu_);
@ -59,6 +71,7 @@ class DataServiceWorkerImpl {
LOCKS_EXCLUDED(mu_);
// Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureTaskInitialized(Task& task);
// A thread for doing async background processing not associated with a
// specific RPC, such as reporting finished tasks. The thread takes
// ownership of the passed dispatcher_ptr. We use a raw pointer instead of
@ -66,21 +79,13 @@ class DataServiceWorkerImpl {
void BackgroundThread(DispatcherService::Stub* dispatcher_ptr)
LOCKS_EXCLUDED(mu_);
typedef struct Task {
int64 task_id;
// TODO(aaudibert): Have standalone::Iterator own a reference to
// standalone::Dataset so that we don't need to store the dataset here.
std::unique_ptr<standalone::Dataset> dataset;
std::unique_ptr<standalone::Iterator> iterator;
} Task;
const experimental::WorkerConfig config_;
// The worker's own address.
std::string worker_address_;
mutex mu_;
// Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
// Completed tasks which haven't yet been communicated to the dispatcher.
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;

View File

@ -2,6 +2,17 @@ syntax = "proto3";
package tensorflow.data.experimental;
enum DatasetSharingMode {
// Unknown default value.
UNKNOWN = 0;
// Pass dataset definitions over the wire.
RPC = 1;
// Write dataset definitions to a shared filesystem, then send only the path
// over the wire. This reduces the load on the dispatcher, but requires that
// that the dispatcher's work_dir is accessible from the workers.
SHARED_FILESYSTEM = 2;
}
// Configuration for a tf.data service DispatchServer.
message DispatcherConfig {
// The port for the dispatcher to bind to. A value of 0 indicates that the
@ -15,6 +26,8 @@ message DispatcherConfig {
// Whether to run in fault tolerant mode, where dispatcher state is saved
// across restarts.
bool fault_tolerant_mode = 4;
// How to share datasets with workers.
DatasetSharingMode dataset_sharing_mode = 5;
}
// Configuration for a tf.data service WorkerServer.

View File

@ -92,17 +92,22 @@ class DispatchServer(object):
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
"""
self._protocol = protocol or DEFAULT_PROTOCOL
work_dir = work_dir or ""
fault_tolerant_mode = fault_tolerant_mode or False
if fault_tolerant_mode and not work_dir:
self._protocol = DEFAULT_PROTOCOL if protocol is None else protocol
self._work_dir = "" if work_dir is None else work_dir
self._dataset_sharing_mode = ("shared_filesystem"
if self._work_dir else "rpc")
self._fault_tolerant_mode = (False if fault_tolerant_mode is None else
fault_tolerant_mode)
if self._fault_tolerant_mode and not self._work_dir:
raise ValueError(
"Cannot enable fault tolerant mode without configuring a work_dir")
config = service_config_pb2.DispatcherConfig(
port=port,
protocol=self._protocol,
work_dir=work_dir,
fault_tolerant_mode=fault_tolerant_mode)
work_dir=self._work_dir,
fault_tolerant_mode=self._fault_tolerant_mode,
dataset_sharing_mode=service_config_pb2.DatasetSharingMode.Value(
self._dataset_sharing_mode.upper()))
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
config.SerializeToString())
if start:

View File

@ -94,6 +94,7 @@ tf_py_test(
name = "data_service_ops_test",
size = "medium",
srcs = ["data_service_ops_test.py"],
shard_count = 10,
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",

View File

@ -65,6 +65,14 @@ def _make_distributed_dataset(dataset,
task_refresh_interval_hint_ms=20))
def _all_cluster_configurations():
with_work_dir = combinations.combine(
work_dir=None, fault_tolerant_mode=[True, False])
without_work_dir = combinations.combine(
work_dir="", fault_tolerant_mode=False)
return with_work_dir + without_work_dir
def _make_distributed_range_dataset(num_elements,
dispatcher,
job_name=None,
@ -89,15 +97,20 @@ def _make_distributed_range_dataset(num_elements,
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def start_dispatch_server(self, name="", port=0):
def start_dispatch_server(self,
name="",
port=0,
work_dir=None,
fault_tolerant_mode=True):
# If a test starts multiple independent dispatch servers, it should give
# them different `name` values.
work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name)
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
name) if work_dir is None else work_dir
return server_lib.DispatchServer(
port=port,
protocol=server_lib.DEFAULT_PROTOCOL,
work_dir=work_dir,
fault_tolerant_mode=True)
fault_tolerant_mode=fault_tolerant_mode)
def start_worker_server(self, dispatcher, port=0):
return server_lib.WorkerServer(
@ -109,7 +122,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
port = int(_address_from_target(dispatcher.target).split(":")[1])
dispatcher._stop()
return self.start_dispatch_server(port=port)
return self.start_dispatch_server(
port=port,
work_dir=dispatcher._work_dir,
fault_tolerant_mode=dispatcher._fault_tolerant_mode)
def restart_worker(self, worker, dispatcher, use_same_port=True):
"""Stops `worker` and returns a new worker."""
@ -119,23 +135,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
worker._stop()
return self.start_worker_server(dispatcher, port)
def start_cluster(self, num_workers, name=""):
"""Creates a cluster of tf.data service servers.
def start_cluster(self,
num_workers,
name="",
work_dir=None,
fault_tolerant_mode=True):
"""Creates and starts a tf.data service cluster."""
dispatcher = self.start_dispatch_server(
name=name, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
workers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
return dispatcher, workers
Args:
num_workers: The number of workers in the cluster.
name: A name for the cluster.
Returns:
A tuple of (dispatcher, list_of_workers).
"""
dispatcher = self.start_dispatch_server(name=name)
servers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
return dispatcher, servers
@combinations.generate(test_base.eager_only_combinations())
def testDistributeBasic(self):
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
@combinations.generate(
combinations.times(test_base.eager_only_combinations(),
_all_cluster_configurations()))
def testDistributeBasic(self, work_dir, fault_tolerant_mode):
dispatcher, workers = self.start_cluster( # to avoid gcing workers, pylint: disable=unused-variable
1,
work_dir=work_dir,
fault_tolerant_mode=fault_tolerant_mode)
num_elements = 10
ds = _make_distributed_range_dataset(10, dispatcher)
results = [elem.numpy() for elem in ds]
@ -387,9 +405,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(test_base.eager_only_combinations(),
combinations.combine(use_same_port=[True, False])))
def testRestartWorker(self, use_same_port):
dispatcher, [worker] = self.start_cluster(1)
combinations.combine(use_same_port=[True, False]),
_all_cluster_configurations()))
def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
dispatcher, [worker] = self.start_cluster(
1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
num_elements = 100
ds = _make_distributed_range_dataset(num_elements, dispatcher)
iterator = iter(ds)