PSv2: Export a few tf.distribute symbols related to TF2 parameter server training.

This change exports the following class symbols, and adds relevant documentation and example code to

tf.distribute.experimental.ParameterServerStrategy
tf.distribute.experimental.coordinator.ClusterCoordinator
tf.distribute.experimental.coordinator.PerWorkerValues
tf.distribute.experimental.coordinator.RemoteValue

PiperOrigin-RevId: 338151262
Change-Id: If2d1c513d30a999c728cecc2e73b75adda1948c2
This commit is contained in:
Rick Chao 2020-10-20 15:34:17 -07:00 committed by TensorFlower Gardener
parent 833b3a49a9
commit 32f35aabce
24 changed files with 911 additions and 309 deletions

View File

@ -208,7 +208,16 @@
how many times the function is called, and independent of global seed
settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* (Experimental) Parameter server training:
* Replaced the existing
`tf.distribute.experimental.ParameterServerStrategy` symbol with
a new class that is for parameter server training in TF2. Usage with
the old symbol, usually with Estimator, should be replaced with
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
* Added `tf.distribute.experimental.coordinator.*` namespace,
including the main API `ClusterCoordinator` for coordinating the
training cluster, the related data structure `RemoteValue`
and `PerWorkerValue`.
* `tf.keras`:
* Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global

View File

