[tf.data service] Add dataset_sharing_mode option.
Previously, the dataset_sharing_mode was always "rpc", with entire (potentially large) dataset graphs being sent over RPC. This CL adds another option "shared_filesystem", which shares datasets by writing them to the dispatcher's work_dir, then transmitting only the filesystem path, instead of the full dataset graph. PiperOrigin-RevId: 327130518 Change-Id: I8565689de2ce35448e8944ecc39e7ba8bb053ff9
This commit is contained in:
parent
33e968d542
commit
94ca496b8a
@ -12,11 +12,13 @@ message DatasetDef {
|
||||
|
||||
message TaskDef {
|
||||
// The dataset to iterate over.
|
||||
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
||||
DatasetDef dataset = 1;
|
||||
int64 dataset_id = 2;
|
||||
int64 task_id = 3;
|
||||
int64 job_id = 4;
|
||||
oneof dataset {
|
||||
DatasetDef dataset_def = 1;
|
||||
string path = 2;
|
||||
}
|
||||
int64 dataset_id = 3;
|
||||
int64 task_id = 4;
|
||||
int64 job_id = 5;
|
||||
}
|
||||
|
||||
message TaskInfo {
|
||||
|
@ -54,6 +54,8 @@ using Worker = DispatcherState::Worker;
|
||||
using NamedJobKey = DispatcherState::NamedJobKey;
|
||||
using Job = DispatcherState::Job;
|
||||
using Task = DispatcherState::Task;
|
||||
using ::tensorflow::data::experimental::RPC;
|
||||
using ::tensorflow::data::experimental::SHARED_FILESYSTEM;
|
||||
|
||||
std::string JournalDir(const std::string& work_dir) {
|
||||
return io::JoinPath(work_dir, kJournalDir);
|
||||
@ -93,7 +95,17 @@ DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
||||
|
||||
Status DataServiceDispatcherImpl::Start() {
|
||||
mutex_lock l(mu_);
|
||||
if (!config_.work_dir().empty()) {
|
||||
if (config_.work_dir().empty()) {
|
||||
if (config_.fault_tolerant_mode()) {
|
||||
return errors::InvalidArgument(
|
||||
"fault_tolerant_mode is True, but no work_dir is configured.");
|
||||
}
|
||||
if (config_.dataset_sharing_mode() == SHARED_FILESYSTEM) {
|
||||
return errors::InvalidArgument(
|
||||
"dataset sharing mode is shared_filesystem, but no work_dir is "
|
||||
"configured.");
|
||||
}
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
Env::Default()->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
|
||||
}
|
||||
@ -102,10 +114,6 @@ Status DataServiceDispatcherImpl::Start() {
|
||||
"not be able to recover its state on restart.";
|
||||
return Status::OK();
|
||||
}
|
||||
if (config_.work_dir().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"fault_tolerant_mode is True, but no work_dir is configured.");
|
||||
}
|
||||
journal_writer_ = absl::make_unique<FileJournalWriter>(
|
||||
Env::Default(), JournalDir(config_.work_dir()));
|
||||
LOG(INFO) << "Restoring dispatcher state from journal in "
|
||||
@ -169,10 +177,25 @@ Status DataServiceDispatcherImpl::RegisterWorker(
|
||||
TaskDef* task_def = response->add_tasks();
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, &dataset));
|
||||
std::shared_ptr<const DatasetDef> dataset_def;
|
||||
TF_RETURN_IF_ERROR(dataset_store_->Get(
|
||||
DatasetKey(dataset->dataset_id, dataset->fingerprint), dataset_def));
|
||||
*(task_def->mutable_dataset()) = *dataset_def;
|
||||
std::string dataset_key =
|
||||
DatasetKey(dataset->dataset_id, dataset->fingerprint);
|
||||
switch (config_.dataset_sharing_mode()) {
|
||||
case SHARED_FILESYSTEM: {
|
||||
std::string path =
|
||||
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
|
||||
task_def->set_path(path);
|
||||
break;
|
||||
}
|
||||
case RPC: {
|
||||
std::shared_ptr<const DatasetDef> dataset_def;
|
||||
TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
|
||||
*task_def->mutable_dataset_def() = *dataset_def;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::Internal("Unrecognized dataset sharing mode: ",
|
||||
config_.dataset_sharing_mode());
|
||||
}
|
||||
task_def->set_dataset_id(job->dataset_id);
|
||||
task_def->set_job_id(job->job_id);
|
||||
task_def->set_task_id(task->task_id);
|
||||
@ -458,6 +481,8 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
|
||||
|
||||
Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(2) << "Started assigning task " << task->task_id << " to worker "
|
||||
<< task->worker_address;
|
||||
grpc::ClientContext client_ctx;
|
||||
ProcessTaskRequest req;
|
||||
TaskDef* task_def = req.mutable_task();
|
||||
@ -466,10 +491,25 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
mutex_lock l(mu_);
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset));
|
||||
std::shared_ptr<const DatasetDef> dataset_def;
|
||||
TF_RETURN_IF_ERROR(dataset_store_->Get(
|
||||
DatasetKey(dataset->dataset_id, dataset->fingerprint), dataset_def));
|
||||
*task_def->mutable_dataset() = *dataset_def;
|
||||
std::string dataset_key =
|
||||
DatasetKey(dataset->dataset_id, dataset->fingerprint);
|
||||
switch (config_.dataset_sharing_mode()) {
|
||||
case SHARED_FILESYSTEM: {
|
||||
std::string path =
|
||||
io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
|
||||
task_def->set_path(path);
|
||||
break;
|
||||
}
|
||||
case RPC: {
|
||||
std::shared_ptr<const DatasetDef> dataset_def;
|
||||
TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
|
||||
*task_def->mutable_dataset_def() = *dataset_def;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::Internal("Unrecognized dataset sharing mode: ",
|
||||
config_.dataset_sharing_mode());
|
||||
}
|
||||
}
|
||||
task_def->set_task_id(task->task_id);
|
||||
ProcessTaskResponse resp;
|
||||
@ -481,6 +521,8 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
absl::StrCat("Failed to submit task to worker ", task->worker_address),
|
||||
s);
|
||||
}
|
||||
VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
|
||||
<< task->worker_address;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/utils.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -94,27 +95,46 @@ Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
||||
|
||||
Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(3) << "Received request to process task " << task_def.task_id();
|
||||
standalone::Dataset::Params params;
|
||||
std::unique_ptr<standalone::Dataset> dataset;
|
||||
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
|
||||
params, task_def.dataset().graph(), &dataset));
|
||||
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
|
||||
|
||||
if (tasks_.contains(task_def.task_id())) {
|
||||
std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
|
||||
if (task) {
|
||||
return errors::AlreadyExists("A task with id ", task_def.task_id(),
|
||||
" already exists.");
|
||||
}
|
||||
Task& task = tasks_[task_def.task_id()];
|
||||
task.task_id = task_def.task_id();
|
||||
task.dataset = std::move(dataset);
|
||||
task.iterator = std::move(iterator);
|
||||
task = absl::make_unique<Task>(task_def);
|
||||
VLOG(3) << "Began processing for task " << task_def.task_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
DataServiceWorkerImpl::Task& task) {
|
||||
mutex_lock l(task.mu);
|
||||
if (task.initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
standalone::Dataset::Params params;
|
||||
|
||||
switch (task.task_def.dataset_case()) {
|
||||
case TaskDef::kDatasetDef:
|
||||
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
|
||||
params, task.task_def.dataset_def().graph(), &task.dataset));
|
||||
break;
|
||||
case TaskDef::kPath: {
|
||||
DatasetDef def;
|
||||
TF_RETURN_IF_ERROR(ReadDatasetDef(task.task_def.path(), def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
standalone::Dataset::FromGraph(params, def.graph(), &task.dataset));
|
||||
break;
|
||||
}
|
||||
case TaskDef::DATASET_NOT_SET:
|
||||
return errors::Internal("Unrecognized dataset case: ",
|
||||
task.task_def.dataset_case());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
|
||||
task.initialized = true;
|
||||
VLOG(3) << "Created iterator for task " << task.task_def.task_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
GetElementResponse* response) {
|
||||
VLOG(3) << "Received GetElement request for task " << request->task_id();
|
||||
@ -134,7 +154,9 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
return errors::NotFound("DataServiceWorkerImpl::GetElement failed. ",
|
||||
"Task id ", request->task_id(), " not found");
|
||||
}
|
||||
std::unique_ptr<standalone::Iterator>& iter = it->second.iterator;
|
||||
auto& task = it->second;
|
||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||
std::unique_ptr<standalone::Iterator>& iter = task->iterator;
|
||||
if (iter == nullptr) {
|
||||
VLOG(3) << "Task " << request->task_id() << " is already finished";
|
||||
response->set_end_of_sequence(true);
|
||||
|
@ -51,6 +51,18 @@ class DataServiceWorkerImpl {
|
||||
GetElementResponse* response);
|
||||
|
||||
private:
|
||||
struct Task {
|
||||
explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {}
|
||||
|
||||
TaskDef task_def;
|
||||
mutex mu;
|
||||
bool initialized TF_GUARDED_BY(mu) = false;
|
||||
// TODO(aaudibert): Have standalone::Iterator own a reference to
|
||||
// standalone::Dataset so that we don't need to store the dataset here.
|
||||
std::unique_ptr<standalone::Dataset> dataset;
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
};
|
||||
|
||||
Status MakeDispatcherStub(std::unique_ptr<DispatcherService::Stub>* stub);
|
||||
// Registers the worker with the dispatcher.
|
||||
Status Register(DispatcherService::Stub* dispatcher) LOCKS_EXCLUDED(mu_);
|
||||
@ -59,6 +71,7 @@ class DataServiceWorkerImpl {
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Creates an iterator to process a task.
|
||||
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status EnsureTaskInitialized(Task& task);
|
||||
// A thread for doing async background processing not associated with a
|
||||
// specific RPC, such as reporting finished tasks. The thread takes
|
||||
// ownership of the passed dispatcher_ptr. We use a raw pointer instead of
|
||||
@ -66,21 +79,13 @@ class DataServiceWorkerImpl {
|
||||
void BackgroundThread(DispatcherService::Stub* dispatcher_ptr)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
|
||||
typedef struct Task {
|
||||
int64 task_id;
|
||||
// TODO(aaudibert): Have standalone::Iterator own a reference to
|
||||
// standalone::Dataset so that we don't need to store the dataset here.
|
||||
std::unique_ptr<standalone::Dataset> dataset;
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
} Task;
|
||||
|
||||
const experimental::WorkerConfig config_;
|
||||
// The worker's own address.
|
||||
std::string worker_address_;
|
||||
|
||||
mutex mu_;
|
||||
// Information about tasks, keyed by task ids.
|
||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
|
||||
// Completed tasks which haven't yet been communicated to the dispatcher.
|
||||
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
|
@ -2,6 +2,17 @@ syntax = "proto3";
|
||||
|
||||
package tensorflow.data.experimental;
|
||||
|
||||
enum DatasetSharingMode {
|
||||
// Unknown default value.
|
||||
UNKNOWN = 0;
|
||||
// Pass dataset definitions over the wire.
|
||||
RPC = 1;
|
||||
// Write dataset definitions to a shared filesystem, then send only the path
|
||||
// over the wire. This reduces the load on the dispatcher, but requires that
|
||||
// that the dispatcher's work_dir is accessible from the workers.
|
||||
SHARED_FILESYSTEM = 2;
|
||||
}
|
||||
|
||||
// Configuration for a tf.data service DispatchServer.
|
||||
message DispatcherConfig {
|
||||
// The port for the dispatcher to bind to. A value of 0 indicates that the
|
||||
@ -15,6 +26,8 @@ message DispatcherConfig {
|
||||
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
||||
// across restarts.
|
||||
bool fault_tolerant_mode = 4;
|
||||
// How to share datasets with workers.
|
||||
DatasetSharingMode dataset_sharing_mode = 5;
|
||||
}
|
||||
|
||||
// Configuration for a tf.data service WorkerServer.
|
||||
|
@ -92,17 +92,22 @@ class DispatchServer(object):
|
||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||
creating the TensorFlow server.
|
||||
"""
|
||||
self._protocol = protocol or DEFAULT_PROTOCOL
|
||||
work_dir = work_dir or ""
|
||||
fault_tolerant_mode = fault_tolerant_mode or False
|
||||
if fault_tolerant_mode and not work_dir:
|
||||
self._protocol = DEFAULT_PROTOCOL if protocol is None else protocol
|
||||
self._work_dir = "" if work_dir is None else work_dir
|
||||
self._dataset_sharing_mode = ("shared_filesystem"
|
||||
if self._work_dir else "rpc")
|
||||
self._fault_tolerant_mode = (False if fault_tolerant_mode is None else
|
||||
fault_tolerant_mode)
|
||||
if self._fault_tolerant_mode and not self._work_dir:
|
||||
raise ValueError(
|
||||
"Cannot enable fault tolerant mode without configuring a work_dir")
|
||||
config = service_config_pb2.DispatcherConfig(
|
||||
port=port,
|
||||
protocol=self._protocol,
|
||||
work_dir=work_dir,
|
||||
fault_tolerant_mode=fault_tolerant_mode)
|
||||
work_dir=self._work_dir,
|
||||
fault_tolerant_mode=self._fault_tolerant_mode,
|
||||
dataset_sharing_mode=service_config_pb2.DatasetSharingMode.Value(
|
||||
self._dataset_sharing_mode.upper()))
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
||||
config.SerializeToString())
|
||||
if start:
|
||||
|
@ -94,6 +94,7 @@ tf_py_test(
|
||||
name = "data_service_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["data_service_ops_test.py"],
|
||||
shard_count = 10,
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -65,6 +65,14 @@ def _make_distributed_dataset(dataset,
|
||||
task_refresh_interval_hint_ms=20))
|
||||
|
||||
|
||||
def _all_cluster_configurations():
|
||||
with_work_dir = combinations.combine(
|
||||
work_dir=None, fault_tolerant_mode=[True, False])
|
||||
without_work_dir = combinations.combine(
|
||||
work_dir="", fault_tolerant_mode=False)
|
||||
return with_work_dir + without_work_dir
|
||||
|
||||
|
||||
def _make_distributed_range_dataset(num_elements,
|
||||
dispatcher,
|
||||
job_name=None,
|
||||
@ -89,15 +97,20 @@ def _make_distributed_range_dataset(num_elements,
|
||||
|
||||
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def start_dispatch_server(self, name="", port=0):
|
||||
def start_dispatch_server(self,
|
||||
name="",
|
||||
port=0,
|
||||
work_dir=None,
|
||||
fault_tolerant_mode=True):
|
||||
# If a test starts multiple independent dispatch servers, it should give
|
||||
# them different `name` values.
|
||||
work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name)
|
||||
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
||||
name) if work_dir is None else work_dir
|
||||
return server_lib.DispatchServer(
|
||||
port=port,
|
||||
protocol=server_lib.DEFAULT_PROTOCOL,
|
||||
work_dir=work_dir,
|
||||
fault_tolerant_mode=True)
|
||||
fault_tolerant_mode=fault_tolerant_mode)
|
||||
|
||||
def start_worker_server(self, dispatcher, port=0):
|
||||
return server_lib.WorkerServer(
|
||||
@ -109,7 +122,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
||||
port = int(_address_from_target(dispatcher.target).split(":")[1])
|
||||
dispatcher._stop()
|
||||
return self.start_dispatch_server(port=port)
|
||||
return self.start_dispatch_server(
|
||||
port=port,
|
||||
work_dir=dispatcher._work_dir,
|
||||
fault_tolerant_mode=dispatcher._fault_tolerant_mode)
|
||||
|
||||
def restart_worker(self, worker, dispatcher, use_same_port=True):
|
||||
"""Stops `worker` and returns a new worker."""
|
||||
@ -119,23 +135,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
worker._stop()
|
||||
return self.start_worker_server(dispatcher, port)
|
||||
|
||||
def start_cluster(self, num_workers, name=""):
|
||||
"""Creates a cluster of tf.data service servers.
|
||||
def start_cluster(self,
|
||||
num_workers,
|
||||
name="",
|
||||
work_dir=None,
|
||||
fault_tolerant_mode=True):
|
||||
"""Creates and starts a tf.data service cluster."""
|
||||
dispatcher = self.start_dispatch_server(
|
||||
name=name, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
||||
workers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
|
||||
return dispatcher, workers
|
||||
|
||||
Args:
|
||||
num_workers: The number of workers in the cluster.
|
||||
name: A name for the cluster.
|
||||
|
||||
Returns:
|
||||
A tuple of (dispatcher, list_of_workers).
|
||||
"""
|
||||
dispatcher = self.start_dispatch_server(name=name)
|
||||
servers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
|
||||
return dispatcher, servers
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeBasic(self):
|
||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
_all_cluster_configurations()))
|
||||
def testDistributeBasic(self, work_dir, fault_tolerant_mode):
|
||||
dispatcher, workers = self.start_cluster( # to avoid gcing workers, pylint: disable=unused-variable
|
||||
1,
|
||||
work_dir=work_dir,
|
||||
fault_tolerant_mode=fault_tolerant_mode)
|
||||
num_elements = 10
|
||||
ds = _make_distributed_range_dataset(10, dispatcher)
|
||||
results = [elem.numpy() for elem in ds]
|
||||
@ -387,9 +405,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(use_same_port=[True, False])))
|
||||
def testRestartWorker(self, use_same_port):
|
||||
dispatcher, [worker] = self.start_cluster(1)
|
||||
combinations.combine(use_same_port=[True, False]),
|
||||
_all_cluster_configurations()))
|
||||
def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
|
||||
dispatcher, [worker] = self.start_cluster(
|
||||
1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
||||
num_elements = 100
|
||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
||||
iterator = iter(ds)
|
||||
|
Loading…
x
Reference in New Issue
Block a user