Construct data service servers from config proto.

This will allow us to add new options without changing the constructor signature.

This breaks backwards compatibility of the experimental `DispatchServer` and `WorkerServer` APIs. To migrate, update symbols as follows:

`tf.data.experimental.service.DispatchServer` becomes `tf.data.experimental.service.create_dispatcher`.

`tf.data.experimental.service.WorkerServer(...)` becomes
`tf.data.experimental.service.create_worker(...)`. The parameter order has changed now that `port` is optional.

This CL also updates argument defaults to follow the recommendation in https://github.com/tensorflow/community/pull/250/files to prefer concrete default values over `None` when the values can't be changed in a backwards-compatible way.

PiperOrigin-RevId: 329786653
Change-Id: I0d6483e2f18a73c80cfd68f9dd44f403c976d140
This commit is contained in:
Andrew Audibert 2020-09-02 14:19:02 -07:00 committed by TensorFlower Gardener
parent 6128d6cee1
commit b129d3dc1c
15 changed files with 302 additions and 122 deletions

View File

@ -40,6 +40,12 @@
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type * `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not `tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types. well defined for complex types.
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`.
## Known Caveats ## Known Caveats

View File

@ -9,11 +9,11 @@ message DispatcherConfig {
int64 port = 1; int64 port = 1;
// The protocol for the dispatcher to use when connecting to workers. // The protocol for the dispatcher to use when connecting to workers.
string protocol = 2; string protocol = 2;
// An optional work directory to use for storing dispatcher state, and for // A work directory to use for storing dispatcher state, and for recovering
// recovering during restarts. // during restarts. The empty string indicates not to use any work directory.
string work_dir = 3; string work_dir = 3;
// Whether to run in fault tolerant mode, where dispatcher state is saved // Whether to run in fault tolerant mode, where dispatcher state is saved
// across restarts. // across restarts. Requires that `work_dir` is nonempty.
bool fault_tolerant_mode = 4; bool fault_tolerant_mode = 4;
} }

View File

@ -435,10 +435,11 @@ def register_dataset(service, dataset):
If the dataset is already registered with the tf.data service, If the dataset is already registered with the tf.data service,
`register_dataset` returns the already-registered dataset's id. `register_dataset` returns the already-registered dataset's id.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address) ... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset_id = tf.data.experimental.service.register_dataset( >>> dataset_id = tf.data.experimental.service.register_dataset(
... dispatcher.target, dataset) ... dispatcher.target, dataset)
@ -518,10 +519,11 @@ def from_dataset_id(processing_mode,
See the documentation for `tf.data.experimental.service.distribute` for more See the documentation for `tf.data.experimental.service.distribute` for more
detail about how `from_dataset_id` works. detail about how `from_dataset_id` works.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address) ... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset_id = tf.data.experimental.service.register_dataset( >>> dataset_id = tf.data.experimental.service.register_dataset(
... dispatcher.target, dataset) ... dispatcher.target, dataset)

View File

@ -107,10 +107,11 @@ dataset = dataset.apply(tf.data.experimental.service.distribute(
Below is a toy example that you can run yourself. Below is a toy example that you can run yourself.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address) ... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
@ -128,5 +129,7 @@ from __future__ import print_function
from tensorflow.python.data.experimental.ops.data_service_ops import distribute from tensorflow.python.data.experimental.ops.data_service_ops import distribute
from tensorflow.python.data.experimental.ops.data_service_ops import from_dataset_id from tensorflow.python.data.experimental.ops.data_service_ops import from_dataset_id
from tensorflow.python.data.experimental.ops.data_service_ops import register_dataset from tensorflow.python.data.experimental.ops.data_service_ops import register_dataset
from tensorflow.python.data.experimental.service.server_lib import DispatcherConfig
from tensorflow.python.data.experimental.service.server_lib import DispatchServer from tensorflow.python.data.experimental.service.server_lib import DispatchServer
from tensorflow.python.data.experimental.service.server_lib import WorkerConfig
from tensorflow.python.data.experimental.service.server_lib import WorkerServer from tensorflow.python.data.experimental.service.server_lib import WorkerServer

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import # pylint: disable=invalid-import-order,g-bad-import-order, unused-import
from tensorflow.core.protobuf.data.experimental import service_config_pb2 from tensorflow.core.protobuf.data.experimental import service_config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
@ -25,7 +27,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
DEFAULT_PROTOCOL = "grpc" @tf_export("data.experimental.service.DispatcherConfig")
class DispatcherConfig(
collections.namedtuple(
"DispatcherConfig",
["port", "protocol", "work_dir", "fault_tolerant_mode"])):
"""Configuration class for tf.data service dispatchers.
Fields:
port: Specifies the port to bind to. A value of 0 indicates that the server
may bind to any available port.
protocol: The protocol to use for communicating with the tf.data service.
Acceptable values include `"grpc" and "grpc+local"`.
work_dir: A directory to store dispatcher state in. This
argument is required for the dispatcher to be able to recover from
restarts.
fault_tolerant_mode: Whether the dispatcher should write its state to a
journal so that it can recover from restarts. Dispatcher state, including
registered datasets and created jobs, is synchronously written to the
journal before responding to RPCs. If `True`, `work_dir` must also be
specified.
"""
def __new__(cls,
port=0,
protocol="grpc",
work_dir=None,
fault_tolerant_mode=False):
return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir,
fault_tolerant_mode)
@tf_export("data.experimental.service.DispatchServer", v1=[]) @tf_export("data.experimental.service.DispatchServer", v1=[])
@ -36,10 +66,10 @@ class DispatchServer(object):
`tf.data.experimental.service.WorkerServer`s. When the workers start, they `tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the dispatcher. register themselves with the dispatcher.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
... port=0, dispatcher_address=dispatcher_address) ... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
@ -50,7 +80,8 @@ class DispatchServer(object):
indefinitely after starting up the server. indefinitely after starting up the server.
``` ```
dispatcher = tf.data.experimental.service.DispatchServer(port=5050) dispatcher = tf.data.experimental.service.DispatchServer(
tf.data.experimental.service.DispatcherConfig(port=5050))
dispatcher.join() dispatcher.join()
``` ```
@ -59,61 +90,42 @@ class DispatchServer(object):
``` ```
dispatcher = tf.data.experimental.service.DispatchServer( dispatcher = tf.data.experimental.service.DispatchServer(
port=5050, tf.data.experimental.service.DispatcherConfig(
work_dir="gs://my-bucket/dispatcher/work_dir", port=5050,
fault_tolerant_mode=True) work_dir="gs://my-bucket/dispatcher/work_dir",
fault_tolerant_mode=True))
``` ```
""" """
def __init__(self, def __init__(self, config=None, start=True):
port,
protocol=None,
work_dir=None,
fault_tolerant_mode=None,
start=True):
"""Creates a new dispatch server. """Creates a new dispatch server.
Args: Args:
port: Specifies the port to bind to. config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
protocol: (Optional.) Specifies the protocol to be used by the server. configration. If `None`, the dispatcher will be use default
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`. configuration values.
work_dir: (Optional.) A directory to store dispatcher state in. This
argument is required for the dispatcher to be able to recover from
restarts.
fault_tolerant_mode: (Optional.) Whether the dispatcher should write
its state to a journal so that it can recover from restarts. Dispatcher
state, including registered datasets and created jobs, is synchronously
written to the journal before responding to RPCs. If `True`, `work_dir`
must also be specified. Defaults to `False`.
start: (Optional.) Boolean, indicating whether to start the server after start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`. creating it.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
""" """
self._protocol = DEFAULT_PROTOCOL if protocol is None else protocol config = config or DispatcherConfig()
self._work_dir = "" if work_dir is None else work_dir if config.fault_tolerant_mode and not config.work_dir:
self._fault_tolerant_mode = (False if fault_tolerant_mode is None else
fault_tolerant_mode)
if self._fault_tolerant_mode and not self._work_dir:
raise ValueError( raise ValueError(
"Cannot enable fault tolerant mode without configuring a work_dir") "Cannot enable fault tolerant mode without configuring a work_dir")
config = service_config_pb2.DispatcherConfig( self._config = config
port=port, config_proto = service_config_pb2.DispatcherConfig(
protocol=self._protocol, port=config.port,
work_dir=self._work_dir, protocol=config.protocol,
fault_tolerant_mode=self._fault_tolerant_mode) work_dir=config.work_dir,
fault_tolerant_mode=config.fault_tolerant_mode)
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
config.SerializeToString()) config_proto.SerializeToString())
if start: if start:
self._server.start() self._server.start()
def start(self): def start(self):
"""Starts this server. """Starts this server.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0, >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
... start=False)
>>> dispatcher.start() >>> dispatcher.start()
Raises: Raises:
@ -128,7 +140,8 @@ class DispatchServer(object):
This is useful when starting a dedicated dispatch process. This is useful when starting a dedicated dispatch process.
``` ```
dispatcher = tf.data.experimental.service.DispatchServer(port=5050) dispatcher = tf.data.experimental.service.DispatchServer(
tf.data.experimental.service.DispatcherConfig(port=5050))
dispatcher.join() dispatcher.join()
``` ```
@ -142,7 +155,7 @@ class DispatchServer(object):
def target(self): def target(self):
"""Returns a target that can be used to connect to the server. """Returns a target that can be used to connect to the server.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
@ -150,7 +163,7 @@ class DispatchServer(object):
The returned string will be in the form protocol://address, e.g. The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050". "grpc://localhost:5050".
""" """
return "{0}://localhost:{1}".format(self._protocol, return "{0}://localhost:{1}".format(self._config.protocol,
self._server.bound_port()) self._server.bound_port())
def _stop(self): def _stop(self):
@ -178,6 +191,35 @@ class DispatchServer(object):
return self._server.num_workers() return self._server.num_workers()
@tf_export("data.experimental.service.WorkerConfig")
class WorkerConfig(
collections.namedtuple(
"WorkerConfig",
["dispatcher_address", "worker_address", "port", "protocol"])):
"""Configuration class for tf.data service dispatchers.
Fields:
dispatcher_address: Specifies the address of the dispatcher.
worker_address: Specifies the address of the worker server. This address is
passed to the dispatcher so that the dispatcher can tell clients how to
connect to this worker.
port: Specifies the port to bind to. A value of 0 indicates that the worker
can bind to any available port.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc" and "grpc+local"`.
"""
def __new__(cls,
dispatcher_address,
worker_address=None,
port=0,
protocol="grpc"):
worker_address = ("localhost:%port%"
if worker_address is None else worker_address)
return super(WorkerConfig, cls).__new__(cls, dispatcher_address,
worker_address, port, protocol)
@tf_export("data.experimental.service.WorkerServer", v1=[]) @tf_export("data.experimental.service.WorkerServer", v1=[])
class WorkerServer(object): class WorkerServer(object):
"""An in-process tf.data service worker server. """An in-process tf.data service worker server.
@ -187,10 +229,11 @@ class WorkerServer(object):
RPC. A worker is associated with a single RPC. A worker is associated with a single
`tf.data.experimental.service.DispatchServer`. `tf.data.experimental.service.DispatchServer`.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address) ... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
@ -207,45 +250,23 @@ class WorkerServer(object):
``` ```
""" """
def __init__(self, def __init__(self, config, start=True):
port,
dispatcher_address,
worker_address=None,
protocol=None,
start=True):
"""Creates a new worker server. """Creates a new worker server.
Args: Args:
port: Specifies the port to bind to. A value of 0 indicates that the config: A `tf.data.experimental.service.WorkerConfig` configration.
worker can bind to any available port.
dispatcher_address: Specifies the address of the dispatcher.
worker_address: (Optional.) Specifies the address of the worker server.
This address is passed to the dispatcher so that the dispatcher can
tell clients how to connect to this worker. Defaults to
`"localhost:%port%"`, where `%port%` will be replaced with the port used
by the worker.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`. creating it.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
""" """
if worker_address is None: if config.dispatcher_address is None:
worker_address = "localhost:%port%" raise ValueError("must specify a dispatcher_address")
if protocol is None: config_proto = service_config_pb2.WorkerConfig(
protocol = "grpc" dispatcher_address=config.dispatcher_address,
worker_address=config.worker_address,
self._protocol = protocol port=config.port,
config = service_config_pb2.WorkerConfig( protocol=config.protocol)
port=port,
protocol=protocol,
dispatcher_address=dispatcher_address,
worker_address=worker_address)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config.SerializeToString()) config_proto.SerializeToString())
if start: if start:
self._server.start() self._server.start()

View File

@ -26,66 +26,73 @@ from tensorflow.python.platform import test
class ServerLibTest(test.TestCase): class ServerLibTest(test.TestCase):
def testStartDispatcher(self): def testStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=False) dispatcher = server_lib.DispatchServer(start=False)
dispatcher.start() dispatcher.start()
def testMultipleStartDispatcher(self): def testMultipleStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=True) dispatcher = server_lib.DispatchServer(start=True)
dispatcher.start() dispatcher.start()
def testStartWorker(self): def testStartWorker(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
worker = server_lib.WorkerServer(0, dispatcher._address, start=False) worker = server_lib.WorkerServer(
server_lib.WorkerConfig(dispatcher._address), start=False)
worker.start() worker.start()
def testMultipleStartWorker(self): def testMultipleStartWorker(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
worker = server_lib.WorkerServer(0, dispatcher._address, start=True) worker = server_lib.WorkerServer(
server_lib.WorkerConfig(dispatcher._address), start=True)
worker.start() worker.start()
def testStopDispatcher(self): def testStopDispatcher(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
dispatcher._stop() dispatcher._stop()
dispatcher._stop() dispatcher._stop()
def testStopWorker(self): def testStopWorker(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
worker = server_lib.WorkerServer(0, dispatcher._address) worker = server_lib.WorkerServer(
server_lib.WorkerConfig(dispatcher._address))
worker._stop() worker._stop()
worker._stop() worker._stop()
def testStopStartDispatcher(self): def testStopStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
dispatcher._stop() dispatcher._stop()
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"): RuntimeError, "Server cannot be started after it has been stopped"):
dispatcher.start() dispatcher.start()
def testStopStartWorker(self): def testStopStartWorker(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
worker = server_lib.WorkerServer(0, dispatcher._address) worker = server_lib.WorkerServer(
server_lib.WorkerConfig(dispatcher._address))
worker._stop() worker._stop()
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"): RuntimeError, "Server cannot be started after it has been stopped"):
worker.start() worker.start()
def testJoinDispatcher(self): def testJoinDispatcher(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
dispatcher._stop() dispatcher._stop()
dispatcher.join() dispatcher.join()
def testJoinWorker(self): def testJoinWorker(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
worker = server_lib.WorkerServer(0, dispatcher._address) worker = server_lib.WorkerServer(
server_lib.WorkerConfig(dispatcher._address))
worker._stop() worker._stop()
worker.join() worker.join()
def testDispatcherNumWorkers(self): def testDispatcherNumWorkers(self):
dispatcher = server_lib.DispatchServer(0) dispatcher = server_lib.DispatchServer()
self.assertEqual(0, dispatcher._num_workers()) self.assertEqual(0, dispatcher._num_workers())
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable worker1 = server_lib.WorkerServer( # pylint: disable=unused-variable
server_lib.WorkerConfig(dispatcher._address))
self.assertEqual(1, dispatcher._num_workers()) self.assertEqual(1, dispatcher._num_workers())
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable worker2 = server_lib.WorkerServer( # pylint: disable=unused-variable
server_lib.WorkerConfig(dispatcher._address))
self.assertEqual(2, dispatcher._num_workers()) self.assertEqual(2, dispatcher._num_workers())

View File

@ -107,16 +107,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
work_dir = os.path.join(self.get_temp_dir(), "work_dir_", work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
name) if work_dir is None else work_dir name) if work_dir is None else work_dir
return server_lib.DispatchServer( return server_lib.DispatchServer(
port=port, server_lib.DispatcherConfig(
protocol=server_lib.DEFAULT_PROTOCOL, port=port,
work_dir=work_dir, work_dir=work_dir,
fault_tolerant_mode=fault_tolerant_mode) fault_tolerant_mode=fault_tolerant_mode))
def start_worker_server(self, dispatcher, port=0): def start_worker_server(self, dispatcher, port=0):
return server_lib.WorkerServer( return server_lib.WorkerServer(
port=port, server_lib.WorkerConfig(
dispatcher_address=_address_from_target(dispatcher.target), dispatcher_address=_address_from_target(dispatcher.target),
protocol=server_lib.DEFAULT_PROTOCOL) port=port))
def restart_dispatcher(self, dispatcher): def restart_dispatcher(self, dispatcher):
"""Stops `dispatcher` and returns a new dispatcher with the same port.""" """Stops `dispatcher` and returns a new dispatcher with the same port."""
@ -124,8 +124,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
dispatcher._stop() dispatcher._stop()
return self.start_dispatch_server( return self.start_dispatch_server(
port=port, port=port,
work_dir=dispatcher._work_dir, work_dir=dispatcher._config.work_dir,
fault_tolerant_mode=dispatcher._fault_tolerant_mode) fault_tolerant_mode=dispatcher._config.fault_tolerant_mode)
def restart_worker(self, worker, dispatcher, use_same_port=True): def restart_worker(self, worker, dispatcher, use_same_port=True):
"""Stops `worker` and returns a new worker.""" """Stops `worker` and returns a new worker."""
@ -362,10 +362,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
except: except:
raise self.skipTest("Flakes in portpicker library do not represent " raise self.skipTest("Flakes in portpicker library do not represent "
"TensorFlow errors.") "TensorFlow errors.")
dispatcher = server_lib.DispatchServer(port=dispatcher_port, start=False) dispatcher = server_lib.DispatchServer(
server_lib.DispatcherConfig(port=dispatcher_port), start=False)
worker = server_lib.WorkerServer( worker = server_lib.WorkerServer(
port=0, server_lib.WorkerConfig(
dispatcher_address=_address_from_target(dispatcher.target), dispatcher_address=_address_from_target(dispatcher.target), port=0),
start=False) start=False)
def start_servers(): def start_servers():

View File

@ -0,0 +1,31 @@
path: "tensorflow.data.experimental.service.DispatcherConfig"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
is_instance: "<type \'tuple\'>"
member {
name: "fault_tolerant_mode"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"
}
member {
name: "protocol"
mtype: "<type \'property\'>"
}
member {
name: "work_dir"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,31 @@
path: "tensorflow.data.experimental.service.WorkerConfig"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
is_instance: "<type \'tuple\'>"
member {
name: "dispatcher_address"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"
}
member {
name: "protocol"
mtype: "<type \'property\'>"
}
member {
name: "worker_address"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -1,5 +1,13 @@
path: "tensorflow.data.experimental.service" path: "tensorflow.data.experimental.service"
tf_module { tf_module {
member {
name: "DispatcherConfig"
mtype: "<type \'type\'>"
}
member {
name: "WorkerConfig"
mtype: "<type \'type\'>"
}
member_method { member_method {
name: "distribute" name: "distribute"
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -8,7 +8,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'port\', \'protocol\', \'work_dir\', \'fault_tolerant_mode\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " argspec: "args=[\'self\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
} }
member_method { member_method {
name: "join" name: "join"

View File

@ -0,0 +1,31 @@
path: "tensorflow.data.experimental.service.DispatcherConfig"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
is_instance: "<type \'tuple\'>"
member {
name: "fault_tolerant_mode"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"
}
member {
name: "protocol"
mtype: "<type \'property\'>"
}
member {
name: "work_dir"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,31 @@
path: "tensorflow.data.experimental.service.WorkerConfig"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
is_instance: "<type \'tuple\'>"
member {
name: "dispatcher_address"
mtype: "<type \'property\'>"
}
member {
name: "port"
mtype: "<type \'property\'>"
}
member {
name: "protocol"
mtype: "<type \'property\'>"
}
member {
name: "worker_address"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " argspec: "args=[\'self\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'True\'], "
} }
member_method { member_method {
name: "join" name: "join"

View File

@ -4,6 +4,14 @@ tf_module {
name: "DispatchServer" name: "DispatchServer"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "DispatcherConfig"
mtype: "<type \'type\'>"
}
member {
name: "WorkerConfig"
mtype: "<type \'type\'>"
}
member { member {
name: "WorkerServer" name: "WorkerServer"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"