Construct data service servers from config proto.
This will allow us to add new options without changing the constructor signature. This breaks backwards compatibility of the experimental `DispatchServer` and `WorkerServer` APIs. To migrate, update symbols as follows: `tf.data.experimental.service.DispatchServer` becomes `tf.data.experimental.service.create_dispatcher`. `tf.data.experimental.service.WorkerServer(...)` becomes `tf.data.experimental.service.create_worker(...)`. The parameter order has changed now that `port` is optional. This CL also updates argument defaults to follow the recommendation in https://github.com/tensorflow/community/pull/250/files to prefer concrete default values over `None` when the values can't be changed in a backwards-compatible way. PiperOrigin-RevId: 329786653 Change-Id: I0d6483e2f18a73c80cfd68f9dd44f403c976d140
This commit is contained in:
parent
6128d6cee1
commit
b129d3dc1c
@ -40,6 +40,12 @@
|
|||||||
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
|
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
|
||||||
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
|
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
|
||||||
well defined for complex types.
|
well defined for complex types.
|
||||||
|
* `tf.data.experimental.service.DispatchServer` now takes a config tuple
|
||||||
|
instead of individual arguments. Usages should be updated to
|
||||||
|
`tf.data.experimental.service.DispatchServer(dispatcher_config)`.
|
||||||
|
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||||
|
instead of individual arguments. Usages should be updated to
|
||||||
|
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
|
|
||||||
|
@ -9,11 +9,11 @@ message DispatcherConfig {
|
|||||||
int64 port = 1;
|
int64 port = 1;
|
||||||
// The protocol for the dispatcher to use when connecting to workers.
|
// The protocol for the dispatcher to use when connecting to workers.
|
||||||
string protocol = 2;
|
string protocol = 2;
|
||||||
// An optional work directory to use for storing dispatcher state, and for
|
// A work directory to use for storing dispatcher state, and for recovering
|
||||||
// recovering during restarts.
|
// during restarts. The empty string indicates not to use any work directory.
|
||||||
string work_dir = 3;
|
string work_dir = 3;
|
||||||
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
// Whether to run in fault tolerant mode, where dispatcher state is saved
|
||||||
// across restarts.
|
// across restarts. Requires that `work_dir` is nonempty.
|
||||||
bool fault_tolerant_mode = 4;
|
bool fault_tolerant_mode = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -435,10 +435,11 @@ def register_dataset(service, dataset):
|
|||||||
If the dataset is already registered with the tf.data service,
|
If the dataset is already registered with the tf.data service,
|
||||||
`register_dataset` returns the already-registered dataset's id.
|
`register_dataset` returns the already-registered dataset's id.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, dispatcher_address=dispatcher_address)
|
... tf.data.experimental.service.WorkerConfig(
|
||||||
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset_id = tf.data.experimental.service.register_dataset(
|
>>> dataset_id = tf.data.experimental.service.register_dataset(
|
||||||
... dispatcher.target, dataset)
|
... dispatcher.target, dataset)
|
||||||
@ -518,10 +519,11 @@ def from_dataset_id(processing_mode,
|
|||||||
See the documentation for `tf.data.experimental.service.distribute` for more
|
See the documentation for `tf.data.experimental.service.distribute` for more
|
||||||
detail about how `from_dataset_id` works.
|
detail about how `from_dataset_id` works.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, dispatcher_address=dispatcher_address)
|
... tf.data.experimental.service.WorkerConfig(
|
||||||
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset_id = tf.data.experimental.service.register_dataset(
|
>>> dataset_id = tf.data.experimental.service.register_dataset(
|
||||||
... dispatcher.target, dataset)
|
... dispatcher.target, dataset)
|
||||||
|
@ -107,10 +107,11 @@ dataset = dataset.apply(tf.data.experimental.service.distribute(
|
|||||||
|
|
||||||
Below is a toy example that you can run yourself.
|
Below is a toy example that you can run yourself.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, dispatcher_address=dispatcher_address)
|
... tf.data.experimental.service.WorkerConfig(
|
||||||
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
@ -128,5 +129,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
||||||
from tensorflow.python.data.experimental.ops.data_service_ops import from_dataset_id
|
from tensorflow.python.data.experimental.ops.data_service_ops import from_dataset_id
|
||||||
from tensorflow.python.data.experimental.ops.data_service_ops import register_dataset
|
from tensorflow.python.data.experimental.ops.data_service_ops import register_dataset
|
||||||
|
from tensorflow.python.data.experimental.service.server_lib import DispatcherConfig
|
||||||
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
|
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
|
||||||
|
from tensorflow.python.data.experimental.service.server_lib import WorkerConfig
|
||||||
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
|
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
|
||||||
from tensorflow.core.protobuf.data.experimental import service_config_pb2
|
from tensorflow.core.protobuf.data.experimental import service_config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
@ -25,7 +27,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PROTOCOL = "grpc"
|
@tf_export("data.experimental.service.DispatcherConfig")
|
||||||
|
class DispatcherConfig(
|
||||||
|
collections.namedtuple(
|
||||||
|
"DispatcherConfig",
|
||||||
|
["port", "protocol", "work_dir", "fault_tolerant_mode"])):
|
||||||
|
"""Configuration class for tf.data service dispatchers.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
port: Specifies the port to bind to. A value of 0 indicates that the server
|
||||||
|
may bind to any available port.
|
||||||
|
protocol: The protocol to use for communicating with the tf.data service.
|
||||||
|
Acceptable values include `"grpc" and "grpc+local"`.
|
||||||
|
work_dir: A directory to store dispatcher state in. This
|
||||||
|
argument is required for the dispatcher to be able to recover from
|
||||||
|
restarts.
|
||||||
|
fault_tolerant_mode: Whether the dispatcher should write its state to a
|
||||||
|
journal so that it can recover from restarts. Dispatcher state, including
|
||||||
|
registered datasets and created jobs, is synchronously written to the
|
||||||
|
journal before responding to RPCs. If `True`, `work_dir` must also be
|
||||||
|
specified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
port=0,
|
||||||
|
protocol="grpc",
|
||||||
|
work_dir=None,
|
||||||
|
fault_tolerant_mode=False):
|
||||||
|
return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir,
|
||||||
|
fault_tolerant_mode)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||||||
@ -36,10 +66,10 @@ class DispatchServer(object):
|
|||||||
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||||||
register themselves with the dispatcher.
|
register themselves with the dispatcher.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
|
||||||
... port=0, dispatcher_address=dispatcher_address)
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
@ -50,7 +80,8 @@ class DispatchServer(object):
|
|||||||
indefinitely after starting up the server.
|
indefinitely after starting up the server.
|
||||||
|
|
||||||
```
|
```
|
||||||
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||||||
|
tf.data.experimental.service.DispatcherConfig(port=5050))
|
||||||
dispatcher.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -59,61 +90,42 @@ class DispatchServer(object):
|
|||||||
|
|
||||||
```
|
```
|
||||||
dispatcher = tf.data.experimental.service.DispatchServer(
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||||||
port=5050,
|
tf.data.experimental.service.DispatcherConfig(
|
||||||
work_dir="gs://my-bucket/dispatcher/work_dir",
|
port=5050,
|
||||||
fault_tolerant_mode=True)
|
work_dir="gs://my-bucket/dispatcher/work_dir",
|
||||||
|
fault_tolerant_mode=True))
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, config=None, start=True):
|
||||||
port,
|
|
||||||
protocol=None,
|
|
||||||
work_dir=None,
|
|
||||||
fault_tolerant_mode=None,
|
|
||||||
start=True):
|
|
||||||
"""Creates a new dispatch server.
|
"""Creates a new dispatch server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to.
|
config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
configration. If `None`, the dispatcher will be use default
|
||||||
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
configuration values.
|
||||||
work_dir: (Optional.) A directory to store dispatcher state in. This
|
|
||||||
argument is required for the dispatcher to be able to recover from
|
|
||||||
restarts.
|
|
||||||
fault_tolerant_mode: (Optional.) Whether the dispatcher should write
|
|
||||||
its state to a journal so that it can recover from restarts. Dispatcher
|
|
||||||
state, including registered datasets and created jobs, is synchronously
|
|
||||||
written to the journal before responding to RPCs. If `True`, `work_dir`
|
|
||||||
must also be specified. Defaults to `False`.
|
|
||||||
start: (Optional.) Boolean, indicating whether to start the server after
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
creating it. Defaults to `True`.
|
creating it.
|
||||||
|
|
||||||
Raises:
|
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
|
||||||
creating the TensorFlow server.
|
|
||||||
"""
|
"""
|
||||||
self._protocol = DEFAULT_PROTOCOL if protocol is None else protocol
|
config = config or DispatcherConfig()
|
||||||
self._work_dir = "" if work_dir is None else work_dir
|
if config.fault_tolerant_mode and not config.work_dir:
|
||||||
self._fault_tolerant_mode = (False if fault_tolerant_mode is None else
|
|
||||||
fault_tolerant_mode)
|
|
||||||
if self._fault_tolerant_mode and not self._work_dir:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot enable fault tolerant mode without configuring a work_dir")
|
"Cannot enable fault tolerant mode without configuring a work_dir")
|
||||||
config = service_config_pb2.DispatcherConfig(
|
self._config = config
|
||||||
port=port,
|
config_proto = service_config_pb2.DispatcherConfig(
|
||||||
protocol=self._protocol,
|
port=config.port,
|
||||||
work_dir=self._work_dir,
|
protocol=config.protocol,
|
||||||
fault_tolerant_mode=self._fault_tolerant_mode)
|
work_dir=config.work_dir,
|
||||||
|
fault_tolerant_mode=config.fault_tolerant_mode)
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
||||||
config.SerializeToString())
|
config_proto.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts this server.
|
"""Starts this server.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
|
||||||
... start=False)
|
|
||||||
>>> dispatcher.start()
|
>>> dispatcher.start()
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -128,7 +140,8 @@ class DispatchServer(object):
|
|||||||
This is useful when starting a dedicated dispatch process.
|
This is useful when starting a dedicated dispatch process.
|
||||||
|
|
||||||
```
|
```
|
||||||
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||||||
|
tf.data.experimental.service.DispatcherConfig(port=5050))
|
||||||
dispatcher.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -142,7 +155,7 @@ class DispatchServer(object):
|
|||||||
def target(self):
|
def target(self):
|
||||||
"""Returns a target that can be used to connect to the server.
|
"""Returns a target that can be used to connect to the server.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
@ -150,7 +163,7 @@ class DispatchServer(object):
|
|||||||
The returned string will be in the form protocol://address, e.g.
|
The returned string will be in the form protocol://address, e.g.
|
||||||
"grpc://localhost:5050".
|
"grpc://localhost:5050".
|
||||||
"""
|
"""
|
||||||
return "{0}://localhost:{1}".format(self._protocol,
|
return "{0}://localhost:{1}".format(self._config.protocol,
|
||||||
self._server.bound_port())
|
self._server.bound_port())
|
||||||
|
|
||||||
def _stop(self):
|
def _stop(self):
|
||||||
@ -178,6 +191,35 @@ class DispatchServer(object):
|
|||||||
return self._server.num_workers()
|
return self._server.num_workers()
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.service.WorkerConfig")
|
||||||
|
class WorkerConfig(
|
||||||
|
collections.namedtuple(
|
||||||
|
"WorkerConfig",
|
||||||
|
["dispatcher_address", "worker_address", "port", "protocol"])):
|
||||||
|
"""Configuration class for tf.data service dispatchers.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
dispatcher_address: Specifies the address of the dispatcher.
|
||||||
|
worker_address: Specifies the address of the worker server. This address is
|
||||||
|
passed to the dispatcher so that the dispatcher can tell clients how to
|
||||||
|
connect to this worker.
|
||||||
|
port: Specifies the port to bind to. A value of 0 indicates that the worker
|
||||||
|
can bind to any available port.
|
||||||
|
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||||
|
Acceptable values include `"grpc" and "grpc+local"`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
dispatcher_address,
|
||||||
|
worker_address=None,
|
||||||
|
port=0,
|
||||||
|
protocol="grpc"):
|
||||||
|
worker_address = ("localhost:%port%"
|
||||||
|
if worker_address is None else worker_address)
|
||||||
|
return super(WorkerConfig, cls).__new__(cls, dispatcher_address,
|
||||||
|
worker_address, port, protocol)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||||||
class WorkerServer(object):
|
class WorkerServer(object):
|
||||||
"""An in-process tf.data service worker server.
|
"""An in-process tf.data service worker server.
|
||||||
@ -187,10 +229,11 @@ class WorkerServer(object):
|
|||||||
RPC. A worker is associated with a single
|
RPC. A worker is associated with a single
|
||||||
`tf.data.experimental.service.DispatchServer`.
|
`tf.data.experimental.service.DispatchServer`.
|
||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, dispatcher_address=dispatcher_address)
|
... tf.data.experimental.service.WorkerConfig(
|
||||||
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
@ -207,45 +250,23 @@ class WorkerServer(object):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, config, start=True):
|
||||||
port,
|
|
||||||
dispatcher_address,
|
|
||||||
worker_address=None,
|
|
||||||
protocol=None,
|
|
||||||
start=True):
|
|
||||||
"""Creates a new worker server.
|
"""Creates a new worker server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to. A value of 0 indicates that the
|
config: A `tf.data.experimental.service.WorkerConfig` configration.
|
||||||
worker can bind to any available port.
|
|
||||||
dispatcher_address: Specifies the address of the dispatcher.
|
|
||||||
worker_address: (Optional.) Specifies the address of the worker server.
|
|
||||||
This address is passed to the dispatcher so that the dispatcher can
|
|
||||||
tell clients how to connect to this worker. Defaults to
|
|
||||||
`"localhost:%port%"`, where `%port%` will be replaced with the port used
|
|
||||||
by the worker.
|
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
|
||||||
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
|
||||||
start: (Optional.) Boolean, indicating whether to start the server after
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
creating it. Defaults to `True`.
|
creating it.
|
||||||
|
|
||||||
Raises:
|
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
|
||||||
creating the TensorFlow server.
|
|
||||||
"""
|
"""
|
||||||
if worker_address is None:
|
if config.dispatcher_address is None:
|
||||||
worker_address = "localhost:%port%"
|
raise ValueError("must specify a dispatcher_address")
|
||||||
if protocol is None:
|
config_proto = service_config_pb2.WorkerConfig(
|
||||||
protocol = "grpc"
|
dispatcher_address=config.dispatcher_address,
|
||||||
|
worker_address=config.worker_address,
|
||||||
self._protocol = protocol
|
port=config.port,
|
||||||
config = service_config_pb2.WorkerConfig(
|
protocol=config.protocol)
|
||||||
port=port,
|
|
||||||
protocol=protocol,
|
|
||||||
dispatcher_address=dispatcher_address,
|
|
||||||
worker_address=worker_address)
|
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||||
config.SerializeToString())
|
config_proto.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
|
@ -26,66 +26,73 @@ from tensorflow.python.platform import test
|
|||||||
class ServerLibTest(test.TestCase):
|
class ServerLibTest(test.TestCase):
|
||||||
|
|
||||||
def testStartDispatcher(self):
|
def testStartDispatcher(self):
|
||||||
dispatcher = server_lib.DispatchServer(0, start=False)
|
dispatcher = server_lib.DispatchServer(start=False)
|
||||||
dispatcher.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testMultipleStartDispatcher(self):
|
def testMultipleStartDispatcher(self):
|
||||||
dispatcher = server_lib.DispatchServer(0, start=True)
|
dispatcher = server_lib.DispatchServer(start=True)
|
||||||
dispatcher.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStartWorker(self):
|
def testStartWorker(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(dispatcher._address), start=False)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testMultipleStartWorker(self):
|
def testMultipleStartWorker(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(dispatcher._address), start=True)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testStopDispatcher(self):
|
def testStopDispatcher(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
dispatcher._stop()
|
dispatcher._stop()
|
||||||
dispatcher._stop()
|
dispatcher._stop()
|
||||||
|
|
||||||
def testStopWorker(self):
|
def testStopWorker(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(dispatcher._address))
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker._stop()
|
worker._stop()
|
||||||
|
|
||||||
def testStopStartDispatcher(self):
|
def testStopStartDispatcher(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
dispatcher._stop()
|
dispatcher._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
dispatcher.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStopStartWorker(self):
|
def testStopStartWorker(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(dispatcher._address))
|
||||||
worker._stop()
|
worker._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testJoinDispatcher(self):
|
def testJoinDispatcher(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
dispatcher._stop()
|
dispatcher._stop()
|
||||||
dispatcher.join()
|
dispatcher.join()
|
||||||
|
|
||||||
def testJoinWorker(self):
|
def testJoinWorker(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(dispatcher._address))
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker.join()
|
worker.join()
|
||||||
|
|
||||||
def testDispatcherNumWorkers(self):
|
def testDispatcherNumWorkers(self):
|
||||||
dispatcher = server_lib.DispatchServer(0)
|
dispatcher = server_lib.DispatchServer()
|
||||||
self.assertEqual(0, dispatcher._num_workers())
|
self.assertEqual(0, dispatcher._num_workers())
|
||||||
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
worker1 = server_lib.WorkerServer( # pylint: disable=unused-variable
|
||||||
|
server_lib.WorkerConfig(dispatcher._address))
|
||||||
self.assertEqual(1, dispatcher._num_workers())
|
self.assertEqual(1, dispatcher._num_workers())
|
||||||
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
worker2 = server_lib.WorkerServer( # pylint: disable=unused-variable
|
||||||
|
server_lib.WorkerConfig(dispatcher._address))
|
||||||
self.assertEqual(2, dispatcher._num_workers())
|
self.assertEqual(2, dispatcher._num_workers())
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,16 +107,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
||||||
name) if work_dir is None else work_dir
|
name) if work_dir is None else work_dir
|
||||||
return server_lib.DispatchServer(
|
return server_lib.DispatchServer(
|
||||||
port=port,
|
server_lib.DispatcherConfig(
|
||||||
protocol=server_lib.DEFAULT_PROTOCOL,
|
port=port,
|
||||||
work_dir=work_dir,
|
work_dir=work_dir,
|
||||||
fault_tolerant_mode=fault_tolerant_mode)
|
fault_tolerant_mode=fault_tolerant_mode))
|
||||||
|
|
||||||
def start_worker_server(self, dispatcher, port=0):
|
def start_worker_server(self, dispatcher, port=0):
|
||||||
return server_lib.WorkerServer(
|
return server_lib.WorkerServer(
|
||||||
port=port,
|
server_lib.WorkerConfig(
|
||||||
dispatcher_address=_address_from_target(dispatcher.target),
|
dispatcher_address=_address_from_target(dispatcher.target),
|
||||||
protocol=server_lib.DEFAULT_PROTOCOL)
|
port=port))
|
||||||
|
|
||||||
def restart_dispatcher(self, dispatcher):
|
def restart_dispatcher(self, dispatcher):
|
||||||
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
||||||
@ -124,8 +124,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher._stop()
|
dispatcher._stop()
|
||||||
return self.start_dispatch_server(
|
return self.start_dispatch_server(
|
||||||
port=port,
|
port=port,
|
||||||
work_dir=dispatcher._work_dir,
|
work_dir=dispatcher._config.work_dir,
|
||||||
fault_tolerant_mode=dispatcher._fault_tolerant_mode)
|
fault_tolerant_mode=dispatcher._config.fault_tolerant_mode)
|
||||||
|
|
||||||
def restart_worker(self, worker, dispatcher, use_same_port=True):
|
def restart_worker(self, worker, dispatcher, use_same_port=True):
|
||||||
"""Stops `worker` and returns a new worker."""
|
"""Stops `worker` and returns a new worker."""
|
||||||
@ -362,10 +362,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
except:
|
except:
|
||||||
raise self.skipTest("Flakes in portpicker library do not represent "
|
raise self.skipTest("Flakes in portpicker library do not represent "
|
||||||
"TensorFlow errors.")
|
"TensorFlow errors.")
|
||||||
dispatcher = server_lib.DispatchServer(port=dispatcher_port, start=False)
|
dispatcher = server_lib.DispatchServer(
|
||||||
|
server_lib.DispatcherConfig(port=dispatcher_port), start=False)
|
||||||
worker = server_lib.WorkerServer(
|
worker = server_lib.WorkerServer(
|
||||||
port=0,
|
server_lib.WorkerConfig(
|
||||||
dispatcher_address=_address_from_target(dispatcher.target),
|
dispatcher_address=_address_from_target(dispatcher.target), port=0),
|
||||||
start=False)
|
start=False)
|
||||||
|
|
||||||
def start_servers():
|
def start_servers():
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.DispatcherConfig"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "fault_tolerant_mode"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "port"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "protocol"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "work_dir"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.WorkerConfig"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "dispatcher_address"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "port"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "protocol"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_address"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,13 @@
|
|||||||
path: "tensorflow.data.experimental.service"
|
path: "tensorflow.data.experimental.service"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "DispatcherConfig"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "WorkerConfig"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "distribute"
|
name: "distribute"
|
||||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
@ -8,7 +8,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'port\', \'protocol\', \'work_dir\', \'fault_tolerant_mode\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
argspec: "args=[\'self\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "join"
|
name: "join"
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.DispatcherConfig"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatcherConfig\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "fault_tolerant_mode"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "port"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "protocol"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "work_dir"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.WorkerConfig"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerConfig\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "dispatcher_address"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "port"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "protocol"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_address"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
argspec: "args=[\'self\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "join"
|
name: "join"
|
||||||
|
@ -4,6 +4,14 @@ tf_module {
|
|||||||
name: "DispatchServer"
|
name: "DispatchServer"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "DispatcherConfig"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "WorkerConfig"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "WorkerServer"
|
name: "WorkerServer"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user