Andrew Audibert 6b9a9d98bb Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.
Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master

PiperOrigin-RevId: 321613785
Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
2020-07-16 13:07:28 -07:00

257 lines
8.5 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Python interface for creating dataset servers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.DispatchServer", v1=[])
class DispatchServer(object):
"""An in-process tf.data service dispatch server.
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the dispatcher.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data dispatch process, use join() to block
indefinitely after starting up the server.
```
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
"""
def __init__(self, port, protocol=None, start=True):
"""Creates a new dispatch server.
Args:
port: Specifies the port to bind to.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
"""
if protocol is None:
protocol = "grpc"
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
if start:
self._server.start()
def start(self):
"""Starts this server.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
... start=False)
>>> dispatcher.start()
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
starting the server.
"""
self._server.start()
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated dispatch process.
```
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
joining the server.
"""
self._server.join()
@property
def target(self):
"""Returns a target that can be used to connect to the server.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050".
"""
return "{0}://localhost:{1}".format(self._protocol,
self._server.bound_port())
def _stop(self):
"""Stops the server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
stopping the server.
"""
self._server.stop()
def __del__(self):
self._stop()
@property
def _address(self):
"""Returns the address of the server.
The returned string will be in the form address:port, e.g. "localhost:1000".
"""
return "localhost:{0}".format(self._server.bound_port())
def _num_workers(self):
"""Returns the number of workers registered with the dispatcher."""
return self._server.num_workers()
@tf_export("data.experimental.service.WorkerServer", v1=[])
class WorkerServer(object):
"""An in-process tf.data service worker server.
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
processing for user-defined datasets, and provides the resulting elements over
RPC. A worker is associated with a single
`tf.data.experimental.service.DispatchServer`.
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data worker process, use join() to block
indefinitely after starting up the server.
```
worker = tf.data.experimental.service.WorkerServer(
port=5051, dispatcher_address="grpc://localhost:5050")
worker.join()
```
"""
def __init__(self,
port,
dispatcher_address,
worker_address=None,
protocol=None,
start=True):
"""Creates a new worker server.
Args:
port: Specifies the port to bind to. A value of 0 indicates that the
worker can bind to any available port.
dispatcher_address: Specifies the address of the dispatcher.
worker_address: (Optional.) Specifies the address of the worker server.
This address is passed to the dispatcher so that the dispatcher can
tell clients how to connect to this worker. Defaults to
`"localhost:%port%"`, where `%port%` will be replaced with the port used
by the worker.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
"""
if worker_address is None:
worker_address = "localhost:%port%"
if protocol is None:
protocol = "grpc"
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
port, protocol, dispatcher_address, worker_address)
if start:
self._server.start()
def start(self):
"""Starts this server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
starting the server.
"""
self._server.start()
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated worker process.
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5051, dispatcher_address="grpc://localhost:5050")
worker_server.join()
```
This method currently blocks forever.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
joining the server.
"""
self._server.join()
def _stop(self):
"""Stops the server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
stopping the server.
"""
self._server.stop()
def __del__(self):
self._stop()
@property
def _address(self):
"""Returns the address of the server.
The returned string will be in the form address:port, e.g. "localhost:1000".
"""
return "localhost:{0}".format(self._server.bound_port())