PSv2: Export a few tf.distribute symbols related to TF2 parameter server training.
This change exports the following class symbols, and adds relevant documentation and example code to tf.distribute.experimental.ParameterServerStrategy tf.distribute.experimental.coordinator.ClusterCoordinator tf.distribute.experimental.coordinator.PerWorkerValues tf.distribute.experimental.coordinator.RemoteValue PiperOrigin-RevId: 338151262 Change-Id: If2d1c513d30a999c728cecc2e73b75adda1948c2
This commit is contained in:
parent
833b3a49a9
commit
32f35aabce
11
RELEASE.md
11
RELEASE.md
@ -208,7 +208,16 @@
|
||||
how many times the function is called, and independent of global seed
|
||||
settings.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* (Experimental) Parameter server training:
|
||||
* Replaced the existing
|
||||
`tf.distribute.experimental.ParameterServerStrategy` symbol with
|
||||
a new class that is for parameter server training in TF2. Usage with
|
||||
the old symbol, usually with Estimator, should be replaced with
|
||||
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
|
||||
* Added `tf.distribute.experimental.coordinator.*` namespace,
|
||||
including the main API `ClusterCoordinator` for coordinating the
|
||||
training cluster, the related data structure `RemoteValue`
|
||||
and `PerWorkerValue`.
|
||||
* `tf.keras`:
|
||||
* Improvements from the functional API refactoring:
|
||||
* Functional model construction does not need to maintain a global
|
||||
|
@ -153,8 +153,9 @@ py_library(
|
||||
":multi_process_runner",
|
||||
":multi_worker_test_base",
|
||||
":one_device_strategy",
|
||||
":parameter_server_strategy_v2",
|
||||
":sharded_variable",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||
"//tensorflow/python/distribute/experimental",
|
||||
],
|
||||
)
|
||||
@ -1880,6 +1881,7 @@ tf_py_test(
|
||||
":multi_worker_test_base",
|
||||
":parameter_server_strategy_v2",
|
||||
":sharded_variable",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
|
@ -8,8 +8,8 @@ package(
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "client",
|
||||
srcs = ["client.py"],
|
||||
name = "cluster_coordinator",
|
||||
srcs = ["cluster_coordinator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metric_utils",
|
||||
@ -34,9 +34,9 @@ py_library(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "client_test",
|
||||
name = "cluster_coordinator_test",
|
||||
size = "small",
|
||||
srcs = ["client_test.py"],
|
||||
srcs = ["cluster_coordinator_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
@ -44,7 +44,7 @@ tf_py_test(
|
||||
"notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards
|
||||
],
|
||||
deps = [
|
||||
":client",
|
||||
":cluster_coordinator",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -66,12 +66,13 @@ tf_py_test(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "client_mpr_test",
|
||||
srcs = ["client_mpr_test.py"],
|
||||
name = "cluster_coordinator_mpr_test",
|
||||
srcs = ["cluster_coordinator_mpr_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 2,
|
||||
tags = ["no_oss"], # TODO(b/162119374)
|
||||
deps = [
|
||||
":cluster_coordinator",
|
||||
":remote_eager_lib",
|
||||
":utils",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -81,7 +82,6 @@ tf_py_test(
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
@ -102,7 +102,7 @@ tf_py_test(
|
||||
srcs = ["metric_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":client",
|
||||
":cluster_coordinator",
|
||||
":metric_utils",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module for `Client` and relevant cluster-worker related library.
|
||||
"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
|
||||
|
||||
This is currently under development and the API is subject to change.
|
||||
"""
|
||||
@ -35,7 +35,7 @@ from six.moves import queue
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import metric_utils
|
||||
from tensorflow.python.distribute.coordinator import metric_utils
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -48,6 +48,7 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Maximum time for failed worker to come back is 1 hour
|
||||
_WORKER_MAXIMUM_RECOVERY_SEC = 3600
|
||||
@ -56,9 +57,10 @@ _WORKER_MAXIMUM_RECOVERY_SEC = 3600
|
||||
# When the maximum queue size is reached, further schedule calls will become
|
||||
# blocking until some previously queued closures are executed on workers.
|
||||
# Note that using an "infinite" queue size can take a non-trivial portion of
|
||||
# memory, and even lead to client OOM. Modify the size to a smaller value for
|
||||
# client with constrained memory resource (only recommended for advanced users).
|
||||
# Also used in unit tests to ensure the correctness when the queue is full.
|
||||
# memory, and even lead to coordinator OOM. Modify the size to a smaller value
|
||||
# for coordinator with constrained memory resource (only recommended for
|
||||
# advanced users). Also used in unit tests to ensure the correctness when the
|
||||
# queue is full.
|
||||
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
|
||||
|
||||
# RPC error message from PS
|
||||
@ -99,18 +101,77 @@ class _RemoteValueStatus(enum.Enum):
|
||||
READY = "READY"
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
|
||||
class RemoteValue(object):
|
||||
"""An asynchronously available value of a remotely executed function.
|
||||
|
||||
`RemoteValue` class is used as the return value of `Client.schedule()` where
|
||||
the underlying concrete value comes at a later time once the function has been
|
||||
remotely executed. `RemoteValue` can be used as an input to a subsequent
|
||||
function scheduled with `Client.schedule()`.
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` class is used as the
|
||||
return value of
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` where
|
||||
the underlying concrete value becomes available at a later time once the
|
||||
function has been remotely executed. The underlying concrete value is the
|
||||
`tf.Tensor.numpy()` result of the `tf.Tensor`, or the structure of
|
||||
`tf.Tensor`s, returned from the `tf.function` that was scheduled.
|
||||
|
||||
Note: this class is not thread-safe.
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` to be used as an input to
|
||||
a subsequent function scheduled with
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` is
|
||||
currently not supported.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver=...)
|
||||
coordinator = (
|
||||
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
|
||||
|
||||
with strategy.scope():
|
||||
v1 = tf.Variable(initial_value=0.0)
|
||||
v2 = tf.Variable(initial_value=1.0)
|
||||
|
||||
@tf.function
|
||||
def worker_fn():
|
||||
v1.assign_add(0.1)
|
||||
v2.assign_sub(0.2)
|
||||
return v1.read_value() / v2.read_value()
|
||||
|
||||
result = coordinator.schedule(worker_fn)
|
||||
# Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
|
||||
assert result.fetch() == 0.125
|
||||
|
||||
for _ in range(10):
|
||||
# `worker_fn` will be run on arbitrary workers that are available. The
|
||||
# `result` value will be non-deterministic because the workers are executing
|
||||
# the functions asynchronously.
|
||||
result = coordinator.schedule(worker_fn)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, closure, type_spec):
|
||||
def fetch(self):
|
||||
"""Wait for the result of `RemoteValue` to be ready and return the result.
|
||||
|
||||
This makes the value concrete by copying the remote value to local.
|
||||
|
||||
Returns:
|
||||
The actual output of the `tf.function` associated with this `RemoteValue`,
|
||||
previously by a
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
|
||||
This can be a single value, or a structure of values, depending on the
|
||||
output of the `tf.function`.
|
||||
|
||||
Raises:
|
||||
tf.errors.CancelledError: If the function that produces this `RemoteValue`
|
||||
is aborted or cancelled due to failure, and the user should handle and
|
||||
reschedule.
|
||||
"""
|
||||
raise NotImplementedError("Must be implemented in subclasses.")
|
||||
|
||||
|
||||
class RemoteValueImpl(RemoteValue):
|
||||
"""Implementation of `RemoteValue`."""
|
||||
|
||||
def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
|
||||
self._closure = closure
|
||||
# The type spec for this `RemoteValue` which is used to trace functions that
|
||||
# take this `RemoteValue` as input.
|
||||
@ -157,16 +218,6 @@ class RemoteValue(object):
|
||||
self._type_spec = func_graph.convert_structure_to_signature(type_spec)
|
||||
|
||||
def fetch(self):
|
||||
"""Wait for the result of RemoteValue to be ready and return the result.
|
||||
|
||||
Returns:
|
||||
The remote value, as a numpy data type (if scalar) or ndarray.
|
||||
|
||||
Raises:
|
||||
tf.errors.CancelledError: If the function that produces this `RemoteValue`
|
||||
is aborted or cancelled due to failure, and the user should handle and
|
||||
reschedule.
|
||||
"""
|
||||
self._status_available_event.wait()
|
||||
if self._status is _RemoteValueStatus.ABORTED:
|
||||
raise errors.CancelledError(
|
||||
@ -241,8 +292,23 @@ def _maybe_as_type_spec(val):
|
||||
return val
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
|
||||
class PerWorkerValues(object):
|
||||
"""Holds a list of per worker values."""
|
||||
"""A container that holds a list of values, one value per worker.
|
||||
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
|
||||
of values, where each of the values represents a resource that are located in
|
||||
individual workers, and upon being used as one of the `args` or `kwargs` of
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
|
||||
value specific to a worker will be passed into the function being executed at
|
||||
that particular worker.
|
||||
|
||||
Currently, the only supported path to create an object of
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
|
||||
`iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
|
||||
distributed dataset instance. The mechanism to create a custom
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
|
||||
"""
|
||||
|
||||
def __init__(self, values):
|
||||
self._values = tuple(values)
|
||||
@ -262,9 +328,10 @@ def _disallow_remote_value_as_input(structured):
|
||||
|
||||
def _raise_if_remote_value(x):
|
||||
if isinstance(x, RemoteValue):
|
||||
raise ValueError("RemoteValue cannot be used as an input to scheduled "
|
||||
"function. Please file a feature request if you need "
|
||||
"this feature.")
|
||||
raise ValueError(
|
||||
"`tf.distribute.experimental.coordinator.RemoteValue` used "
|
||||
"as an input to scheduled function is not yet "
|
||||
"supported.")
|
||||
|
||||
nest.map_structure(_raise_if_remote_value, structured)
|
||||
|
||||
@ -274,8 +341,8 @@ class Closure(object):
|
||||
|
||||
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
|
||||
if not callable(function):
|
||||
raise ValueError("Function passed to `Client.schedule` must be a "
|
||||
"callable object.")
|
||||
raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
|
||||
"be a callable object.")
|
||||
self._args = args or ()
|
||||
self._kwargs = kwargs or {}
|
||||
|
||||
@ -287,9 +354,9 @@ class Closure(object):
|
||||
replica_kwargs = _select_worker_slice(0, self._kwargs)
|
||||
|
||||
# Note: no need to handle function registration failure since this kind of
|
||||
# failure will not raise exceptions as designed in the runtime. The client
|
||||
# has to rely on subsequent operations that raise to catch function
|
||||
# registration failure.
|
||||
# failure will not raise exceptions as designed in the runtime. The
|
||||
# coordinator has to rely on subsequent operations that raise to catch
|
||||
# function registration failure.
|
||||
|
||||
# Record the function tracing overhead. Note that we pass in the tracing
|
||||
# count of the def_function.Function as a state tracker, so that metrics
|
||||
@ -303,17 +370,18 @@ class Closure(object):
|
||||
self._function = cancellation_mgr.get_cancelable_function(
|
||||
concrete_function)
|
||||
self._output_remote_values = nest.map_structure(
|
||||
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
|
||||
lambda x: RemoteValueImpl(self, x),
|
||||
concrete_function.structured_outputs)
|
||||
elif isinstance(function, tf_function.ConcreteFunction):
|
||||
self._function = cancellation_mgr.get_cancelable_function(function)
|
||||
self._output_remote_values = nest.map_structure(
|
||||
lambda x: RemoteValue(self, x), function.structured_outputs)
|
||||
lambda x: RemoteValueImpl(self, x), function.structured_outputs)
|
||||
else:
|
||||
# Regular python functions.
|
||||
self._function = function
|
||||
# TODO(yuefengz): maybe we should trace python functions if their inputs
|
||||
# are Python primitives, tensors and composite tensors.
|
||||
self._output_remote_values = RemoteValue(self, None)
|
||||
self._output_remote_values = RemoteValueImpl(self, None)
|
||||
|
||||
def _fetch_output_remote_values(self):
|
||||
"""Temporary method used to sync the scheduler."""
|
||||
@ -394,7 +462,7 @@ class _CoordinatedClosureQueue(object):
|
||||
|
||||
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
|
||||
logging.warning(
|
||||
"In a `Client`, creating an infinite closure queue can "
|
||||
"In a `ClusterCoordinator`, creating an infinite closure queue can "
|
||||
"consume a significant amount of memory and even lead to OOM.")
|
||||
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
|
||||
self._error = None
|
||||
@ -430,10 +498,10 @@ class _CoordinatedClosureQueue(object):
|
||||
# The cancellation manager cannot be reused once cancelled. After all
|
||||
# closures (queued or inflight) are cleaned up, recreate the cancellation
|
||||
# manager with clean state.
|
||||
# Note on thread-safety: this is triggered when one of theses client APIs
|
||||
# are called: `schedule`, `wait`, and `done`. At the same time, no new
|
||||
# closures can be constructed (which reads the _cancellation_mgr to get
|
||||
# cancellable functions).
|
||||
# Note on thread-safety: this is triggered when one of theses
|
||||
# ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
|
||||
# same time, no new closures can be constructed (which reads the
|
||||
# _cancellation_mgr to get cancellable functions).
|
||||
self._cancellation_mgr = cancellation.CancellationManager()
|
||||
|
||||
def _raise_if_error(self):
|
||||
@ -742,8 +810,8 @@ class Worker(object):
|
||||
|
||||
def _register_resource(self, resource_remote_value):
|
||||
if not isinstance(resource_remote_value, RemoteValue):
|
||||
raise ValueError(
|
||||
"Resource being registered is not of type `RemoteValue`.")
|
||||
raise ValueError("Resource being registered is not of type "
|
||||
"`tf.distribute.experimental.coordinator.RemoteValue`.")
|
||||
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
|
||||
|
||||
|
||||
@ -753,7 +821,7 @@ class Cluster(object):
|
||||
We assume all function errors are fatal and based on this assumption our
|
||||
error reporting logic is:
|
||||
1) Both `schedule` and `join` can raise a non-retryable error which is the
|
||||
first error seen by the client from any previously scheduled functions.
|
||||
first error seen by the coordinator from any previously scheduled functions.
|
||||
2) When an error is raised, there is no guarantee on how many previously
|
||||
scheduled functions have been executed; functions that have not been executed
|
||||
will be thrown away and marked as cancelled.
|
||||
@ -775,17 +843,17 @@ class Cluster(object):
|
||||
|
||||
# Ignore PS failures reported by workers due to transient connection errors.
|
||||
# Transient connectivity issues between workers and PS are relayed by the
|
||||
# workers to the client, leading the client to believe that there are PS
|
||||
# failures. The difference between transient vs. permanent PS failure is the
|
||||
# number of reports from the workers. When this env var is set to a positive
|
||||
# integer K, the client ignores up to K reports of a failed PS task. I.e.,
|
||||
# only when there are more than K trials of executing closures fail due to
|
||||
# errors from the same PS instance do we consider the PS instance encounters
|
||||
# a failure.
|
||||
# workers to the coordinator, leading the coordinator to believe that there
|
||||
# are PS failures. The difference between transient vs. permanent PS failure
|
||||
# is the number of reports from the workers. When this env var is set to a
|
||||
# positive integer K, the coordinator ignores up to K reports of a failed PS
|
||||
# task, i.e., only when there are more than K trials of executing closures
|
||||
# fail due to errors from the same PS instance do we consider the PS
|
||||
# instance encounters a failure.
|
||||
# TODO(b/164279603): Remove this workaround when the underlying connectivity
|
||||
# issue in gRPC server is resolved.
|
||||
self._transient_ps_failures_threshold = int(os.environ.get(
|
||||
"TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3))
|
||||
self._transient_ps_failures_threshold = int(
|
||||
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
|
||||
self._potential_ps_failures_lock = threading.Lock()
|
||||
self._potential_ps_failures_count = [0] * self._num_ps
|
||||
|
||||
@ -825,7 +893,7 @@ class Cluster(object):
|
||||
kwargs: Keyword arguments for `fn`.
|
||||
|
||||
Returns:
|
||||
A structure of `RemoteValue` object.
|
||||
A `RemoteValue` object.
|
||||
"""
|
||||
closure = Closure(
|
||||
function,
|
||||
@ -844,68 +912,124 @@ class Cluster(object):
|
||||
return self._closure_queue.done()
|
||||
|
||||
|
||||
class Client(object):
|
||||
"""An object to schedule and orchestrate remote function execution.
|
||||
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
|
||||
class ClusterCoordinator(object):
|
||||
"""An object to schedule and coordinate remote function execution.
|
||||
|
||||
A `Client` object represents a program used to create dataset, schedule
|
||||
functions to be executed, and fetch the results of the functions.
|
||||
A `tf.distribute.experimental.coordinator.ClusterCoordinator` object
|
||||
represents a program used to distribute dataset onto the workers, schedule
|
||||
functions to be executed, and fetch the results of the functions. It expects
|
||||
the cluster to contain some machines with processes running TensorFlow
|
||||
servers.
|
||||
|
||||
Currently, `Client` is not supported to be used in a standalone manner.
|
||||
It should be used in conjunction with `ParameterServerStrategyV2`.
|
||||
Currently, `tf.distribute.experimental.coordinator.ClusterCoordinator` is not
|
||||
supported to be used in a standalone manner. It should be used in conjunction
|
||||
with a `tf.distribute` strategy that is designed to work with it. Currently,
|
||||
only `tf.distribute.experimental.ParameterServerStrategy` is supported to work
|
||||
with `tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||
|
||||
__Fault tolerance__
|
||||
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator`, when used with
|
||||
`tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
|
||||
fault tolerance for worker failures. That is, when some workers are not
|
||||
available for any reason to be reached from the coordinator, the training
|
||||
progress continues to be made with the remaining operating workers, without
|
||||
the need of any additional treatment in user code.
|
||||
|
||||
On the other hand, when a parameter server or the coordinator fails, a
|
||||
`tf.errors.UnavailableError` is raised by `ClusterCoordinator.schedule` or
|
||||
`ClusterCoordinator.join`. If any parameter server fails, the user should
|
||||
restart the processes on the failed parameter servers when they become
|
||||
available, *and* restart the process on the coordinator, so the coordinator
|
||||
can re-create the variables on the parameter servers. If the coordinator fails
|
||||
but all other machines continue to be operating, the user only needs to
|
||||
restart the process on the coordinator, which will automatically connect to
|
||||
the parameter servers and workers, and continue the progress.
|
||||
|
||||
It is thus essential that in user's custom training loop, a checkpoint file is
|
||||
periodically saved, and restored at the start of the program.
|
||||
|
||||
* At-least-once semantics of `schedule`: `schedule` puts the `tf.function` in
|
||||
a queue where it gets picked up by a worker to execute. If a worker picks up
|
||||
a `tf.function`, has begun to execute it, but the worker process is
|
||||
disconnected from the coordinator in the middle of execution, the
|
||||
`tf.function` is deemed not completed, and is put back to the queue. This is
|
||||
regardless of how much progress the `tf.function` has run. Once it is
|
||||
re-queued, it will be picked up again by an arbitrary available worker at some
|
||||
later time. As a result, the same function may be executed more than once, but
|
||||
not less than once. If an `tf.keras.optimizers.Optimizer` is used,
|
||||
`tf.keras.optimizers.Optimizer.iterations` roughly indicates the number of
|
||||
times the gradients have been applied and can be used as an approximation of
|
||||
the total number of steps. The number of times the function has been scheduled
|
||||
by the `tf.distribute.experimental.coordinator.ClusterCoordinator`, on the
|
||||
other hand, should not be indicative of actual steps run. See
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` for more
|
||||
information.
|
||||
|
||||
See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
|
||||
example usage of this API.
|
||||
|
||||
This is currently under development, and the API as well as implementation
|
||||
is subject to changes.
|
||||
are subject to changes.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy):
|
||||
"""Initialization of a `Client` instance.
|
||||
|
||||
This connects the client to remote workers and parameter servers, through
|
||||
a `tf.config.experimental_connect_to_cluster` call.
|
||||
"""Initialization of a `ClusterCoordinator` instance.
|
||||
|
||||
Args:
|
||||
strategy: a `tf.distribute.Strategy` object. Currently, only
|
||||
`ParameterServerStrategyV2` is supported.
|
||||
strategy: a supported `tf.distribute.Strategy` object. Currently, only
|
||||
`tf.distribute.experimental.ParameterServerStrategy` is supported.
|
||||
|
||||
Raises:
|
||||
ValueError: if the strategy being used is not supported.
|
||||
"""
|
||||
if not isinstance(strategy,
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2):
|
||||
raise ValueError("Only `ParameterServerStrategyV2` is supported in "
|
||||
"`Client` currently.")
|
||||
raise ValueError(
|
||||
"Only `tf.distribute.experimental.ParameterServerStrategy` "
|
||||
"is supported to work with "
|
||||
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
|
||||
"currently.")
|
||||
self._strategy = strategy
|
||||
self.cluster = Cluster(strategy)
|
||||
|
||||
@property
|
||||
def strategy(self):
|
||||
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
|
||||
return self._strategy
|
||||
|
||||
def schedule(self, fn, args=None, kwargs=None):
|
||||
"""Schedules `fn` to be dispatched to a worker for execution asynchronously.
|
||||
"""Schedules `fn` to be dispatched to a worker for asynchronous execution.
|
||||
|
||||
When calling `schedule` with a function `fn`, `fn` will be executed on a
|
||||
remote worker at some later time. The process is asynchronous, meaning
|
||||
`schedule` returns immediately, possibly without having the result ready
|
||||
yet. `schedule` returns a structure of `RemoteValue` object, which wraps the
|
||||
output of the function. Call `fetch()` on `RemoteValue` to wait for the
|
||||
yet. `schedule` returns a
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` object, which wraps the
|
||||
output of the function. After `schedule` is called, `fetch` can be called
|
||||
on the `tf.distribute.experimental.coordinator.RemoteValue` to wait for the
|
||||
function execution to finish and retrieve its output from the remote worker.
|
||||
On the other hand, call
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
|
||||
all scheduled functions to finish execution before proceeding.
|
||||
|
||||
`schedule` guarantees that `fn` will be executed on a worker at least once;
|
||||
it could be more than once if its corresponding worker fails in the middle
|
||||
of its execution. Note that since worker can fail at any point when
|
||||
executing the function, it is possible that the function is partially
|
||||
executed, but `Client` guarantees that in those events, the function will
|
||||
eventually be fully executed, possibly on a different worker that is
|
||||
available.
|
||||
executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
|
||||
guarantees that in those events, the function will eventually be fully
|
||||
executed, possibly on a different worker that is available.
|
||||
|
||||
If any previously scheduled function raises an error, `schedule` will fail
|
||||
by raising any one of those errors, and clear the errors collected so far.
|
||||
There are two implications when this happens: 1) user should call `schedule`
|
||||
with `fn` again to re-schedule, and 2) some of the previously scheduled
|
||||
functions may have not been executed. User can call `fetch` on the returned
|
||||
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
|
||||
reschedule the corresponding function if needed.
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
|
||||
executed, failed, or cancelled, and reschedule the corresponding function if
|
||||
needed.
|
||||
|
||||
When `schedule` raises, it guarantees that there is no function that is
|
||||
still being executed.
|
||||
@ -914,12 +1038,14 @@ class Client(object):
|
||||
execution, or priority of the workers.
|
||||
|
||||
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is
|
||||
executed on a worker. They can be `PerWorkerValues`, which is a collection
|
||||
of values, each of which represents a component specific to a worker; in
|
||||
this case, the argument will be substituted with the corresponding component
|
||||
on the target worker. Arguments that are not `PerWorkerValues` will be
|
||||
passed into `fn` as-is. Currently, `RemoteValue` is not supported to be
|
||||
input `args` or `kwargs`.
|
||||
executed on a worker. They can be
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
|
||||
'collection of values, each of which represents a component specific to a
|
||||
worker; in this case, the argument will be substituted with the
|
||||
corresponding component on the target worker. Arguments that are not
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
|
||||
`fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
|
||||
is not supported to be input `args` or `kwargs`.
|
||||
|
||||
Args:
|
||||
fn: A `tf.function`; the function to be dispatched to a worker for
|
||||
@ -928,12 +1054,13 @@ class Client(object):
|
||||
kwargs: Keyword arguments for `fn`.
|
||||
|
||||
Returns:
|
||||
A structure of `RemoteValue` object.
|
||||
A `tf.distribute.experimental.coordinator.RemoteValue` object that
|
||||
represents the output of the function scheduled.
|
||||
|
||||
Raises:
|
||||
Exception: one of the exceptions caught by the client by any previously
|
||||
scheduled function since the last time an error was thrown or since
|
||||
the beginning of the program.
|
||||
Exception: one of the exceptions caught by the coordinator by any
|
||||
previously scheduled function since the last time an error was thrown or
|
||||
since the beginning of the program.
|
||||
"""
|
||||
# Slot variables are usually created during function tracing time; thus
|
||||
# `schedule` needs to be called within the `strategy.scope()`.
|
||||
@ -946,18 +1073,18 @@ class Client(object):
|
||||
If any previously scheduled function raises an error, `join` will fail by
|
||||
raising any one of those errors, and clear the errors collected so far. If
|
||||
this happens, some of the previously scheduled functions may have not been
|
||||
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
|
||||
they have executed, failed, or cancelled. If some that have been cancelled
|
||||
need to be rescheduled, users should call `schedule` with the function
|
||||
again.
|
||||
executed. Users can call `fetch` on the returned
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
|
||||
executed, failed, or cancelled. If some that have been cancelled need to be
|
||||
rescheduled, users should call `schedule` with the function again.
|
||||
|
||||
When `join` returns or raises, it guarantees that there is no function that
|
||||
is still being executed.
|
||||
|
||||
Raises:
|
||||
Exception: one of the exceptions caught by the client by any previously
|
||||
scheduled function since the last time an error was thrown or since
|
||||
the beginning of the program.
|
||||
Exception: one of the exceptions caught by the coordinator by any
|
||||
previously scheduled function since the last time an error was thrown or
|
||||
since the beginning of the program.
|
||||
"""
|
||||
self.cluster.join()
|
||||
|
||||
@ -969,6 +1096,9 @@ class Client(object):
|
||||
|
||||
When `done` returns True or raises, it guarantees that there is no function
|
||||
that is still being executed.
|
||||
|
||||
Returns:
|
||||
Whether all the scheduled functions have finished execution.
|
||||
"""
|
||||
return self.cluster.done()
|
||||
|
||||
@ -978,12 +1108,15 @@ class Client(object):
|
||||
This creates the given dataset generated by dataset_fn on the workers
|
||||
and returns an object that represents the collection of those individual
|
||||
datasets. Calling `iter` on such collection of dataset returns a
|
||||
`PerWorkerValues`, which is a collection of iterators, where the iterators
|
||||
have been placed on respective workers.
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
|
||||
collection of iterators, where the iterators have been placed on respective
|
||||
workers.
|
||||
|
||||
Calling `next` on this `PerWorkerValues` of iterators is currently
|
||||
unsupported; it is meant to be passed as an argument into `Client.schedule`.
|
||||
When the scheduled function is picked up and being executed by a worker, the
|
||||
Calling `next` on this
|
||||
`tf.distribute.experimental.coordinator.PerWorkerValues` of iterators is
|
||||
currently unsupported; it is meant to be passed as an argument into
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
|
||||
the scheduled function is picked up and being executed by a worker, the
|
||||
function will receive the individual iterator that corresponds to the
|
||||
worker, and now `next` can be called on iterator to get the next (batch or
|
||||
example) of data.
|
||||
@ -993,6 +1126,26 @@ class Client(object):
|
||||
examples may be skipped and not covered by other workers, if the dataset is
|
||||
sharded.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@tf.function
|
||||
def worker_fn(iterator):
|
||||
return next(iterator)
|
||||
|
||||
def dataset_fn():
|
||||
return tf.data.from_tensor_slices([3] * 3)
|
||||
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver=...)
|
||||
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||
strategy=strategy)
|
||||
per_worker_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||
per_worker_iter = iter(per_worker_dataset)
|
||||
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
|
||||
assert remote_value.fetch() == 3
|
||||
```
|
||||
|
||||
Args:
|
||||
dataset_fn: The dataset function that returns a dataset. This is to be
|
||||
executed on the workers.
|
||||
@ -1000,7 +1153,8 @@ class Client(object):
|
||||
Returns:
|
||||
An object that represents the collection of those individual
|
||||
datasets. `iter` is expected to be called on this object that returns
|
||||
a `PerWorkerValues` of the iterators (that are on the workers).
|
||||
a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
|
||||
iterators (that are on the workers).
|
||||
"""
|
||||
input_workers = input_lib.InputWorkers([
|
||||
(w.device_name, [w.device_name]) for w in self.cluster.workers
|
||||
@ -1011,7 +1165,8 @@ class Client(object):
|
||||
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
|
||||
"""Synchronously create resources on the workers.
|
||||
|
||||
The resources are represented by `RemoteValue`s.
|
||||
The resources are represented by
|
||||
`tf.distribute.experimental.coordinator.RemoteValue`s.
|
||||
|
||||
Args:
|
||||
fn: The function to be dispatched to all workers for execution
|
||||
@ -1020,7 +1175,9 @@ class Client(object):
|
||||
kwargs: Keyword arguments for `fn`.
|
||||
|
||||
Returns:
|
||||
A `PerWorkerValues` object, which wraps a tuple of `RemoteValue` objects.
|
||||
A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
|
||||
wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
|
||||
objects.
|
||||
"""
|
||||
results = []
|
||||
for w in self.cluster.workers:
|
||||
@ -1028,21 +1185,52 @@ class Client(object):
|
||||
return PerWorkerValues(tuple(results))
|
||||
|
||||
def fetch(self, val):
|
||||
"""Blocking call to fetch results from `RemoteValue`s.
|
||||
"""Blocking call to fetch results from the remote values.
|
||||
|
||||
This returns the execution result of `RemoteValue`s; if not ready,
|
||||
waiting for it while blocking the caller.
|
||||
This is a wrapper around
|
||||
`tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
|
||||
`RemoteValue` structure; it returns the execution results of
|
||||
`RemoteValue`s. If not ready, wait for them while blocking the caller.
|
||||
|
||||
Example:
|
||||
```python
|
||||
strategy = ...
|
||||
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||
strategy)
|
||||
|
||||
def dataset_fn():
|
||||
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
|
||||
|
||||
with strategy.scope():
|
||||
v = tf.Variable(initial_value=0)
|
||||
|
||||
@tf.function
|
||||
def worker_fn(iterator):
|
||||
def replica_fn(x):
|
||||
v.assign_add(x)
|
||||
return v.read_value()
|
||||
return strategy.run(replica_fn, args=(next(iterator),))
|
||||
|
||||
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||
assert coordinator.fetch(result) == 1
|
||||
```
|
||||
|
||||
Args:
|
||||
val: The value to fetch the results from. If this is structure of
|
||||
`RemoteValue`, `fetch()` will be called on the individual `RemoteValue`
|
||||
to get the result.
|
||||
`tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
|
||||
called on the individual
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` to get the result.
|
||||
|
||||
Returns:
|
||||
If `val` is a `RemoteValue` or a structure of `RemoteValue`s, returns
|
||||
the fetched `RemoteValue` value immediately if it's available, or blocks
|
||||
the call until it's available, and returns the fetched `RemoteValue`
|
||||
values with the same structure. If `val` is other types, return (`val`,).
|
||||
If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
|
||||
structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
|
||||
return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
|
||||
values immediately if they are available, or block the call until they are
|
||||
available, and return the fetched
|
||||
`tf.distribute.experimental.coordinator.RemoteValue` values with the same
|
||||
structure. If `val` is other types, return it as-is.
|
||||
"""
|
||||
|
||||
def _maybe_fetch(val):
|
||||
@ -1052,10 +1240,7 @@ class Client(object):
|
||||
return val
|
||||
|
||||
# TODO(yuefengz): we should fetch values in a batch.
|
||||
result = nest.map_structure(_maybe_fetch, val)
|
||||
if not isinstance(result, tuple):
|
||||
return (result,)
|
||||
return result
|
||||
return nest.map_structure(_maybe_fetch, val)
|
||||
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
@ -1075,13 +1260,14 @@ def handle_parameter_server_failure():
|
||||
class _PerWorkerDistributedDataset(object):
|
||||
"""Represents worker-distributed datasets created from dataset function."""
|
||||
|
||||
def __init__(self, dataset_fn, input_workers, client):
|
||||
def __init__(self, dataset_fn, input_workers, coordinator):
|
||||
"""Makes an iterable from datasets created by the given function.
|
||||
|
||||
Args:
|
||||
dataset_fn: A function that returns a `Dataset`.
|
||||
input_workers: an `InputWorkers` object.
|
||||
client: a `Client` object, used to create dataset resources.
|
||||
coordinator: a `ClusterCoordinator` object, used to create dataset
|
||||
resources.
|
||||
"""
|
||||
def disallow_variable_creation(next_creator, **kwargs):
|
||||
raise ValueError("Creating variables in `dataset_fn` is not allowed.")
|
||||
@ -1094,7 +1280,7 @@ class _PerWorkerDistributedDataset(object):
|
||||
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
|
||||
self._dataset_fn = dataset_fn
|
||||
self._input_workers = input_workers
|
||||
self._client = client
|
||||
self._coordinator = coordinator
|
||||
self._element_spec = None
|
||||
|
||||
def __iter__(self):
|
||||
@ -1112,7 +1298,7 @@ class _PerWorkerDistributedDataset(object):
|
||||
# If _PerWorkerDistributedDataset.__iter__ is called multiple
|
||||
# times, for the same object it should only create and register resource
|
||||
# once. Using object id to distinguish different iterator resources.
|
||||
per_worker_iterator = self._client._create_per_worker_resources(
|
||||
per_worker_iterator = self._coordinator._create_per_worker_resources(
|
||||
_create_per_worker_iterator)
|
||||
|
||||
# Setting type_spec of each RemoteValue so that functions taking these
|
||||
@ -1131,7 +1317,7 @@ class _PerWorkerDistributedDataset(object):
|
||||
|
||||
|
||||
class _PerWorkerDistributedIterator(PerWorkerValues):
|
||||
"""Distributed iterator for `Client`."""
|
||||
"""Distributed iterator for `ClusterCoordinator`."""
|
||||
|
||||
def __next__(self):
|
||||
return self.get_next()
|
@ -13,21 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`."""
|
||||
"""Multi-process runner tests for `ClusterCoordinator` with PSv2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.client import utils
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||
from tensorflow.python.distribute.coordinator import utils
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -38,7 +37,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
class ClientMprTest(test.TestCase):
|
||||
class ClusterCoordinatorMprTest(test.TestCase):
|
||||
|
||||
def testScheduleTranslatePSFailureError(self):
|
||||
self._test_translate_ps_failure_error(test_schedule=True)
|
||||
@ -56,7 +55,7 @@ class ClientMprTest(test.TestCase):
|
||||
utils.start_server(cluster_resolver, "grpc")
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver)
|
||||
ps_client = client_lib.Client(strategy)
|
||||
ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)
|
||||
|
||||
with strategy.scope():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
@ -68,19 +67,19 @@ class ClientMprTest(test.TestCase):
|
||||
v.assign_add(1)
|
||||
|
||||
# Keep the two workers occupied.
|
||||
ps_client.schedule(worker_fn)
|
||||
ps_client.schedule(worker_fn)
|
||||
ps_coordinator.schedule(worker_fn)
|
||||
ps_coordinator.schedule(worker_fn)
|
||||
# Now the main process can terminate.
|
||||
functions_scheduled_event.set()
|
||||
|
||||
# Verified that join and schedule indeed raise UnavailableError.
|
||||
try:
|
||||
if test_join:
|
||||
ps_client.join()
|
||||
ps_coordinator.join()
|
||||
if test_schedule:
|
||||
while ps_client.cluster._closure_queue._error is None:
|
||||
while ps_coordinator.cluster._closure_queue._error is None:
|
||||
time.sleep(1)
|
||||
ps_client.schedule(worker_fn)
|
||||
ps_coordinator.schedule(worker_fn)
|
||||
except errors.UnavailableError:
|
||||
# The following verifies that after PS fails, continue executing
|
||||
# functions on workers should fail and indicate it's PS failure.
|
||||
@ -91,7 +90,7 @@ class ClientMprTest(test.TestCase):
|
||||
# failure.
|
||||
worker_fn()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if client_lib._is_ps_failure(e):
|
||||
if coordinator_lib._is_ps_failure(e):
|
||||
if worker_id < 2:
|
||||
continue
|
||||
logging.info("_test_translate_ps_failure_error ends properly.")
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for client.py."""
|
||||
"""Tests for coordinator.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -29,8 +29,8 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
@ -52,7 +52,7 @@ from tensorflow.python.util import nest
|
||||
class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
queue = client_lib._CoordinatedClosureQueue()
|
||||
queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
closure1 = self._create_closure(queue._cancellation_mgr)
|
||||
queue.put(closure1)
|
||||
self.assertIs(closure1, queue.get())
|
||||
@ -64,7 +64,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
queue.wait()
|
||||
|
||||
def testProcessAtLeaseOnce(self):
|
||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
||||
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
labels = ['A', 'B', 'C', 'D', 'E']
|
||||
processed_count = collections.defaultdict(int)
|
||||
|
||||
@ -94,7 +94,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
cm = cancellation.CancellationManager()
|
||||
for label in labels:
|
||||
closure_queue.put(client_lib.Closure(get_func(label), cm))
|
||||
closure_queue.put(coordinator_lib.Closure(get_func(label), cm))
|
||||
t1 = threading.Thread(target=process_queue, daemon=True)
|
||||
t1.start()
|
||||
t2 = threading.Thread(target=process_queue, daemon=True)
|
||||
@ -111,7 +111,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
coord.join([t1, t2])
|
||||
|
||||
def testNotifyBeforeWait(self):
|
||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
||||
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
|
||||
def func():
|
||||
logging.info('func running')
|
||||
@ -123,7 +123,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
closure_queue.get()
|
||||
closure_queue.mark_finished()
|
||||
|
||||
closure_queue.put(client_lib.Closure(func, closure_queue._cancellation_mgr))
|
||||
closure_queue.put(
|
||||
coordinator_lib.Closure(func, closure_queue._cancellation_mgr))
|
||||
t = threading.Thread(target=process_queue)
|
||||
t.start()
|
||||
coord.join([t])
|
||||
@ -159,7 +160,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
# TODO(b/165013260): Fix this
|
||||
self.skipTest('Test is currently broken on Windows with Python 3.8')
|
||||
|
||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
||||
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
|
||||
closure = closure_queue.get()
|
||||
|
||||
@ -189,10 +190,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
def some_function():
|
||||
return 1.0
|
||||
|
||||
return client_lib.Closure(some_function, cancellation_mgr)
|
||||
return coordinator_lib.Closure(some_function, cancellation_mgr)
|
||||
|
||||
def _put_two_closures_and_get_one(self):
|
||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
||||
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
closure1 = self._create_closure(closure_queue._cancellation_mgr)
|
||||
closure_queue.put(closure1)
|
||||
|
||||
@ -330,7 +331,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
# Closure2 was an inflight closure when it got cancelled.
|
||||
self.assertEqual(closure2._output_remote_values._status,
|
||||
client_lib._RemoteValueStatus.READY)
|
||||
coordinator_lib._RemoteValueStatus.READY)
|
||||
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
|
||||
@ -361,7 +362,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
def testThreadSafey(self):
|
||||
thread_count = 10
|
||||
queue = client_lib._CoordinatedClosureQueue()
|
||||
queue = coordinator_lib._CoordinatedClosureQueue()
|
||||
|
||||
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
|
||||
# `mark_finished`.
|
||||
@ -427,7 +428,7 @@ class TestCaseWithErrorReportingThread(test.TestCase):
|
||||
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
|
||||
|
||||
|
||||
def make_client(num_workers, num_ps):
|
||||
def make_coordinator(num_workers, num_ps):
|
||||
# TODO(rchao): Test the internal rpc_layer version.
|
||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
|
||||
@ -438,16 +439,16 @@ def make_client(num_workers, num_ps):
|
||||
ClusterSpec(cluster_def), rpc_layer='grpc')
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver)
|
||||
return client_lib.Client(strategy)
|
||||
return coordinator_lib.ClusterCoordinator(strategy)
|
||||
|
||||
|
||||
class ClientTest(TestCaseWithErrorReportingThread):
|
||||
class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ClientTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
super(ClusterCoordinatorTest, cls).setUpClass()
|
||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
def testFnReturnNestedValues(self):
|
||||
x = constant_op.constant(1)
|
||||
@ -456,9 +457,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
def f():
|
||||
return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
|
||||
|
||||
got = self.client.schedule(f)
|
||||
got = self.coordinator.schedule(f)
|
||||
want = 2, (3, 4), [5], {'v': 1}
|
||||
self.assertEqual(self.client.fetch(got), want)
|
||||
self.assertEqual(self.coordinator.fetch(got), want)
|
||||
|
||||
def testInputFunction(self):
|
||||
|
||||
@ -474,12 +475,14 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
v.assign_add(x)
|
||||
return x
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.client.fetch(result)
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||
result = self.coordinator.schedule(
|
||||
worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.coordinator.fetch(result)
|
||||
self.assertEqual(result, (1,))
|
||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.client.fetch(result)
|
||||
result = self.coordinator.schedule(
|
||||
worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.coordinator.fetch(result)
|
||||
self.assertEqual(result, (1,))
|
||||
|
||||
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
|
||||
@ -499,30 +502,30 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
x = next(iterator)
|
||||
v.assign_add(x)
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||
|
||||
iterator = iter(distributed_dataset)
|
||||
|
||||
# Verifying joining without any scheduling doesn't hang.
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
self.assertEqual(v.read_value().numpy(), 0)
|
||||
|
||||
for _ in range(5):
|
||||
self.client.schedule(worker_fn, args=(iterator,))
|
||||
self.client.join()
|
||||
self.coordinator.schedule(worker_fn, args=(iterator,))
|
||||
self.coordinator.join()
|
||||
|
||||
# With 5 addition it should be 2*5 = 10.
|
||||
self.assertEqual(v.read_value().numpy(), 10)
|
||||
|
||||
for _ in range(5):
|
||||
self.client.schedule(worker_fn, args=(iterator,))
|
||||
self.coordinator.schedule(worker_fn, args=(iterator,))
|
||||
|
||||
# Verifying multiple join is fine.
|
||||
self.client.join()
|
||||
self.client.join()
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
self.coordinator.join()
|
||||
self.coordinator.join()
|
||||
|
||||
self.assertTrue(self.client.done())
|
||||
self.assertTrue(self.coordinator.done())
|
||||
|
||||
# Likewise, it's now 20.
|
||||
self.assertEqual(v.read_value().numpy(), 20)
|
||||
@ -542,8 +545,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
def worker_fn(iterator):
|
||||
return next(iterator)
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||
result = self.coordinator.schedule(
|
||||
worker_fn, args=(iter(distributed_dataset),))
|
||||
self.assertEqual(result.fetch(), (10,))
|
||||
self.assertEqual(self._map_fn_tracing_count, 1)
|
||||
|
||||
@ -554,28 +558,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
return v.read_value()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.client.create_per_worker_dataset(input_fn)
|
||||
self.coordinator.create_per_worker_dataset(input_fn)
|
||||
|
||||
def testDatasetsShuffledDifferently(self):
|
||||
# This test requires at least two workers in the cluster.
|
||||
self.assertGreaterEqual(len(self.client.cluster.workers), 2)
|
||||
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
|
||||
|
||||
random_seed.set_random_seed(None)
|
||||
|
||||
def input_fn():
|
||||
return dataset_ops.DatasetV2.range(0, 100).shuffle(100)
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
|
||||
# Get elements from the first two iterators.
|
||||
iterator_1 = distributed_iterator._values[0]
|
||||
iterator_1._rebuild_on(self.client.cluster.workers[0])
|
||||
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
|
||||
iterator_1 = iterator_1.fetch()
|
||||
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
||||
|
||||
iterator_2 = distributed_iterator._values[1]
|
||||
iterator_2._rebuild_on(self.client.cluster.workers[1])
|
||||
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
|
||||
iterator_2 = iterator_2.fetch()
|
||||
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
||||
|
||||
@ -592,12 +596,12 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
self.assertIn('worker', var.device)
|
||||
return var
|
||||
|
||||
worker_local_var = self.client._create_per_worker_resources(create_var)
|
||||
worker_local_var = self.coordinator._create_per_worker_resources(create_var)
|
||||
|
||||
# The following is a workaround to allow `worker_local_var` to be passed in
|
||||
# as args to the `client.schedule` method which requires tensor specs to
|
||||
# trace tf.function but _create_worker_resources' return values don't have
|
||||
# tensor specs. We can get rid of this workaround once
|
||||
# as args to the `coordinator.schedule` method which requires tensor specs
|
||||
# to trace tf.function but _create_worker_resources' return values don't
|
||||
# have tensor specs. We can get rid of this workaround once
|
||||
# _create_worker_resources is able to infer the tensor spec of the return
|
||||
# value of the function passed in. See b/154675763.
|
||||
for var in worker_local_var._values:
|
||||
@ -609,10 +613,10 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
for _ in range(10):
|
||||
# Which slice of `worker_local_var` will be used will depend on which
|
||||
# worker the `worker_fn` gets scheduled on.
|
||||
self.client.schedule(worker_fn, args=(worker_local_var,))
|
||||
self.client.join()
|
||||
self.coordinator.schedule(worker_fn, args=(worker_local_var,))
|
||||
self.coordinator.join()
|
||||
|
||||
var_sum = sum(self.client.fetch(worker_local_var._values))
|
||||
var_sum = sum(self.coordinator.fetch(worker_local_var._values))
|
||||
self.assertEqual(var_sum, 10.0)
|
||||
|
||||
def testDisallowRemoteValueAsInput(self):
|
||||
@ -625,28 +629,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
||||
def func_1(x):
|
||||
return x + 1.0
|
||||
|
||||
remote_v = self.client.schedule(func_0)
|
||||
remote_v = self.coordinator.schedule(func_0)
|
||||
with self.assertRaises(ValueError):
|
||||
self.client.schedule(func_1, args=(remote_v,))
|
||||
self.coordinator.schedule(func_1, args=(remote_v,))
|
||||
|
||||
|
||||
class LimitedClosureQueueSizeBasicTest(ClientTest):
|
||||
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
|
||||
"""Test basic functionality works with explicit maximum closure queue size.
|
||||
|
||||
Execute the same set of test cases as in `ClientTest`, with an
|
||||
Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
|
||||
explicit size limit for the closure queue. Note that even when the queue size
|
||||
is set to infinite, there is still a maximum practical size (depends on host
|
||||
memory limit) that might cause the queue.put operations to be blocking when
|
||||
scheduling a large number of closures on a big cluster. These tests make sure
|
||||
that the client does not run into deadlocks in such scenario.
|
||||
that the coordinator does not run into deadlocks in such scenario.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
|
||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
|
||||
class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
@ -654,8 +658,8 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ErrorReportingTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
with cls.strategy.scope():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
@ -686,80 +690,80 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
def testJoinRaiseError(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
|
||||
def testScheduleRaiseError(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
while True:
|
||||
self.client.schedule(self._normal_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
|
||||
def testScheduleRaiseErrorWithMultipleFailure(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
while True:
|
||||
self.client.schedule(self._error_function)
|
||||
self.client.join()
|
||||
self.coordinator.schedule(self._error_function)
|
||||
self.coordinator.join()
|
||||
|
||||
def testErrorWillbeCleared(self):
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
|
||||
def testRemoteValueReturnError(self):
|
||||
result = self.client.schedule(self._error_function)
|
||||
result = self.coordinator.schedule(self._error_function)
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
result.fetch()
|
||||
|
||||
# Clear the error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
|
||||
def testInputError(self):
|
||||
|
||||
worker_local_val = self.client._create_per_worker_resources(
|
||||
worker_local_val = self.coordinator._create_per_worker_resources(
|
||||
self._error_function)
|
||||
|
||||
@def_function.function
|
||||
def func(x):
|
||||
return x + 1
|
||||
|
||||
result = self.client.schedule(func, args=(worker_local_val,))
|
||||
with self.assertRaises(client_lib.InputError):
|
||||
self.client.join()
|
||||
result = self.coordinator.schedule(func, args=(worker_local_val,))
|
||||
with self.assertRaises(coordinator_lib.InputError):
|
||||
self.coordinator.join()
|
||||
|
||||
with self.assertRaises(client_lib.InputError):
|
||||
with self.assertRaises(coordinator_lib.InputError):
|
||||
result.fetch()
|
||||
|
||||
def testCancellation(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
long_function = self.client.schedule(self._long_function)
|
||||
self.client.schedule(self._error_function)
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
long_function = self.coordinator.schedule(self._long_function)
|
||||
self.coordinator.schedule(self._error_function)
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
self.coordinator.join()
|
||||
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
long_function.fetch()
|
||||
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.join()
|
||||
self.coordinator.schedule(self._normal_function)
|
||||
self.coordinator.join()
|
||||
|
||||
|
||||
class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
||||
@ -772,11 +776,11 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueErrorTest, cls).setUpClass()
|
||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
with cls.client.strategy.scope():
|
||||
with cls.coordinator.strategy.scope():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
|
||||
|
||||
@ -785,8 +789,8 @@ class StrategyIntegrationTest(test.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(StrategyIntegrationTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=1, num_ps=1)
|
||||
cls.strategy = cls.client.strategy
|
||||
cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
def testBasicVariableAssignment(self):
|
||||
self.strategy.extended._variable_count = 0
|
||||
@ -801,9 +805,9 @@ class StrategyIntegrationTest(test.TestCase):
|
||||
v2.assign_sub(0.2)
|
||||
return v1.read_value() / v2.read_value()
|
||||
|
||||
results = self.client.schedule(worker_fn)
|
||||
results = self.coordinator.schedule(worker_fn)
|
||||
logging.info('Results of experimental_run_v2: %f',
|
||||
self.client.fetch(results))
|
||||
self.coordinator.fetch(results))
|
||||
|
||||
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
|
||||
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
|
||||
@ -826,12 +830,14 @@ class StrategyIntegrationTest(test.TestCase):
|
||||
return self.strategy.run(replica_fn, args=(input_tensor,))
|
||||
|
||||
# Asserting scheduling in scope has the expected behavior.
|
||||
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
|
||||
self.assertIsInstance(result, client_lib.RemoteValue)
|
||||
result = self.coordinator.schedule(
|
||||
worker_fn, args=(constant_op.constant(3),))
|
||||
self.assertIsInstance(result, coordinator_lib.RemoteValue)
|
||||
self.assertEqual(result.fetch(), 4)
|
||||
|
||||
# Asserting scheduling out of scope has the expected behavior.
|
||||
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
|
||||
result = self.coordinator.schedule(
|
||||
worker_fn, args=(constant_op.constant(3),))
|
||||
self.assertEqual(result.fetch(), 4)
|
||||
|
||||
|
@ -31,15 +31,15 @@ enable_metrics = False
|
||||
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
|
||||
|
||||
_function_tracing_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets,
|
||||
'/tensorflow/api/ps_strategy/coordinator/function_tracing', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for tracing functions.')
|
||||
|
||||
_closure_execution_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets,
|
||||
'/tensorflow/api/ps_strategy/coordinator/closure_execution', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for executing closures.')
|
||||
|
||||
_remote_value_fetch_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets,
|
||||
'/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for fetching remote_value.')
|
||||
|
||||
_METRICS_MAPPING = {
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metrics collecting in client."""
|
||||
"""Tests for metrics collecting in coordinator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -22,9 +22,9 @@ from __future__ import print_function
|
||||
import time
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.distribute.client import metric_utils
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||
from tensorflow.python.distribute.coordinator import metric_utils
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
@ -35,7 +35,7 @@ class MetricUtilsTest(test.TestCase):
|
||||
def get_rpc_layer(self):
|
||||
return 'grpc'
|
||||
|
||||
def testClientMetrics(self):
|
||||
def testClusterCoordinatorMetrics(self):
|
||||
|
||||
metric_utils.enable_metrics = True
|
||||
|
||||
@ -48,7 +48,7 @@ class MetricUtilsTest(test.TestCase):
|
||||
ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver)
|
||||
cluster = client.Cluster(strategy)
|
||||
cluster = coordinator_lib.Cluster(strategy)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
@ -48,7 +48,7 @@ _LOCAL_CPU = "/device:CPU:0"
|
||||
|
||||
|
||||
# TODO(yuefengz): maybe cache variables on local CPU.
|
||||
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
|
||||
# TODO(b/171250971): Remove this and change all symbol usage of this to V1.
|
||||
class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||
|
||||
@ -154,8 +154,9 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
|
||||
def _raise_pss_error_if_eager(self):
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError("ParameterServerStrategy currently only works "
|
||||
"with the tf.Estimator API")
|
||||
raise NotImplementedError(
|
||||
"`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
|
||||
"currently only works with the tf.Estimator API")
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
|
||||
|
@ -37,82 +37,400 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
|
||||
class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||
"""An multi-worker tf.distribute strategy with parameter servers.
|
||||
|
||||
Currently, `ParameterServerStrategyV2` is not supported to be used as a
|
||||
standalone tf.distribute strategy. It should be used in conjunction with
|
||||
`Client`. Please see `Client` for more information.
|
||||
Parameter server training refers to the distributed training architecture that
|
||||
requires two types of tasks in the cluster: workers (referred to as "worker"
|
||||
task) and parameter servers (referred to as "ps" task). The variables and
|
||||
updates to those variables are placed on ps, and most computation intensive
|
||||
operations are placed on workers.
|
||||
|
||||
This is currently under development, and the API as well as implementation
|
||||
is subject to changes.
|
||||
In TF2, parameter server training makes use of one coordinator, with some
|
||||
number of workers, and (usually fewer) ps. The coordinator uses a
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
|
||||
cluster, and a `tf.distribute.experimental.ParameterServerStrategy` for
|
||||
variable distribution. The coordinator does not perform the actual training.
|
||||
Each of the workers and ps runs a `tf.distribute.Server`, which the
|
||||
coordinator connects to through the use of aforementioned two APIs.
|
||||
|
||||
For the training to work, the coordinator sends requests to workers for the
|
||||
`tf.function`s to be executed on remote workers. Upon receiving requests from
|
||||
the coordinator, a worker executes the `tf.function` by reading the variables
|
||||
from parameter servers, executing the ops, and updating the variables on the
|
||||
parameter servers. Each of the worker only processes the requests from the
|
||||
coordinator, and communicates with parameter servers, without direct
|
||||
interactions with any of the other workers in the cluster.
|
||||
|
||||
As a result, failures of some workers do not prevent the cluster from
|
||||
continuing the work, and this allows the cluster to train with instances that
|
||||
can be occasionally unavailable (e.g. preemptible or spot instances). The
|
||||
coordinator and parameter servers though, must be available at all times for
|
||||
the cluster to make progress.
|
||||
|
||||
Note that the coordinator is not one of the training worker. Instead, its
|
||||
responsibility includes placing variables on ps, remotely executing
|
||||
`tf.function`s on workers, and saving checkpoints. Parameter server training
|
||||
thus consists of a server cluster with worker and ps, and a coordinator which
|
||||
connects to them to coordinate. Optionally, an evaluator can be run on the
|
||||
side that periodically reads the checkpoints saved by the coordinator, and
|
||||
saves summaries for example.
|
||||
|
||||
`tf.distribute.experimental.ParameterServerStrategy` works closely with the
|
||||
associated `tf.distribute.experimental.coordinator.ClusterCoordinator` object,
|
||||
and should be used in conjunction with it. Standalone usage of
|
||||
`tf.distribute.experimental.ParameterServerStrategy` without a
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator` indicates
|
||||
a parameter server training scheme without a centralized coordinator, which is
|
||||
not supported at this time.
|
||||
|
||||
__Example code for coordinator__
|
||||
|
||||
Here's an example usage of the API, with a custom training loop to train a
|
||||
model. This code snippet is intended to be run on (the only) one machine that
|
||||
is designated as the coordinator. Note that `cluster_resolver`,
|
||||
`variable_partitioner`, and `dataset_fn` arguments are explained in the
|
||||
following "Cluster setup", "Variable partitioning", and "Dataset preparation"
|
||||
sections.
|
||||
|
||||
Currently, environment variable `GRPC_FAIL_FAST` needs to be set in all tasks
|
||||
to work around a known hanging issue as the following code illustrates:
|
||||
|
||||
```python
|
||||
# Set the environment variable to allow reporting worker and ps failure to the
|
||||
# coordinator.
|
||||
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||
|
||||
# Prepare a strategy to use with the cluster and variable partitioning info.
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver=...,
|
||||
variable_partitioner=...)
|
||||
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||
strategy=strategy)
|
||||
|
||||
# Prepare a distribute dataset that will place datasets on the workers.
|
||||
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
|
||||
|
||||
with strategy.scope():
|
||||
model = ... # Variables created can possibly be container of variables
|
||||
optimizer, metrics = ... # Keras optimizer/metrics are great choices
|
||||
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
|
||||
checkpoint_manager = tf.train.CheckpointManager(
|
||||
checkpoint, checkpoint_dir, max_to_keep=2)
|
||||
# `load_checkpoint` infers initial epoch from `optimizer.iterations`.
|
||||
initial_epoch = load_checkpoint(checkpoint_manager) or 0
|
||||
|
||||
@tf.function
|
||||
def worker_fn(iterator):
|
||||
|
||||
def replica_fn(inputs):
|
||||
batch_data, labels = inputs
|
||||
# calculate gradient, applying gradient, metrics update etc.
|
||||
|
||||
strategy.run(replica_fn, args=(next(iterator),))
|
||||
|
||||
for epoch in range(initial_epoch, num_epoch):
|
||||
distributed_iterator = iter(distributed_dataset) # Reset iterator state.
|
||||
for step in range(steps_per_epoch):
|
||||
|
||||
# Asynchronously schedule the `worker_fn` to be executed on an arbitrary
|
||||
# worker. This call returns immediately.
|
||||
coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||
|
||||
# `join` blocks until all scheduled `worker_fn`s finish execution. Once it
|
||||
# returns, we can read the metrics and save checkpoints as needed.
|
||||
coordinator.join()
|
||||
logging.info('Metric result: %r', metrics.result())
|
||||
train_accuracy.reset_states()
|
||||
checkpoint_manager.save()
|
||||
```
|
||||
|
||||
__Example code for worker and parameter servers__
|
||||
|
||||
In addition to the coordinator, there should be multiple machines designated
|
||||
as "worker" or "ps". They should run the following code to start a TensorFlow
|
||||
server, waiting for coordinator's request to execute functions or place
|
||||
variables:
|
||||
|
||||
```python
|
||||
# Set the environment variable to allow reporting worker and ps failure to the
|
||||
# coordinator.
|
||||
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||
|
||||
# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
|
||||
# the cluster information. See below "Cluster setup" section.
|
||||
cluster_resolver = ...
|
||||
|
||||
server = tf.distribute.Server(
|
||||
cluster_resolver.cluster_spec().as_cluster_def(),
|
||||
job_name=cluster_resolver.task_type,
|
||||
task_index=cluster_resolver.task_id,
|
||||
protocol=protocol)
|
||||
|
||||
# Blocking the process that starts a server from exiting.
|
||||
server.join()
|
||||
```
|
||||
|
||||
__Cluster setup__
|
||||
|
||||
In order for the tasks in the cluster to know other tasks' addresses,
|
||||
a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
|
||||
in coordinator, worker, and ps. The
|
||||
`tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
|
||||
the cluster information, as well as the task type and id of the current task.
|
||||
See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
|
||||
|
||||
If `TF_CONFIG` environment variable is used for the processes to know the
|
||||
cluster information, a
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used. Note
|
||||
that for legacy reason, "chief" should be used as the task type for the
|
||||
coordinator, as the following example demonstrates. Here we set `TF_CONFIG`
|
||||
in environment variable, intended to be run by the process of the machine
|
||||
designated as the parameter server (task type "ps") and index 1 (the second),
|
||||
in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that the
|
||||
it needs to be set before the use of
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver`.
|
||||
|
||||
Example code for cluster setup:
|
||||
```python
|
||||
os.environ['TF_CONFIG'] = '''
|
||||
{
|
||||
"cluster": {
|
||||
"chief": ["chief.example.com:2222"],
|
||||
"ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
|
||||
"worker": ["worker0.example.com:2222", "worker1.example.com:2222",
|
||||
"worker2.example.com:2222"]
|
||||
},
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 1
|
||||
}
|
||||
}
|
||||
'''
|
||||
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
|
||||
|
||||
# If coordinator ("chief" task type), create a strategy
|
||||
if cluster_resolver.task_type == 'chief':
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver)
|
||||
...
|
||||
|
||||
# If worker/ps, create a server
|
||||
elif cluster_resolver.task_type in ("worker", "ps"):
|
||||
server = tf.distribute.Server(...)
|
||||
...
|
||||
```
|
||||
|
||||
__Variable creation with `strategy.scope()`__
|
||||
|
||||
`tf.distribute.experimental.ParameterServerStrategy` follows the
|
||||
`tf.distribute` API contract where variable creation is expected to be inside
|
||||
the context manager returned by `strategy.scope()`, in order to be correctly
|
||||
placed on parameter servers in a round-robin manner:
|
||||
|
||||
```python
|
||||
# In this example, we're assuming having 3 ps.
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver=...)
|
||||
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||
strategy=strategy)
|
||||
|
||||
# Variables should be created inside scope to be placed on parameter servers.
|
||||
# If created outside scope such as `v1` here, it would be placed on
|
||||
coordinator.
|
||||
v1 = tf.Variable(initial_value=0.0)
|
||||
|
||||
with strategy.scope():
|
||||
v2 = tf.Variable(initial_value=1.0)
|
||||
v3 = tf.Variable(initial_value=2.0)
|
||||
v4 = tf.Variable(initial_value=3.0)
|
||||
v5 = tf.Variable(initial_value=4.0)
|
||||
|
||||
# v2 through v5 are created in scope and are distributed on parameter servers.
|
||||
# Default placement is round-robin but the order should not be relied on.
|
||||
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
|
||||
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
|
||||
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||
```
|
||||
|
||||
See `distribute.Strategy.scope` for more information.
|
||||
|
||||
__Variable partitioning__
|
||||
|
||||
Having dedicated servers to store variables means being able to divide up, or
|
||||
"shard" the variables across the ps. Large embeddings that would otherwise
|
||||
exceed memory limit of a single machine can be used in a cluster with enough
|
||||
number of ps.
|
||||
|
||||
With `tf.distribute.experimental.ParameterServerStrategy`, if a
|
||||
`variable_partitioner` is provided to `__init__` and certain conditions are
|
||||
satisfied, the resulting variables created in scope are sharded across the
|
||||
parameter servers, in a round-robin fashion. The variable reference returned
|
||||
from `tf.Variable` becomes a type that serves as the container of the sharded
|
||||
variables. Access `variables` attribute of this container for the actual
|
||||
variable components. See arguments section of
|
||||
`tf.distribute.experimental.ParameterServerStrategy.__init__` for more
|
||||
information.
|
||||
|
||||
To initialize the sharded variables in a more memory-efficient way, use an
|
||||
initializer whose `__call__` accepts a `shard_info` argument, and use
|
||||
`shard_info.offset` and `shard_info.shape` to create and return a
|
||||
partition-aware `tf.Tensor` to initialize the variable components.
|
||||
|
||||
```python
|
||||
class PartitionAwareIdentity(object):
|
||||
|
||||
def __call__(self, shape, dtype, shard_info):
|
||||
value = tf.eye(*shape, dtype=dtype)
|
||||
if shard_info is not None:
|
||||
value = tf.slice(value, shard_info.offset, shard_info.shape)
|
||||
return value
|
||||
|
||||
cluster_resolver = ...
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver, tf.fixed_size_partitioner(2))
|
||||
with strategy.scope():
|
||||
initializer = PartitionAwareIdentity()
|
||||
initial_value = functools.partial(initializer, shape=(4, 4), dtype=tf.int64)
|
||||
v = tf.Variable(
|
||||
initial_value=initial_value, shape=(4, 4), dtype=tf.int64)
|
||||
|
||||
# `v.variables` gives the actual variable components.
|
||||
assert len(v.variables) == 2
|
||||
assert v.variables[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||
assert v.variables[1].device == "/job:ps/replica:0/task:1/device:CPU:0"
|
||||
assert np.array_equal(v.variables[0].numpy(), [[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert np.array_equal(v.variables[1].numpy(), [[0, 0, 1, 0], [0, 0, 0, 1]])
|
||||
```
|
||||
|
||||
__Dataset preparation__
|
||||
|
||||
With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
|
||||
created in each of the workers to be used for training. This is done by
|
||||
creating a `dataset_fn` that takes no argument and returns a
|
||||
`tf.data.Dataset`, and passing the `dataset_fn` into
|
||||
`tf.distribute.experimental.coordinator.
|
||||
ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
|
||||
shuffled and repeated to have the examples run through the training as evenly
|
||||
as possible.
|
||||
|
||||
```python
|
||||
def dataset_fn():
|
||||
filenames = ...
|
||||
dataset = tf.data.Dataset.from_tensor_slices(filenames)
|
||||
|
||||
# Dataset is recommended to be shuffled, and repeated.
|
||||
return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
|
||||
|
||||
coordinator =
|
||||
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
|
||||
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||
|
||||
```
|
||||
|
||||
__Limitations__
|
||||
|
||||
* `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
|
||||
and the API is subject to further changes.
|
||||
|
||||
* `tf.distribute.experimental.ParameterServerStrategy` does not yet support
|
||||
training with GPU(s). This is a feature request being developed.
|
||||
|
||||
* `tf.distribute.experimental.ParameterServerStrategy` only supports
|
||||
[custom training loop
|
||||
API](https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
currently in TF2. Usage of it with Keras `compile`/`fit` API is being
|
||||
developed.
|
||||
|
||||
* `tf.distribute.experimental.ParameterServerStrategy` must be used with
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||
|
||||
* This strategy is not intended for TPU. Use
|
||||
`tf.distribute.experimental.TPUStrategy` instead.
|
||||
"""
|
||||
|
||||
# pyformat: disable
|
||||
def __init__(self, cluster_resolver, variable_partitioner=None):
|
||||
"""Initializes the V2 parameter server strategy.
|
||||
"""Initializes the TF2 parameter server strategy.
|
||||
|
||||
This also connects to the remote server cluster.
|
||||
This initializes the `tf.distribute.experimental.ParameterServerStrategy`
|
||||
object to be ready for use with
|
||||
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||
|
||||
Args:
|
||||
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
|
||||
object.
|
||||
variable_partitioner: a callable with the signature `num_partitions =
|
||||
fn(shape, dtype)`, where `num_partitions` is a list/tuple representing
|
||||
the number of partitions on each axis, and `shape` and `dtype` are of
|
||||
types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will
|
||||
not be partitioned. * `variable_partitioner` will be called for all
|
||||
variables created under strategy `scope` to instruct how the variables
|
||||
should be partitioned. Variables will be partitioned if there are more
|
||||
than one partitions along the partitioning axis, otherwise it falls back
|
||||
to normal `tf.Variable`. * Only the first / outermost axis partitioning
|
||||
is supported, namely, elements in `num_partitions` must be 1 other than
|
||||
the first element. * Partitioner like `min_max_variable_partitioner`,
|
||||
`variable_axis_size_partitioner` and `fixed_size_partitioner` are also
|
||||
supported since they conform to the required signature. * Div partition
|
||||
variable_partitioner:
|
||||
a callable with the signature `num_partitions = fn(shape, dtype)`, where
|
||||
`num_partitions` is a list/tuple representing the number of partitions
|
||||
on each axis, and `shape` and `dtype` are of types `tf.TensorShape` and
|
||||
`tf.dtypes.Dtype`. If `None`, variables will not be partitioned.
|
||||
|
||||
* `variable_partitioner` will be called for all variables created under
|
||||
strategy `scope` to instruct how the variables should be partitioned.
|
||||
Variables will be created in multiple partitions if there are more than
|
||||
one partition along the partitioning axis, otherwise it falls back to
|
||||
normal `tf.Variable`.
|
||||
|
||||
* Only the first / outermost axis partitioning is supported, namely,
|
||||
elements in `num_partitions` must be 1 other than the first element.
|
||||
|
||||
* Partitioner like `tf.compat.v1.min_max_variable_partitioner`,
|
||||
`tf.compat.v1.variable_axis_size_partitioner` and
|
||||
`tf.compat.v1.fixed_size_partitioner` are also supported since they
|
||||
conform to the required signature.
|
||||
|
||||
* Div partition
|
||||
strategy is used to partition variables. Assuming we assign consecutive
|
||||
integer ids along the first axis of a variable, then ids are assigned to
|
||||
shards in a contiguous manner, while attempting to keep each shard size
|
||||
identical. If the ids do not evenly divide the number of shards, each of
|
||||
the first several shards will be assigned one more id. For instance, a
|
||||
variable whose first dimension is
|
||||
13 has 13 ids, and they are split across 5 shards as: `[[0, 1, 2], [3,
|
||||
4, 5], [6, 7, 8], [9, 10], [11, 12]]`. * Variables created under
|
||||
`strategy.extended.colocate_vars_with` will not be partitioned, e.g,
|
||||
optimizer's slot variables.
|
||||
variable whose first dimension is 13 has 13 ids, and they are split
|
||||
across 5 shards as:
|
||||
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
|
||||
|
||||
* Variables created under `strategy.extended.colocate_vars_with` will
|
||||
not be partitioned, e.g, optimizer's slot variables.
|
||||
"""
|
||||
# pyformat: enable
|
||||
self._cluster_resolver = cluster_resolver
|
||||
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
|
||||
variable_partitioner)
|
||||
self._verify_args_and_config(cluster_resolver)
|
||||
logging.info(
|
||||
"ParameterServerStrategyV2 is initialized with cluster_spec: "
|
||||
"%s", cluster_resolver.cluster_spec())
|
||||
"`tf.distribute.experimental.ParameterServerStrategy` is initialized "
|
||||
"with cluster_spec: %s", cluster_resolver.cluster_spec())
|
||||
|
||||
# TODO(b/167894802): Make chief, worker, and ps names customizable.
|
||||
self._connect_to_cluster(client_name="chief")
|
||||
# TODO(b/167894802): Make coordinator, worker, and ps names customizable.
|
||||
self._connect_to_cluster(coordinator_name="chief")
|
||||
super(ParameterServerStrategyV2, self).__init__(self._extended)
|
||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
||||
"ParameterServerStrategy")
|
||||
|
||||
def _connect_to_cluster(self, client_name):
|
||||
if client_name in ["worker", "ps"]:
|
||||
raise ValueError("Client name should not be 'worker' or 'ps'.")
|
||||
def _connect_to_cluster(self, coordinator_name):
|
||||
if coordinator_name in ["worker", "ps"]:
|
||||
raise ValueError("coordinator name should not be 'worker' or 'ps'.")
|
||||
cluster_spec = self._cluster_resolver.cluster_spec()
|
||||
self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
|
||||
self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
|
||||
|
||||
device_filters = server_lib.ClusterDeviceFilters()
|
||||
# For any worker, only the devices on PS and chief nodes are visible
|
||||
# For any worker, only the devices on ps and coordinator nodes are visible
|
||||
for i in range(self._num_workers):
|
||||
device_filters.set_device_filters(
|
||||
"worker", i, ["/job:ps", "/job:%s" % client_name])
|
||||
# Similarly for any ps, only the devices on workers and chief are visible
|
||||
"worker", i, ["/job:ps", "/job:%s" % coordinator_name])
|
||||
# Similarly for any ps, only the devices on workers and coordinator are
|
||||
# visible
|
||||
for i in range(self._num_ps):
|
||||
device_filters.set_device_filters(
|
||||
"ps", i, ["/job:worker", "/job:%s" % client_name])
|
||||
"ps", i, ["/job:worker", "/job:%s" % coordinator_name])
|
||||
|
||||
# Allow at most one outstanding RPC for each worker at a certain time. This
|
||||
# is to simplify worker failure handling in the runtime
|
||||
@ -122,7 +440,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
self.__class__.__name__, cluster_spec)
|
||||
remote.connect_to_cluster(
|
||||
cluster_spec,
|
||||
job_name=client_name,
|
||||
job_name=coordinator_name,
|
||||
protocol=self._cluster_resolver.rpc_layer,
|
||||
cluster_device_filters=device_filters)
|
||||
|
||||
@ -134,7 +452,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
def _verify_args_and_config(self, cluster_resolver):
|
||||
if not cluster_resolver.cluster_spec():
|
||||
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
|
||||
if self.extended._num_gpus_per_worker > 1:
|
||||
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
|
||||
raise NotImplementedError("Multi-gpu is not supported yet.")
|
||||
|
||||
|
||||
@ -205,8 +523,8 @@ class ParameterServerStrategyV2Extended(
|
||||
init_from_fn = False
|
||||
initial_value = initial_value()
|
||||
if not init_from_fn:
|
||||
# The initial_value is created on client, it will need to be sent to
|
||||
# PS for variable initialization, which can be inefficient and can
|
||||
# The initial_value is created on coordinator, it will need to be sent to
|
||||
# ps for variable initialization, which can be inefficient and can
|
||||
# potentially hit the 2GB limit on protobuf serialization.
|
||||
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
|
||||
dtype = initial_value.dtype
|
||||
|
@ -870,8 +870,8 @@ py_test(
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ParameterServerClient and Keras models."""
|
||||
"""Tests for ClusterCoordinator and Keras models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -27,8 +27,8 @@ from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
def make_client(num_workers, num_ps):
|
||||
def make_coordinator(num_workers, num_ps):
|
||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||
cluster_def["chief"] = [
|
||||
@ -52,7 +52,7 @@ def make_client(num_workers, num_ps):
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc")
|
||||
return client_lib.Client(
|
||||
return coordinator_lib.ClusterCoordinator(
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ class KPLTest(test.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(KPLTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||
|
||||
def testTrainAndServe(self):
|
||||
# These vocabularies usually come from TFT or a Beam pipeline.
|
||||
@ -71,10 +71,10 @@ class KPLTest(test.TestCase):
|
||||
]
|
||||
label_vocab = ["yes", "no"]
|
||||
|
||||
with self.client.strategy.scope():
|
||||
with self.coordinator.strategy.scope():
|
||||
|
||||
# Define KPLs under strategy's scope. Right now, if they have look up
|
||||
# tables, they will be created on the client. Their variables will be
|
||||
# tables, they will be created on the coordinator. Their variables will be
|
||||
# created on PS. Ideally they should be cached on each worker since they
|
||||
# will not be changed in a training step.
|
||||
feature_lookup_layer = string_lookup.StringLookup()
|
||||
@ -113,7 +113,7 @@ class KPLTest(test.TestCase):
|
||||
label = "yes" if "avenger" in features else "no"
|
||||
yield {"features": features, "label": label}
|
||||
|
||||
# The dataset will be created on the client?
|
||||
# The dataset will be created on the coordinator?
|
||||
raw_dataset = dataset_ops.Dataset.from_generator(
|
||||
feature_and_label_gen,
|
||||
output_types={
|
||||
@ -131,7 +131,8 @@ class KPLTest(test.TestCase):
|
||||
}, [x["label"]]))
|
||||
return train_dataset
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(dataset_fn)
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(
|
||||
dataset_fn)
|
||||
|
||||
model_input = keras.layers.Input(
|
||||
shape=(3,), dtype=dtypes.int64, name="model_input")
|
||||
@ -163,12 +164,12 @@ class KPLTest(test.TestCase):
|
||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||
accuracy.update_state(labels, actual_pred)
|
||||
|
||||
self.client._strategy.run(train_step, args=(iterator,))
|
||||
self.coordinator._strategy.run(train_step, args=(iterator,))
|
||||
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
for _ in range(10):
|
||||
self.client.schedule(worker_fn, args=(distributed_iterator,))
|
||||
self.client.join()
|
||||
self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||
self.coordinator.join()
|
||||
self.assertGreater(accuracy.result().numpy(), 0.0)
|
||||
|
||||
# Create a saved model.
|
||||
|
@ -362,7 +362,8 @@ class Model(training_lib.Model):
|
||||
if isinstance(self._distribution_strategy,
|
||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
||||
parameter_server_strategy.ParameterServerStrategy)):
|
||||
raise NotImplementedError('ParameterServerStrategy currently only works '
|
||||
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
|
||||
'erServerStrategy` currently only works '
|
||||
'with the tf.Estimator API')
|
||||
|
||||
if not self._experimental_run_tf_function:
|
||||
|
@ -64,6 +64,9 @@ from tensorflow.python.framework.test_combinations import *
|
||||
from tensorflow.python.util.tf_decorator import make_decorator
|
||||
from tensorflow.python.util.tf_decorator import unwrap
|
||||
|
||||
from tensorflow.python.distribute.parameter_server_strategy_v2 import *
|
||||
from tensorflow.python.distribute.coordinator.cluster_coordinator import *
|
||||
|
||||
tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator)
|
||||
tf_export('__internal__.decorator.unwrap', v1=[])(unwrap)
|
||||
|
||||
|
@ -31,6 +31,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"distribute/__init__.py",
|
||||
"distribute/cluster_resolver/__init__.py",
|
||||
"distribute/experimental/__init__.py",
|
||||
"distribute/experimental/coordinator/__init__.py",
|
||||
"dtypes/__init__.py",
|
||||
"errors/__init__.py",
|
||||
"experimental/__init__.py",
|
||||
|
@ -27,6 +27,8 @@ from tensorflow.lite.python import lite as _tflite_for_api_traversal
|
||||
from tensorflow.python import modules_with_exports
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import test_combinations
|
||||
# pylint: enable=unused-import
|
||||
|
@ -1,6 +1,6 @@
|
||||
path: "tensorflow.distribute.experimental.ParameterServerStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy_v2.ParameterServerStrategyV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -18,7 +18,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'cluster_resolver\', \'variable_partitioner\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
|
@ -0,0 +1,33 @@
|
||||
path: "tensorflow.distribute.experimental.coordinator.ClusterCoordinator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.ClusterCoordinator\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "strategy"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "create_per_worker_dataset"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "done"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "fetch"
|
||||
argspec: "args=[\'self\', \'val\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "join"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "schedule"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.experimental.coordinator.PerWorkerValues"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.PerWorkerValues\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.distribute.experimental.coordinator.RemoteValue"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.RemoteValue\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "fetch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
path: "tensorflow.distribute.experimental.coordinator"
|
||||
tf_module {
|
||||
member {
|
||||
name: "ClusterCoordinator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "PerWorkerValues"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RemoteValue"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
@ -36,4 +36,8 @@ tf_module {
|
||||
name: "ValueContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "coordinator"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
||||
|
@ -154,9 +154,9 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/tools/common:test_module1",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute/client:client",
|
||||
"//tensorflow/python/distribute/client:remote_eager_lib",
|
||||
"//tensorflow/python/distribute/client:metric_utils",
|
||||
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||
"//tensorflow/python/distribute/coordinator:remote_eager_lib",
|
||||
"//tensorflow/python/distribute/coordinator:metric_utils",
|
||||
]
|
||||
|
||||
# On Windows, python binary is a zip file of runfiles tree.
|
||||
|
Loading…
Reference in New Issue
Block a user