@ -153,8 +153,9 @@ py_library(
":multi_process_runner",
":multi_worker_test_base",
":one_device_strategy",
":parameter_server_strategy_v2",
":sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/distribute/experimental",
],
)
@ -1880,6 +1881,7 @@ tf_py_test(
":multi_worker_test_base",
":parameter_server_strategy_v2",
":sharded_variable",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:extra_py_tests_deps",

View File

@ -8,8 +8,8 @@ package(
exports_files(["LICENSE"])
py_library(
name = "client",
srcs = ["client.py"],
name = "cluster_coordinator",
srcs = ["cluster_coordinator.py"],
srcs_version = "PY2AND3",
deps = [
":metric_utils",
@ -34,9 +34,9 @@ py_library(
)
tf_py_test(
name = "client_test",
name = "cluster_coordinator_test",
size = "small",
srcs = ["client_test.py"],
srcs = ["cluster_coordinator_test.py"],
python_version = "PY3",
shard_count = 50,
tags = [
@ -44,7 +44,7 @@ tf_py_test(
"notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards
],
deps = [
":client",
":cluster_coordinator",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@ -66,12 +66,13 @@ tf_py_test(
)
tf_py_test(
name = "client_mpr_test",
srcs = ["client_mpr_test.py"],
name = "cluster_coordinator_mpr_test",
srcs = ["cluster_coordinator_mpr_test.py"],
python_version = "PY3",
shard_count = 2,
tags = ["no_oss"], # TODO(b/162119374)
deps = [
":cluster_coordinator",
":remote_eager_lib",
":utils",
"//tensorflow/python:dtypes",
@ -81,7 +82,6 @@ tf_py_test(
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
@ -102,7 +102,7 @@ tf_py_test(
srcs = ["metric_utils_test.py"],
python_version = "PY3",
deps = [
":client",
":cluster_coordinator",
":metric_utils",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/distribute:multi_worker_test_base",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for `Client` and relevant cluster-worker related library.
"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
This is currently under development and the API is subject to change.
"""
@ -35,7 +35,7 @@ from six.moves import queue
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import metric_utils
from tensorflow.python.distribute.coordinator import metric_utils
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -48,6 +48,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Maximum time for failed worker to come back is 1 hour
_WORKER_MAXIMUM_RECOVERY_SEC = 3600
@ -56,9 +57,10 @@ _WORKER_MAXIMUM_RECOVERY_SEC = 3600
# When the maximum queue size is reached, further schedule calls will become
# blocking until some previously queued closures are executed on workers.
# Note that using an "infinite" queue size can take a non-trivial portion of
# memory, and even lead to client OOM. Modify the size to a smaller value for
# client with constrained memory resource (only recommended for advanced users).
# Also used in unit tests to ensure the correctness when the queue is full.
# memory, and even lead to coordinator OOM. Modify the size to a smaller value
# for coordinator with constrained memory resource (only recommended for
# advanced users). Also used in unit tests to ensure the correctness when the
# queue is full.
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
# RPC error message from PS
@ -99,18 +101,77 @@ class _RemoteValueStatus(enum.Enum):
READY = "READY"
@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
class RemoteValue(object):
"""An asynchronously available value of a remotely executed function.
`RemoteValue` class is used as the return value of `Client.schedule()` where
the underlying concrete value comes at a later time once the function has been
remotely executed. `RemoteValue` can be used as an input to a subsequent
function scheduled with `Client.schedule()`.
`tf.distribute.experimental.coordinator.RemoteValue` class is used as the
return value of
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` where
the underlying concrete value becomes available at a later time once the
function has been remotely executed. The underlying concrete value is the
`tf.Tensor.numpy()` result of the `tf.Tensor`, or the structure of
`tf.Tensor`s, returned from the `tf.function` that was scheduled.
Note: this class is not thread-safe.
`tf.distribute.experimental.coordinator.RemoteValue` to be used as an input to
a subsequent function scheduled with
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` is
currently not supported.
Example:
```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
with strategy.scope():
v1 = tf.Variable(initial_value=0.0)
v2 = tf.Variable(initial_value=1.0)
@tf.function
def worker_fn():
v1.assign_add(0.1)
v2.assign_sub(0.2)
return v1.read_value() / v2.read_value()
result = coordinator.schedule(worker_fn)
# Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
assert result.fetch() == 0.125
for _ in range(10):
# `worker_fn` will be run on arbitrary workers that are available. The
# `result` value will be non-deterministic because the workers are executing
# the functions asynchronously.
result = coordinator.schedule(worker_fn)
```
"""
def __init__(self, closure, type_spec):
def fetch(self):
"""Wait for the result of `RemoteValue` to be ready and return the result.
This makes the value concrete by copying the remote value to local.
Returns:
The actual output of the `tf.function` associated with this `RemoteValue`,
previously by a
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
This can be a single value, or a structure of values, depending on the
output of the `tf.function`.
Raises:
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure, and the user should handle and
reschedule.
"""
raise NotImplementedError("Must be implemented in subclasses.")
class RemoteValueImpl(RemoteValue):
"""Implementation of `RemoteValue`."""
def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
self._closure = closure
# The type spec for this `RemoteValue` which is used to trace functions that
# take this `RemoteValue` as input.
@ -157,16 +218,6 @@ class RemoteValue(object):
self._type_spec = func_graph.convert_structure_to_signature(type_spec)
def fetch(self):
"""Wait for the result of RemoteValue to be ready and return the result.
Returns:
The remote value, as a numpy data type (if scalar) or ndarray.
Raises:
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure, and the user should handle and
reschedule.
"""
self._status_available_event.wait()
if self._status is _RemoteValueStatus.ABORTED:
raise errors.CancelledError(
@ -241,8 +292,23 @@ def _maybe_as_type_spec(val):
return val
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(object):
"""Holds a list of per worker values."""
"""A container that holds a list of values, one value per worker.
`tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
of values, where each of the values represents a resource that are located in
individual workers, and upon being used as one of the `args` or `kwargs` of
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
value specific to a worker will be passed into the function being executed at
that particular worker.
Currently, the only supported path to create an object of
`tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
`iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
distributed dataset instance. The mechanism to create a custom
`tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
"""
def __init__(self, values):
self._values = tuple(values)
@ -262,9 +328,10 @@ def _disallow_remote_value_as_input(structured):
def _raise_if_remote_value(x):
if isinstance(x, RemoteValue):
raise ValueError("RemoteValue cannot be used as an input to scheduled "
"function. Please file a feature request if you need "
"this feature.")
raise ValueError(
"`tf.distribute.experimental.coordinator.RemoteValue` used "
"as an input to scheduled function is not yet "
"supported.")
nest.map_structure(_raise_if_remote_value, structured)
@ -274,8 +341,8 @@ class Closure(object):
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to `Client.schedule` must be a "
"callable object.")
raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
"be a callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
@ -287,9 +354,9 @@ class Closure(object):
replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The client
# has to rely on subsequent operations that raise to catch function
# registration failure.
# failure will not raise exceptions as designed in the runtime. The
# coordinator has to rely on subsequent operations that raise to catch
# function registration failure.
# Record the function tracing overhead. Note that we pass in the tracing
# count of the def_function.Function as a state tracker, so that metrics
@ -303,17 +370,18 @@ class Closure(object):
self._function = cancellation_mgr.get_cancelable_function(
concrete_function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
lambda x: RemoteValueImpl(self, x),
concrete_function.structured_outputs)
elif isinstance(function, tf_function.ConcreteFunction):
self._function = cancellation_mgr.get_cancelable_function(function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), function.structured_outputs)
lambda x: RemoteValueImpl(self, x), function.structured_outputs)
else:
# Regular python functions.
self._function = function
# TODO(yuefengz): maybe we should trace python functions if their inputs
# are Python primitives, tensors and composite tensors.
self._output_remote_values = RemoteValue(self, None)
self._output_remote_values = RemoteValueImpl(self, None)
def _fetch_output_remote_values(self):
"""Temporary method used to sync the scheduler."""
@ -394,7 +462,7 @@ class _CoordinatedClosureQueue(object):
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning(
"In a `Client`, creating an infinite closure queue can "
"In a `ClusterCoordinator`, creating an infinite closure queue can "
"consume a significant amount of memory and even lead to OOM.")
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
@ -430,10 +498,10 @@ class _CoordinatedClosureQueue(object):
# The cancellation manager cannot be reused once cancelled. After all
# closures (queued or inflight) are cleaned up, recreate the cancellation
# manager with clean state.
# Note on thread-safety: this is triggered when one of theses client APIs
# are called: `schedule`, `wait`, and `done`. At the same time, no new
# closures can be constructed (which reads the _cancellation_mgr to get
# cancellable functions).
# Note on thread-safety: this is triggered when one of theses
# ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
# same time, no new closures can be constructed (which reads the
# _cancellation_mgr to get cancellable functions).
self._cancellation_mgr = cancellation.CancellationManager()
def _raise_if_error(self):
@ -742,8 +810,8 @@ class Worker(object):
def _register_resource(self, resource_remote_value):
if not isinstance(resource_remote_value, RemoteValue):
raise ValueError(
"Resource being registered is not of type `RemoteValue`.")
raise ValueError("Resource being registered is not of type "
"`tf.distribute.experimental.coordinator.RemoteValue`.")
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
@ -753,7 +821,7 @@ class Cluster(object):
We assume all function errors are fatal and based on this assumption our
error reporting logic is:
1) Both `schedule` and `join` can raise a non-retryable error which is the
first error seen by the client from any previously scheduled functions.
first error seen by the coordinator from any previously scheduled functions.
2) When an error is raised, there is no guarantee on how many previously
scheduled functions have been executed; functions that have not been executed
will be thrown away and marked as cancelled.
@ -775,17 +843,17 @@ class Cluster(object):
# Ignore PS failures reported by workers due to transient connection errors.
# Transient connectivity issues between workers and PS are relayed by the
# workers to the client, leading the client to believe that there are PS
# failures. The difference between transient vs. permanent PS failure is the
# number of reports from the workers. When this env var is set to a positive
# integer K, the client ignores up to K reports of a failed PS task. I.e.,
# only when there are more than K trials of executing closures fail due to
# errors from the same PS instance do we consider the PS instance encounters
# a failure.
# workers to the coordinator, leading the coordinator to believe that there
# are PS failures. The difference between transient vs. permanent PS failure
# is the number of reports from the workers. When this env var is set to a
# positive integer K, the coordinator ignores up to K reports of a failed PS
# task, i.e., only when there are more than K trials of executing closures
# fail due to errors from the same PS instance do we consider the PS
# instance encounters a failure.
# TODO(b/164279603): Remove this workaround when the underlying connectivity
# issue in gRPC server is resolved.
self._transient_ps_failures_threshold = int(os.environ.get(
"TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._transient_ps_failures_threshold = int(
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps
@ -825,7 +893,7 @@ class Cluster(object):
kwargs: Keyword arguments for `fn`.
Returns:
A structure of `RemoteValue` object.
A `RemoteValue` object.
"""
closure = Closure(
function,
@ -844,68 +912,124 @@ class Cluster(object):
return self._closure_queue.done()
class Client(object):
"""An object to schedule and orchestrate remote function execution.
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
"""An object to schedule and coordinate remote function execution.
A `Client` object represents a program used to create dataset, schedule
functions to be executed, and fetch the results of the functions.
A `tf.distribute.experimental.coordinator.ClusterCoordinator` object
represents a program used to distribute dataset onto the workers, schedule
functions to be executed, and fetch the results of the functions. It expects
the cluster to contain some machines with processes running TensorFlow
servers.
Currently, `Client` is not supported to be used in a standalone manner.
It should be used in conjunction with `ParameterServerStrategyV2`.
Currently, `tf.distribute.experimental.coordinator.ClusterCoordinator` is not
supported to be used in a standalone manner. It should be used in conjunction
with a `tf.distribute` strategy that is designed to work with it. Currently,
only `tf.distribute.experimental.ParameterServerStrategy` is supported to work
with `tf.distribute.experimental.coordinator.ClusterCoordinator`.
__Fault tolerance__
`tf.distribute.experimental.coordinator.ClusterCoordinator`, when used with
`tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
fault tolerance for worker failures. That is, when some workers are not
available for any reason to be reached from the coordinator, the training
progress continues to be made with the remaining operating workers, without
the need of any additional treatment in user code.
On the other hand, when a parameter server or the coordinator fails, a
`tf.errors.UnavailableError` is raised by `ClusterCoordinator.schedule` or
`ClusterCoordinator.join`. If any parameter server fails, the user should
restart the processes on the failed parameter servers when they become
available, *and* restart the process on the coordinator, so the coordinator
can re-create the variables on the parameter servers. If the coordinator fails
but all other machines continue to be operating, the user only needs to
restart the process on the coordinator, which will automatically connect to
the parameter servers and workers, and continue the progress.
It is thus essential that in user's custom training loop, a checkpoint file is
periodically saved, and restored at the start of the program.
* At-least-once semantics of `schedule`: `schedule` puts the `tf.function` in
a queue where it gets picked up by a worker to execute. If a worker picks up
a `tf.function`, has begun to execute it, but the worker process is
disconnected from the coordinator in the middle of execution, the
`tf.function` is deemed not completed, and is put back to the queue. This is
regardless of how much progress the `tf.function` has run. Once it is
re-queued, it will be picked up again by an arbitrary available worker at some
later time. As a result, the same function may be executed more than once, but
not less than once. If an `tf.keras.optimizers.Optimizer` is used,
`tf.keras.optimizers.Optimizer.iterations` roughly indicates the number of
times the gradients have been applied and can be used as an approximation of
the total number of steps. The number of times the function has been scheduled
by the `tf.distribute.experimental.coordinator.ClusterCoordinator`, on the
other hand, should not be indicative of actual steps run. See
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` for more
information.
See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
example usage of this API.
This is currently under development, and the API as well as implementation
is subject to changes.
are subject to changes.
"""
def __init__(self, strategy):
"""Initialization of a `Client` instance.
This connects the client to remote workers and parameter servers, through
a `tf.config.experimental_connect_to_cluster` call.
"""Initialization of a `ClusterCoordinator` instance.
Args:
strategy: a `tf.distribute.Strategy` object. Currently, only
`ParameterServerStrategyV2` is supported.
strategy: a supported `tf.distribute.Strategy` object. Currently, only
`tf.distribute.experimental.ParameterServerStrategy` is supported.
Raises:
ValueError: if the strategy being used is not supported.
"""
if not isinstance(strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2):
raise ValueError("Only `ParameterServerStrategyV2` is supported in "
"`Client` currently.")
raise ValueError(
"Only `tf.distribute.experimental.ParameterServerStrategy` "
"is supported to work with "
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
"currently.")
self._strategy = strategy
self.cluster = Cluster(strategy)
@property
def strategy(self):
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
return self._strategy
def schedule(self, fn, args=None, kwargs=None):
"""Schedules `fn` to be dispatched to a worker for execution asynchronously.
"""Schedules `fn` to be dispatched to a worker for asynchronous execution.
When calling `schedule` with a function `fn`, `fn` will be executed on a
remote worker at some later time. The process is asynchronous, meaning
`schedule` returns immediately, possibly without having the result ready
yet. `schedule` returns a structure of `RemoteValue` object, which wraps the
output of the function. Call `fetch()` on `RemoteValue` to wait for the
yet. `schedule` returns a
`tf.distribute.experimental.coordinator.RemoteValue` object, which wraps the
output of the function. After `schedule` is called, `fetch` can be called
on the `tf.distribute.experimental.coordinator.RemoteValue` to wait for the
function execution to finish and retrieve its output from the remote worker.
On the other hand, call
`tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
all scheduled functions to finish execution before proceeding.
`schedule` guarantees that `fn` will be executed on a worker at least once;
it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but `Client` guarantees that in those events, the function will
eventually be fully executed, possibly on a different worker that is
available.
executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
guarantees that in those events, the function will eventually be fully
executed, possibly on a different worker that is available.
If any previously scheduled function raises an error, `schedule` will fail
by raising any one of those errors, and clear the errors collected so far.
There are two implications when this happens: 1) user should call `schedule`
with `fn` again to re-schedule, and 2) some of the previously scheduled
functions may have not been executed. User can call `fetch` on the returned
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
reschedule the corresponding function if needed.
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
executed, failed, or cancelled, and reschedule the corresponding function if
needed.
When `schedule` raises, it guarantees that there is no function that is
still being executed.
@ -914,12 +1038,14 @@ class Client(object):
execution, or priority of the workers.
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is
executed on a worker. They can be `PerWorkerValues`, which is a collection
of values, each of which represents a component specific to a worker; in
this case, the argument will be substituted with the corresponding component
on the target worker. Arguments that are not `PerWorkerValues` will be
passed into `fn` as-is. Currently, `RemoteValue` is not supported to be
input `args` or `kwargs`.
executed on a worker. They can be
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
'collection of values, each of which represents a component specific to a
worker; in this case, the argument will be substituted with the
corresponding component on the target worker. Arguments that are not
`tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
`fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
is not supported to be input `args` or `kwargs`.
Args:
fn: A `tf.function`; the function to be dispatched to a worker for
@ -928,12 +1054,13 @@ class Client(object):
kwargs: Keyword arguments for `fn`.
Returns:
A structure of `RemoteValue` object.
A `tf.distribute.experimental.coordinator.RemoteValue` object that
represents the output of the function scheduled.
Raises:
Exception: one of the exceptions caught by the client by any previously
scheduled function since the last time an error was thrown or since
the beginning of the program.
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
@ -946,18 +1073,18 @@ class Client(object):
If any previously scheduled function raises an error, `join` will fail by
raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may have not been
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
they have executed, failed, or cancelled. If some that have been cancelled
need to be rescheduled, users should call `schedule` with the function
again.
executed. Users can call `fetch` on the returned
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
executed, failed, or cancelled. If some that have been cancelled need to be
rescheduled, users should call `schedule` with the function again.
When `join` returns or raises, it guarantees that there is no function that
is still being executed.
Raises:
Exception: one of the exceptions caught by the client by any previously
scheduled function since the last time an error was thrown or since
the beginning of the program.
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
self.cluster.join()
@ -969,6 +1096,9 @@ class Client(object):
When `done` returns True or raises, it guarantees that there is no function
that is still being executed.
Returns:
Whether all the scheduled functions have finished execution.
"""
return self.cluster.done()
@ -978,12 +1108,15 @@ class Client(object):
This creates the given dataset generated by dataset_fn on the workers
and returns an object that represents the collection of those individual
datasets. Calling `iter` on such collection of dataset returns a
`PerWorkerValues`, which is a collection of iterators, where the iterators
have been placed on respective workers.
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
collection of iterators, where the iterators have been placed on respective
workers.
Calling `next` on this `PerWorkerValues` of iterators is currently
unsupported; it is meant to be passed as an argument into `Client.schedule`.
When the scheduled function is picked up and being executed by a worker, the
Calling `next` on this
`tf.distribute.experimental.coordinator.PerWorkerValues` of iterators is
currently unsupported; it is meant to be passed as an argument into
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
the scheduled function is picked up and being executed by a worker, the
function will receive the individual iterator that corresponds to the
worker, and now `next` can be called on iterator to get the next (batch or
example) of data.
@ -993,6 +1126,26 @@ class Client(object):
examples may be skipped and not covered by other workers, if the dataset is
sharded.
Example:
```python
@tf.function
def worker_fn(iterator):
return next(iterator)
def dataset_fn():
return tf.data.from_tensor_slices([3] * 3)
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
per_worker_dataset = coordinator.create_per_worker_dataset(dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3
```
Args:
dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers.
@ -1000,7 +1153,8 @@ class Client(object):
Returns:
An object that represents the collection of those individual
datasets. `iter` is expected to be called on this object that returns
a `PerWorkerValues` of the iterators (that are on the workers).
a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
iterators (that are on the workers).
"""
input_workers = input_lib.InputWorkers([
(w.device_name, [w.device_name]) for w in self.cluster.workers
@ -1011,7 +1165,8 @@ class Client(object):
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
"""Synchronously create resources on the workers.
The resources are represented by `RemoteValue`s.
The resources are represented by
`tf.distribute.experimental.coordinator.RemoteValue`s.
Args:
fn: The function to be dispatched to all workers for execution
@ -1020,7 +1175,9 @@ class Client(object):
kwargs: Keyword arguments for `fn`.
Returns:
A `PerWorkerValues` object, which wraps a tuple of `RemoteValue` objects.
A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
objects.
"""
results = []
for w in self.cluster.workers:
@ -1028,21 +1185,52 @@ class Client(object):
return PerWorkerValues(tuple(results))
def fetch(self, val):
"""Blocking call to fetch results from `RemoteValue`s.
"""Blocking call to fetch results from the remote values.
This returns the execution result of `RemoteValue`s; if not ready,
waiting for it while blocking the caller.
This is a wrapper around
`tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
`RemoteValue` structure; it returns the execution results of
`RemoteValue`s. If not ready, wait for them while blocking the caller.
Example:
```python
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),))
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
```
Args:
val: The value to fetch the results from. If this is structure of
`RemoteValue`, `fetch()` will be called on the individual `RemoteValue`
to get the result.
`tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
called on the individual
`tf.distribute.experimental.coordinator.RemoteValue` to get the result.
Returns:
If `val` is a `RemoteValue` or a structure of `RemoteValue`s, returns
the fetched `RemoteValue` value immediately if it's available, or blocks
the call until it's available, and returns the fetched `RemoteValue`
values with the same structure. If `val` is other types, return (`val`,).
If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
values immediately if they are available, or block the call until they are
available, and return the fetched
`tf.distribute.experimental.coordinator.RemoteValue` values with the same
structure. If `val` is other types, return it as-is.
"""
def _maybe_fetch(val):
@ -1052,10 +1240,7 @@ class Client(object):
return val
# TODO(yuefengz): we should fetch values in a batch.
result = nest.map_structure(_maybe_fetch, val)
if not isinstance(result, tuple):
return (result,)
return result
return nest.map_structure(_maybe_fetch, val)
# pylint: disable=missing-function-docstring
@ -1075,13 +1260,14 @@ def handle_parameter_server_failure():
class _PerWorkerDistributedDataset(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, input_workers, client):
def __init__(self, dataset_fn, input_workers, coordinator):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a `Dataset`.
input_workers: an `InputWorkers` object.
client: a `Client` object, used to create dataset resources.
coordinator: a `ClusterCoordinator` object, used to create dataset
resources.
"""
def disallow_variable_creation(next_creator, **kwargs):
raise ValueError("Creating variables in `dataset_fn` is not allowed.")
@ -1094,7 +1280,7 @@ class _PerWorkerDistributedDataset(object):
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = dataset_fn
self._input_workers = input_workers
self._client = client
self._coordinator = coordinator
self._element_spec = None
def __iter__(self):
@ -1112,7 +1298,7 @@ class _PerWorkerDistributedDataset(object):
# If _PerWorkerDistributedDataset.__iter__ is called multiple
# times, for the same object it should only create and register resource
# once. Using object id to distinguish different iterator resources.
per_worker_iterator = self._client._create_per_worker_resources(
per_worker_iterator = self._coordinator._create_per_worker_resources(
_create_per_worker_iterator)
# Setting type_spec of each RemoteValue so that functions taking these
@ -1131,7 +1317,7 @@ class _PerWorkerDistributedDataset(object):
class _PerWorkerDistributedIterator(PerWorkerValues):
"""Distributed iterator for `Client`."""
"""Distributed iterator for `ClusterCoordinator`."""
def __next__(self):
return self.get_next()

View File

@ -13,21 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`."""
"""Multi-process runner tests for `ClusterCoordinator` with PSv2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.client import utils
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import utils
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
@ -38,7 +37,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
class ClientMprTest(test.TestCase):
class ClusterCoordinatorMprTest(test.TestCase):
def testScheduleTranslatePSFailureError(self):
self._test_translate_ps_failure_error(test_schedule=True)
@ -56,7 +55,7 @@ class ClientMprTest(test.TestCase):
utils.start_server(cluster_resolver, "grpc")
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
ps_client = client_lib.Client(strategy)
ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)
with strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
@ -68,19 +67,19 @@ class ClientMprTest(test.TestCase):
v.assign_add(1)
# Keep the two workers occupied.
ps_client.schedule(worker_fn)
ps_client.schedule(worker_fn)
ps_coordinator.schedule(worker_fn)
ps_coordinator.schedule(worker_fn)
# Now the main process can terminate.
functions_scheduled_event.set()
# Verified that join and schedule indeed raise UnavailableError.
try:
if test_join:
ps_client.join()
ps_coordinator.join()
if test_schedule:
while ps_client.cluster._closure_queue._error is None:
while ps_coordinator.cluster._closure_queue._error is None:
time.sleep(1)
ps_client.schedule(worker_fn)
ps_coordinator.schedule(worker_fn)
except errors.UnavailableError:
# The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure.
@ -91,7 +90,7 @@ class ClientMprTest(test.TestCase):
# failure.
worker_fn()
except Exception as e: # pylint: disable=broad-except
if client_lib._is_ps_failure(e):
if coordinator_lib._is_ps_failure(e):
if worker_id < 2:
continue
logging.info("_test_translate_ps_failure_error ends properly.")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for client.py."""
"""Tests for coordinator.py."""
from __future__ import absolute_import
from __future__ import division
@ -29,8 +29,8 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@ -52,7 +52,7 @@ from tensorflow.python.util import nest
class CoordinatedClosureQueueTest(test.TestCase):
def testBasic(self):
queue = client_lib._CoordinatedClosureQueue()
queue = coordinator_lib._CoordinatedClosureQueue()
closure1 = self._create_closure(queue._cancellation_mgr)
queue.put(closure1)
self.assertIs(closure1, queue.get())
@ -64,7 +64,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
queue.wait()
def testProcessAtLeaseOnce(self):
closure_queue = client_lib._CoordinatedClosureQueue()
closure_queue = coordinator_lib._CoordinatedClosureQueue()
labels = ['A', 'B', 'C', 'D', 'E']
processed_count = collections.defaultdict(int)
@ -94,7 +94,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
cm = cancellation.CancellationManager()
for label in labels:
closure_queue.put(client_lib.Closure(get_func(label), cm))
closure_queue.put(coordinator_lib.Closure(get_func(label), cm))
t1 = threading.Thread(target=process_queue, daemon=True)
t1.start()
t2 = threading.Thread(target=process_queue, daemon=True)
@ -111,7 +111,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
coord.join([t1, t2])
def testNotifyBeforeWait(self):
closure_queue = client_lib._CoordinatedClosureQueue()
closure_queue = coordinator_lib._CoordinatedClosureQueue()
def func():
logging.info('func running')
@ -123,7 +123,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
closure_queue.get()
closure_queue.mark_finished()
closure_queue.put(client_lib.Closure(func, closure_queue._cancellation_mgr))
closure_queue.put(
coordinator_lib.Closure(func, closure_queue._cancellation_mgr))
t = threading.Thread(target=process_queue)
t.start()
coord.join([t])
@ -159,7 +160,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
# TODO(b/165013260): Fix this
self.skipTest('Test is currently broken on Windows with Python 3.8')
closure_queue = client_lib._CoordinatedClosureQueue()
closure_queue = coordinator_lib._CoordinatedClosureQueue()
closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
closure = closure_queue.get()
@ -189,10 +190,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
def some_function():
return 1.0
return client_lib.Closure(some_function, cancellation_mgr)
return coordinator_lib.Closure(some_function, cancellation_mgr)
def _put_two_closures_and_get_one(self):
closure_queue = client_lib._CoordinatedClosureQueue()
closure_queue = coordinator_lib._CoordinatedClosureQueue()
closure1 = self._create_closure(closure_queue._cancellation_mgr)
closure_queue.put(closure1)
@ -330,7 +331,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
# Closure2 was an inflight closure when it got cancelled.
self.assertEqual(closure2._output_remote_values._status,
client_lib._RemoteValueStatus.READY)
coordinator_lib._RemoteValueStatus.READY)
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
closure2._fetch_output_remote_values()
@ -361,7 +362,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
def testThreadSafey(self):
thread_count = 10
queue = client_lib._CoordinatedClosureQueue()
queue = coordinator_lib._CoordinatedClosureQueue()
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
# `mark_finished`.
@ -427,7 +428,7 @@ class TestCaseWithErrorReportingThread(test.TestCase):
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
def make_client(num_workers, num_ps):
def make_coordinator(num_workers, num_ps):
# TODO(rchao): Test the internal rpc_layer version.
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
@ -438,16 +439,16 @@ def make_client(num_workers, num_ps):
ClusterSpec(cluster_def), rpc_layer='grpc')
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
return client_lib.Client(strategy)
return coordinator_lib.ClusterCoordinator(strategy)
class ClientTest(TestCaseWithErrorReportingThread):
class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
super(ClientTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
super(ClusterCoordinatorTest, cls).setUpClass()
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.coordinator.strategy
def testFnReturnNestedValues(self):
x = constant_op.constant(1)
@ -456,9 +457,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
def f():
return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
got = self.client.schedule(f)
got = self.coordinator.schedule(f)
want = 2, (3, 4), [5], {'v': 1}
self.assertEqual(self.client.fetch(got), want)
self.assertEqual(self.coordinator.fetch(got), want)
def testInputFunction(self):
@ -474,12 +475,14 @@ class ClientTest(TestCaseWithErrorReportingThread):
v.assign_add(x)
return x
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
result = self.client.fetch(result)
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
result = self.coordinator.schedule(
worker_fn, args=(iter(distributed_dataset),))
result = self.coordinator.fetch(result)
self.assertEqual(result, (1,))
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
result = self.client.fetch(result)
result = self.coordinator.schedule(
worker_fn, args=(iter(distributed_dataset),))
result = self.coordinator.fetch(result)
self.assertEqual(result, (1,))
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
@ -499,30 +502,30 @@ class ClientTest(TestCaseWithErrorReportingThread):
x = next(iterator)
v.assign_add(x)
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
iterator = iter(distributed_dataset)
# Verifying joining without any scheduling doesn't hang.
self.client.join()
self.coordinator.join()
self.assertEqual(v.read_value().numpy(), 0)
for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,))
self.client.join()
self.coordinator.schedule(worker_fn, args=(iterator,))
self.coordinator.join()
# With 5 addition it should be 2*5 = 10.
self.assertEqual(v.read_value().numpy(), 10)
for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,))
self.coordinator.schedule(worker_fn, args=(iterator,))
# Verifying multiple join is fine.
self.client.join()
self.client.join()
self.client.join()
self.coordinator.join()
self.coordinator.join()
self.coordinator.join()
self.assertTrue(self.client.done())
self.assertTrue(self.coordinator.done())
# Likewise, it's now 20.
self.assertEqual(v.read_value().numpy(), 20)
@ -542,8 +545,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
def worker_fn(iterator):
return next(iterator)
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
result = self.coordinator.schedule(
worker_fn, args=(iter(distributed_dataset),))
self.assertEqual(result.fetch(), (10,))
self.assertEqual(self._map_fn_tracing_count, 1)
@ -554,28 +558,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
return v.read_value()
with self.assertRaises(ValueError):
self.client.create_per_worker_dataset(input_fn)
self.coordinator.create_per_worker_dataset(input_fn)
def testDatasetsShuffledDifferently(self):
# This test requires at least two workers in the cluster.
self.assertGreaterEqual(len(self.client.cluster.workers), 2)
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
random_seed.set_random_seed(None)
def input_fn():
return dataset_ops.DatasetV2.range(0, 100).shuffle(100)
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
distributed_iterator = iter(distributed_dataset)
# Get elements from the first two iterators.
iterator_1 = distributed_iterator._values[0]
iterator_1._rebuild_on(self.client.cluster.workers[0])
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
iterator_1 = iterator_1.fetch()
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
iterator_2 = distributed_iterator._values[1]
iterator_2._rebuild_on(self.client.cluster.workers[1])
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
iterator_2 = iterator_2.fetch()
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
@ -592,12 +596,12 @@ class ClientTest(TestCaseWithErrorReportingThread):
self.assertIn('worker', var.device)
return var
worker_local_var = self.client._create_per_worker_resources(create_var)
worker_local_var = self.coordinator._create_per_worker_resources(create_var)
# The following is a workaround to allow `worker_local_var` to be passed in
# as args to the `client.schedule` method which requires tensor specs to
# trace tf.function but _create_worker_resources' return values don't have
# tensor specs. We can get rid of this workaround once
# as args to the `coordinator.schedule` method which requires tensor specs
# to trace tf.function but _create_worker_resources' return values don't
# have tensor specs. We can get rid of this workaround once
# _create_worker_resources is able to infer the tensor spec of the return
# value of the function passed in. See b/154675763.
for var in worker_local_var._values:
@ -609,10 +613,10 @@ class ClientTest(TestCaseWithErrorReportingThread):
for _ in range(10):
# Which slice of `worker_local_var` will be used will depend on which
# worker the `worker_fn` gets scheduled on.
self.client.schedule(worker_fn, args=(worker_local_var,))
self.client.join()
self.coordinator.schedule(worker_fn, args=(worker_local_var,))
self.coordinator.join()
var_sum = sum(self.client.fetch(worker_local_var._values))
var_sum = sum(self.coordinator.fetch(worker_local_var._values))
self.assertEqual(var_sum, 10.0)
def testDisallowRemoteValueAsInput(self):
@ -625,28 +629,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
def func_1(x):
return x + 1.0
remote_v = self.client.schedule(func_0)
remote_v = self.coordinator.schedule(func_0)
with self.assertRaises(ValueError):
self.client.schedule(func_1, args=(remote_v,))
self.coordinator.schedule(func_1, args=(remote_v,))
class LimitedClosureQueueSizeBasicTest(ClientTest):
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
"""Test basic functionality works with explicit maximum closure queue size.
Execute the same set of test cases as in `ClientTest`, with an
Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
explicit size limit for the closure queue. Note that even when the queue size
is set to infinite, there is still a maximum practical size (depends on host
memory limit) that might cause the queue.put operations to be blocking when
scheduling a large number of closures on a big cluster. These tests make sure
that the client does not run into deadlocks in such scenario.
that the coordinator does not run into deadlocks in such scenario.
"""
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.coordinator.strategy
class ErrorReportingTest(TestCaseWithErrorReportingThread):
@ -654,8 +658,8 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
super(ErrorReportingTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.coordinator.strategy
with cls.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0)
@ -686,80 +690,80 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
def testJoinRaiseError(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
self.coordinator.schedule(self._normal_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
self.coordinator.join()
def testScheduleRaiseError(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
self.coordinator.schedule(self._normal_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
while True:
self.client.schedule(self._normal_function)
self.coordinator.schedule(self._normal_function)
def testScheduleRaiseErrorWithMultipleFailure(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
self.coordinator.schedule(self._normal_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
while True:
self.client.schedule(self._error_function)
self.client.join()
self.coordinator.schedule(self._error_function)
self.coordinator.join()
def testErrorWillbeCleared(self):
self.client.schedule(self._error_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
self.coordinator.join()
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
self.coordinator.schedule(self._normal_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
self.coordinator.join()
def testRemoteValueReturnError(self):
result = self.client.schedule(self._error_function)
result = self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
result.fetch()
# Clear the error.
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
self.coordinator.join()
def testInputError(self):
worker_local_val = self.client._create_per_worker_resources(
worker_local_val = self.coordinator._create_per_worker_resources(
self._error_function)
@def_function.function
def func(x):
return x + 1
result = self.client.schedule(func, args=(worker_local_val,))
with self.assertRaises(client_lib.InputError):
self.client.join()
result = self.coordinator.schedule(func, args=(worker_local_val,))
with self.assertRaises(coordinator_lib.InputError):
self.coordinator.join()
with self.assertRaises(client_lib.InputError):
with self.assertRaises(coordinator_lib.InputError):
result.fetch()
def testCancellation(self):
for _ in range(3):
self.client.schedule(self._normal_function)
long_function = self.client.schedule(self._long_function)
self.client.schedule(self._error_function)
self.coordinator.schedule(self._normal_function)
long_function = self.coordinator.schedule(self._long_function)
self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
self.coordinator.join()
with self.assertRaises(errors.CancelledError):
long_function.fetch()
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.join()
self.coordinator.schedule(self._normal_function)
self.coordinator.join()
class LimitedClosureQueueErrorTest(ErrorReportingTest):
@ -772,11 +776,11 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueErrorTest, cls).setUpClass()
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.coordinator.strategy
with cls.client.strategy.scope():
with cls.coordinator.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0)
@ -785,8 +789,8 @@ class StrategyIntegrationTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(StrategyIntegrationTest, cls).setUpClass()
cls.client = make_client(num_workers=1, num_ps=1)
cls.strategy = cls.client.strategy
cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
cls.strategy = cls.coordinator.strategy
def testBasicVariableAssignment(self):
self.strategy.extended._variable_count = 0
@ -801,9 +805,9 @@ class StrategyIntegrationTest(test.TestCase):
v2.assign_sub(0.2)
return v1.read_value() / v2.read_value()
results = self.client.schedule(worker_fn)
results = self.coordinator.schedule(worker_fn)
logging.info('Results of experimental_run_v2: %f',
self.client.fetch(results))
self.coordinator.fetch(results))
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
@ -826,12 +830,14 @@ class StrategyIntegrationTest(test.TestCase):
return self.strategy.run(replica_fn, args=(input_tensor,))
# Asserting scheduling in scope has the expected behavior.
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
self.assertIsInstance(result, client_lib.RemoteValue)
result = self.coordinator.schedule(
worker_fn, args=(constant_op.constant(3),))
self.assertIsInstance(result, coordinator_lib.RemoteValue)
self.assertEqual(result.fetch(), 4)
# Asserting scheduling out of scope has the expected behavior.
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
result = self.coordinator.schedule(
worker_fn, args=(constant_op.constant(3),))
self.assertEqual(result.fetch(), 4)

View File

@ -31,15 +31,15 @@ enable_metrics = False
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
_function_tracing_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets,
'/tensorflow/api/ps_strategy/coordinator/function_tracing', _time_buckets,
'Sampler to track the time (in seconds) for tracing functions.')
_closure_execution_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets,
'/tensorflow/api/ps_strategy/coordinator/closure_execution', _time_buckets,
'Sampler to track the time (in seconds) for executing closures.')
_remote_value_fetch_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets,
'/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', _time_buckets,
'Sampler to track the time (in seconds) for fetching remote_value.')
_METRICS_MAPPING = {

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metrics collecting in client."""
"""Tests for metrics collecting in coordinator."""
from __future__ import absolute_import
from __future__ import division
@ -22,9 +22,9 @@ from __future__ import print_function
import time
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import metric_utils
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import metric_utils
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.training.server_lib import ClusterSpec
@ -35,7 +35,7 @@ class MetricUtilsTest(test.TestCase):
def get_rpc_layer(self):
return 'grpc'
def testClientMetrics(self):
def testClusterCoordinatorMetrics(self):
metric_utils.enable_metrics = True
@ -48,7 +48,7 @@ class MetricUtilsTest(test.TestCase):
ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
cluster = client.Cluster(strategy)
cluster = coordinator_lib.Cluster(strategy)
@def_function.function
def func():

View File

@ -48,7 +48,7 @@ _LOCAL_CPU = "/device:CPU:0"
# TODO(yuefengz): maybe cache variables on local CPU.
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
# TODO(b/171250971): Remove this and change all symbol usage of this to V1.
class ParameterServerStrategy(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy.
@ -154,8 +154,9 @@ class ParameterServerStrategy(distribute_lib.Strategy):
def _raise_pss_error_if_eager(self):
if context.executing_eagerly():
raise NotImplementedError("ParameterServerStrategy currently only works "
"with the tf.Estimator API")
raise NotImplementedError(
"`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
"currently only works with the tf.Estimator API")
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring

View File

@ -37,82 +37,400 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
class ParameterServerStrategyV2(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy.
"""An multi-worker tf.distribute strategy with parameter servers.
Currently, `ParameterServerStrategyV2` is not supported to be used as a
standalone tf.distribute strategy. It should be used in conjunction with
`Client`. Please see `Client` for more information.
Parameter server training refers to the distributed training architecture that
requires two types of tasks in the cluster: workers (referred to as "worker"
task) and parameter servers (referred to as "ps" task). The variables and
updates to those variables are placed on ps, and most computation intensive
operations are placed on workers.
This is currently under development, and the API as well as implementation
is subject to changes.
In TF2, parameter server training makes use of one coordinator, with some
number of workers, and (usually fewer) ps. The coordinator uses a
`tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
cluster, and a `tf.distribute.experimental.ParameterServerStrategy` for
variable distribution. The coordinator does not perform the actual training.
Each of the workers and ps runs a `tf.distribute.Server`, which the
coordinator connects to through the use of aforementioned two APIs.
For the training to work, the coordinator sends requests to workers for the
`tf.function`s to be executed on remote workers. Upon receiving requests from
the coordinator, a worker executes the `tf.function` by reading the variables
from parameter servers, executing the ops, and updating the variables on the
parameter servers. Each of the worker only processes the requests from the
coordinator, and communicates with parameter servers, without direct
interactions with any of the other workers in the cluster.
As a result, failures of some workers do not prevent the cluster from
continuing the work, and this allows the cluster to train with instances that
can be occasionally unavailable (e.g. preemptible or spot instances). The
coordinator and parameter servers though, must be available at all times for
the cluster to make progress.
Note that the coordinator is not one of the training worker. Instead, its
responsibility includes placing variables on ps, remotely executing
`tf.function`s on workers, and saving checkpoints. Parameter server training
thus consists of a server cluster with worker and ps, and a coordinator which
connects to them to coordinate. Optionally, an evaluator can be run on the
side that periodically reads the checkpoints saved by the coordinator, and
saves summaries for example.
`tf.distribute.experimental.ParameterServerStrategy` works closely with the
associated `tf.distribute.experimental.coordinator.ClusterCoordinator` object,
and should be used in conjunction with it. Standalone usage of
`tf.distribute.experimental.ParameterServerStrategy` without a
`tf.distribute.experimental.coordinator.ClusterCoordinator` indicates
a parameter server training scheme without a centralized coordinator, which is
not supported at this time.
__Example code for coordinator__
Here's an example usage of the API, with a custom training loop to train a
model. This code snippet is intended to be run on (the only) one machine that
is designated as the coordinator. Note that `cluster_resolver`,
`variable_partitioner`, and `dataset_fn` arguments are explained in the
following "Cluster setup", "Variable partitioning", and "Dataset preparation"
sections.
Currently, environment variable `GRPC_FAIL_FAST` needs to be set in all tasks
to work around a known hanging issue as the following code illustrates:
```python
# Set the environment variable to allow reporting worker and ps failure to the
# coordinator.
os.environ["GRPC_FAIL_FAST"] = "use_caller"
# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...,
variable_partitioner=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
# Prepare a distribute dataset that will place datasets on the workers.
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
with strategy.scope():
model = ... # Variables created can possibly be container of variables
optimizer, metrics = ... # Keras optimizer/metrics are great choices
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint_dir, max_to_keep=2)
# `load_checkpoint` infers initial epoch from `optimizer.iterations`.
initial_epoch = load_checkpoint(checkpoint_manager) or 0
@tf.function
def worker_fn(iterator):
def replica_fn(inputs):
batch_data, labels = inputs
# calculate gradient, applying gradient, metrics update etc.
strategy.run(replica_fn, args=(next(iterator),))
for epoch in range(initial_epoch, num_epoch):
distributed_iterator = iter(distributed_dataset) # Reset iterator state.
for step in range(steps_per_epoch):
# Asynchronously schedule the `worker_fn` to be executed on an arbitrary
# worker. This call returns immediately.
coordinator.schedule(worker_fn, args=(distributed_iterator,))
# `join` blocks until all scheduled `worker_fn`s finish execution. Once it
# returns, we can read the metrics and save checkpoints as needed.
coordinator.join()
logging.info('Metric result: %r', metrics.result())
train_accuracy.reset_states()
checkpoint_manager.save()
```
__Example code for worker and parameter servers__
In addition to the coordinator, there should be multiple machines designated
as "worker" or "ps". They should run the following code to start a TensorFlow
server, waiting for coordinator's request to execute functions or place
variables:
```python
# Set the environment variable to allow reporting worker and ps failure to the
# coordinator.
os.environ["GRPC_FAIL_FAST"] = "use_caller"
# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
# the cluster information. See below "Cluster setup" section.
cluster_resolver = ...
server = tf.distribute.Server(
cluster_resolver.cluster_spec().as_cluster_def(),
job_name=cluster_resolver.task_type,
task_index=cluster_resolver.task_id,
protocol=protocol)
# Blocking the process that starts a server from exiting.
server.join()
```
__Cluster setup__
In order for the tasks in the cluster to know other tasks' addresses,
a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
in coordinator, worker, and ps. The
`tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
the cluster information, as well as the task type and id of the current task.
See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
If `TF_CONFIG` environment variable is used for the processes to know the
cluster information, a
`tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used. Note
that for legacy reason, "chief" should be used as the task type for the
coordinator, as the following example demonstrates. Here we set `TF_CONFIG`
in environment variable, intended to be run by the process of the machine
designated as the parameter server (task type "ps") and index 1 (the second),
in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that the
it needs to be set before the use of
`tf.distribute.cluster_resolver.TFConfigClusterResolver`.
Example code for cluster setup:
```python
os.environ['TF_CONFIG'] = '''
{
"cluster": {
"chief": ["chief.example.com:2222"],
"ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
"worker": ["worker0.example.com:2222", "worker1.example.com:2222",
"worker2.example.com:2222"]
},
"task": {
"type": "ps",
"index": 1
}
}
'''
os.environ["GRPC_FAIL_FAST"] = "use_caller"
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
# If coordinator ("chief" task type), create a strategy
if cluster_resolver.task_type == 'chief':
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
...
# If worker/ps, create a server
elif cluster_resolver.task_type in ("worker", "ps"):
server = tf.distribute.Server(...)
...
```
__Variable creation with `strategy.scope()`__
`tf.distribute.experimental.ParameterServerStrategy` follows the
`tf.distribute` API contract where variable creation is expected to be inside
the context manager returned by `strategy.scope()`, in order to be correctly
placed on parameter servers in a round-robin manner:
```python
# In this example, we're assuming having 3 ps.
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
# Variables should be created inside scope to be placed on parameter servers.
# If created outside scope such as `v1` here, it would be placed on
coordinator.
v1 = tf.Variable(initial_value=0.0)
with strategy.scope():
v2 = tf.Variable(initial_value=1.0)
v3 = tf.Variable(initial_value=2.0)
v4 = tf.Variable(initial_value=3.0)
v5 = tf.Variable(initial_value=4.0)
# v2 through v5 are created in scope and are distributed on parameter servers.
# Default placement is round-robin but the order should not be relied on.
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
```
See `distribute.Strategy.scope` for more information.
__Variable partitioning__
Having dedicated servers to store variables means being able to divide up, or
"shard" the variables across the ps. Large embeddings that would otherwise
exceed memory limit of a single machine can be used in a cluster with enough
number of ps.
With `tf.distribute.experimental.ParameterServerStrategy`, if a
`variable_partitioner` is provided to `__init__` and certain conditions are
satisfied, the resulting variables created in scope are sharded across the
parameter servers, in a round-robin fashion. The variable reference returned
from `tf.Variable` becomes a type that serves as the container of the sharded
variables. Access `variables` attribute of this container for the actual
variable components. See arguments section of
`tf.distribute.experimental.ParameterServerStrategy.__init__` for more
information.
To initialize the sharded variables in a more memory-efficient way, use an
initializer whose `__call__` accepts a `shard_info` argument, and use
`shard_info.offset` and `shard_info.shape` to create and return a
partition-aware `tf.Tensor` to initialize the variable components.
```python
class PartitionAwareIdentity(object):
def __call__(self, shape, dtype, shard_info):
value = tf.eye(*shape, dtype=dtype)
if shard_info is not None:
value = tf.slice(value, shard_info.offset, shard_info.shape)
return value
cluster_resolver = ...
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver, tf.fixed_size_partitioner(2))
with strategy.scope():
initializer = PartitionAwareIdentity()
initial_value = functools.partial(initializer, shape=(4, 4), dtype=tf.int64)
v = tf.Variable(
initial_value=initial_value, shape=(4, 4), dtype=tf.int64)
# `v.variables` gives the actual variable components.
assert len(v.variables) == 2
assert v.variables[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v.variables[1].device == "/job:ps/replica:0/task:1/device:CPU:0"
assert np.array_equal(v.variables[0].numpy(), [[1, 0, 0, 0], [0, 1, 0, 0]])
assert np.array_equal(v.variables[1].numpy(), [[0, 0, 1, 0], [0, 0, 0, 1]])
```
__Dataset preparation__
With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
created in each of the workers to be used for training. This is done by
creating a `dataset_fn` that takes no argument and returns a
`tf.data.Dataset`, and passing the `dataset_fn` into
`tf.distribute.experimental.coordinator.
ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
shuffled and repeated to have the examples run through the training as evenly
as possible.
```python
def dataset_fn():
filenames = ...
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Dataset is recommended to be shuffled, and repeated.
return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
coordinator =
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
```
__Limitations__
* `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
and the API is subject to further changes.
* `tf.distribute.experimental.ParameterServerStrategy` does not yet support
training with GPU(s). This is a feature request being developed.
* `tf.distribute.experimental.ParameterServerStrategy` only supports
[custom training loop
API](https://www.tensorflow.org/tutorials/distribute/custom_training)
currently in TF2. Usage of it with Keras `compile`/`fit` API is being
developed.
* `tf.distribute.experimental.ParameterServerStrategy` must be used with
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
* This strategy is not intended for TPU. Use
`tf.distribute.experimental.TPUStrategy` instead.
"""
# pyformat: disable
def __init__(self, cluster_resolver, variable_partitioner=None):
"""Initializes the V2 parameter server strategy.
"""Initializes the TF2 parameter server strategy.
This also connects to the remote server cluster.
This initializes the `tf.distribute.experimental.ParameterServerStrategy`
object to be ready for use with
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
Args:
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
object.
variable_partitioner: a callable with the signature `num_partitions =
fn(shape, dtype)`, where `num_partitions` is a list/tuple representing
the number of partitions on each axis, and `shape` and `dtype` are of
types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will
not be partitioned. * `variable_partitioner` will be called for all
variables created under strategy `scope` to instruct how the variables
should be partitioned. Variables will be partitioned if there are more
than one partitions along the partitioning axis, otherwise it falls back
to normal `tf.Variable`. * Only the first / outermost axis partitioning
is supported, namely, elements in `num_partitions` must be 1 other than
the first element. * Partitioner like `min_max_variable_partitioner`,
`variable_axis_size_partitioner` and `fixed_size_partitioner` are also
supported since they conform to the required signature. * Div partition
variable_partitioner:
a callable with the signature `num_partitions = fn(shape, dtype)`, where
`num_partitions` is a list/tuple representing the number of partitions
on each axis, and `shape` and `dtype` are of types `tf.TensorShape` and
`tf.dtypes.Dtype`. If `None`, variables will not be partitioned.
* `variable_partitioner` will be called for all variables created under
strategy `scope` to instruct how the variables should be partitioned.
Variables will be created in multiple partitions if there are more than
one partition along the partitioning axis, otherwise it falls back to
normal `tf.Variable`.
* Only the first / outermost axis partitioning is supported, namely,
elements in `num_partitions` must be 1 other than the first element.
* Partitioner like `tf.compat.v1.min_max_variable_partitioner`,
`tf.compat.v1.variable_axis_size_partitioner` and
`tf.compat.v1.fixed_size_partitioner` are also supported since they
conform to the required signature.
* Div partition
strategy is used to partition variables. Assuming we assign consecutive
integer ids along the first axis of a variable, then ids are assigned to
shards in a contiguous manner, while attempting to keep each shard size
identical. If the ids do not evenly divide the number of shards, each of
the first several shards will be assigned one more id. For instance, a
variable whose first dimension is
13 has 13 ids, and they are split across 5 shards as: `[[0, 1, 2], [3,
4, 5], [6, 7, 8], [9, 10], [11, 12]]`. * Variables created under
`strategy.extended.colocate_vars_with` will not be partitioned, e.g,
optimizer's slot variables.
variable whose first dimension is 13 has 13 ids, and they are split
across 5 shards as:
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
* Variables created under `strategy.extended.colocate_vars_with` will
not be partitioned, e.g, optimizer's slot variables.
"""
# pyformat: enable
self._cluster_resolver = cluster_resolver
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
variable_partitioner)
self._verify_args_and_config(cluster_resolver)
logging.info(
"ParameterServerStrategyV2 is initialized with cluster_spec: "
"%s", cluster_resolver.cluster_spec())
"`tf.distribute.experimental.ParameterServerStrategy` is initialized "
"with cluster_spec: %s", cluster_resolver.cluster_spec())
# TODO(b/167894802): Make chief, worker, and ps names customizable.
self._connect_to_cluster(client_name="chief")
# TODO(b/167894802): Make coordinator, worker, and ps names customizable.
self._connect_to_cluster(coordinator_name="chief")
super(ParameterServerStrategyV2, self).__init__(self._extended)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"ParameterServerStrategy")
def _connect_to_cluster(self, client_name):
if client_name in ["worker", "ps"]:
raise ValueError("Client name should not be 'worker' or 'ps'.")
def _connect_to_cluster(self, coordinator_name):
if coordinator_name in ["worker", "ps"]:
raise ValueError("coordinator name should not be 'worker' or 'ps'.")
cluster_spec = self._cluster_resolver.cluster_spec()
self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
device_filters = server_lib.ClusterDeviceFilters()
# For any worker, only the devices on PS and chief nodes are visible
# For any worker, only the devices on ps and coordinator nodes are visible
for i in range(self._num_workers):
device_filters.set_device_filters(
"worker", i, ["/job:ps", "/job:%s" % client_name])
# Similarly for any ps, only the devices on workers and chief are visible
"worker", i, ["/job:ps", "/job:%s" % coordinator_name])
# Similarly for any ps, only the devices on workers and coordinator are
# visible
for i in range(self._num_ps):
device_filters.set_device_filters(
"ps", i, ["/job:worker", "/job:%s" % client_name])
"ps", i, ["/job:worker", "/job:%s" % coordinator_name])
# Allow at most one outstanding RPC for each worker at a certain time. This
# is to simplify worker failure handling in the runtime
@ -122,7 +440,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
self.__class__.__name__, cluster_spec)
remote.connect_to_cluster(
cluster_spec,
job_name=client_name,
job_name=coordinator_name,
protocol=self._cluster_resolver.rpc_layer,
cluster_device_filters=device_filters)
@ -134,7 +452,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
def _verify_args_and_config(self, cluster_resolver):
if not cluster_resolver.cluster_spec():
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
if self.extended._num_gpus_per_worker > 1:
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
raise NotImplementedError("Multi-gpu is not supported yet.")
@ -205,8 +523,8 @@ class ParameterServerStrategyV2Extended(
init_from_fn = False
initial_value = initial_value()
if not init_from_fn:
# The initial_value is created on client, it will need to be sent to
# PS for variable initialization, which can be inefficient and can
# The initial_value is created on coordinator, it will need to be sent to
# ps for variable initialization, which can be inefficient and can
# potentially hit the 2GB limit on protobuf serialization.
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
dtype = initial_value.dtype

View File

@ -870,8 +870,8 @@ py_test(
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ParameterServerClient and Keras models."""
"""Tests for ClusterCoordinator and Keras models."""
from __future__ import absolute_import
from __future__ import division
@ -27,8 +27,8 @@ from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec
def make_client(num_workers, num_ps):
def make_coordinator(num_workers, num_ps):
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
cluster_def["chief"] = [
@ -52,7 +52,7 @@ def make_client(num_workers, num_ps):
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc")
return client_lib.Client(
return coordinator_lib.ClusterCoordinator(
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
@ -61,7 +61,7 @@ class KPLTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(KPLTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
def testTrainAndServe(self):
# These vocabularies usually come from TFT or a Beam pipeline.
@ -71,10 +71,10 @@ class KPLTest(test.TestCase):
]
label_vocab = ["yes", "no"]
with self.client.strategy.scope():
with self.coordinator.strategy.scope():
# Define KPLs under strategy's scope. Right now, if they have look up
# tables, they will be created on the client. Their variables will be
# tables, they will be created on the coordinator. Their variables will be
# created on PS. Ideally they should be cached on each worker since they
# will not be changed in a training step.
feature_lookup_layer = string_lookup.StringLookup()
@ -113,7 +113,7 @@ class KPLTest(test.TestCase):
label = "yes" if "avenger" in features else "no"
yield {"features": features, "label": label}
# The dataset will be created on the client?
# The dataset will be created on the coordinator?
raw_dataset = dataset_ops.Dataset.from_generator(
feature_and_label_gen,
output_types={
@ -131,7 +131,8 @@ class KPLTest(test.TestCase):
}, [x["label"]]))
return train_dataset
distributed_dataset = self.client.create_per_worker_dataset(dataset_fn)
distributed_dataset = self.coordinator.create_per_worker_dataset(
dataset_fn)
model_input = keras.layers.Input(
shape=(3,), dtype=dtypes.int64, name="model_input")
@ -163,12 +164,12 @@ class KPLTest(test.TestCase):
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
accuracy.update_state(labels, actual_pred)
self.client._strategy.run(train_step, args=(iterator,))
self.coordinator._strategy.run(train_step, args=(iterator,))
distributed_iterator = iter(distributed_dataset)
for _ in range(10):
self.client.schedule(worker_fn, args=(distributed_iterator,))
self.client.join()
self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
self.coordinator.join()
self.assertGreater(accuracy.result().numpy(), 0.0)
# Create a saved model.

View File

@ -362,7 +362,8 @@ class Model(training_lib.Model):
if isinstance(self._distribution_strategy,
(parameter_server_strategy.ParameterServerStrategyV1,
parameter_server_strategy.ParameterServerStrategy)):
raise NotImplementedError('ParameterServerStrategy currently only works '
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
'erServerStrategy` currently only works '
'with the tf.Estimator API')
if not self._experimental_run_tf_function:

View File

@ -64,6 +64,9 @@ from tensorflow.python.framework.test_combinations import *
from tensorflow.python.util.tf_decorator import make_decorator
from tensorflow.python.util.tf_decorator import unwrap
from tensorflow.python.distribute.parameter_server_strategy_v2 import *
from tensorflow.python.distribute.coordinator.cluster_coordinator import *
tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator)
tf_export('__internal__.decorator.unwrap', v1=[])(unwrap)

View File

@ -31,6 +31,7 @@ TENSORFLOW_API_INIT_FILES = [
"distribute/__init__.py",
"distribute/cluster_resolver/__init__.py",
"distribute/experimental/__init__.py",
"distribute/experimental/coordinator/__init__.py",
"dtypes/__init__.py",
"errors/__init__.py",
"experimental/__init__.py",

View File

@ -27,6 +27,8 @@ from tensorflow.lite.python import lite as _tflite_for_api_traversal
from tensorflow.python import modules_with_exports
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.framework import combinations
from tensorflow.python.framework import test_combinations
# pylint: enable=unused-import

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.experimental.ParameterServerStrategy"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategy\'>"
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy_v2.ParameterServerStrategyV2\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
@ -18,7 +18,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'cluster_resolver\', \'variable_partitioner\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "colocate_vars_with"

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.experimental.coordinator.ClusterCoordinator"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.ClusterCoordinator\'>"
is_instance: "<type \'object\'>"
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "create_per_worker_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "done"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "fetch"
argspec: "args=[\'self\', \'val\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "join"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "schedule"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.coordinator.PerWorkerValues"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.PerWorkerValues\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.experimental.coordinator.RemoteValue"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.RemoteValue\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "fetch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,15 @@
path: "tensorflow.distribute.experimental.coordinator"
tf_module {
member {
name: "ClusterCoordinator"
mtype: "<type \'type\'>"
}
member {
name: "PerWorkerValues"
mtype: "<type \'type\'>"
}
member {
name: "RemoteValue"
mtype: "<type \'type\'>"
}
}

View File

@ -36,4 +36,8 @@ tf_module {
name: "ValueContext"
mtype: "<type \'type\'>"
}
member {
name: "coordinator"
mtype: "<type \'module\'>"
}
}

View File

@ -154,9 +154,9 @@ COMMON_PIP_DEPS = [
"//tensorflow/tools/common:test_module1",
"//tensorflow/tools/common:traverse",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute/client:client",
"//tensorflow/python/distribute/client:remote_eager_lib",
"//tensorflow/python/distribute/client:metric_utils",
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/distribute/coordinator:remote_eager_lib",
"//tensorflow/python/distribute/coordinator:metric_utils",
]
# On Windows, python binary is a zip file of runfiles tree.