PSv2: Export a few tf.distribute symbols related to TF2 parameter server training.

This change exports the following class symbols, and adds relevant documentation and example code to

tf.distribute.experimental.ParameterServerStrategy
tf.distribute.experimental.coordinator.ClusterCoordinator
tf.distribute.experimental.coordinator.PerWorkerValues
tf.distribute.experimental.coordinator.RemoteValue

PiperOrigin-RevId: 338151262
Change-Id: If2d1c513d30a999c728cecc2e73b75adda1948c2
This commit is contained in:
Rick Chao 2020-10-20 15:34:17 -07:00 committed by TensorFlower Gardener
parent 833b3a49a9
commit 32f35aabce
24 changed files with 911 additions and 309 deletions

View File

@ -208,7 +208,16 @@
how many times the function is called, and independent of global seed how many times the function is called, and independent of global seed
settings. settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * (Experimental) Parameter server training:
* Replaced the existing
`tf.distribute.experimental.ParameterServerStrategy` symbol with
a new class that is for parameter server training in TF2. Usage with
the old symbol, usually with Estimator, should be replaced with
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
* Added `tf.distribute.experimental.coordinator.*` namespace,
including the main API `ClusterCoordinator` for coordinating the
training cluster, the related data structure `RemoteValue`
and `PerWorkerValue`.
* `tf.keras`: * `tf.keras`:
* Improvements from the functional API refactoring: * Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global * Functional model construction does not need to maintain a global

View File

@ -153,8 +153,9 @@ py_library(
":multi_process_runner", ":multi_process_runner",
":multi_worker_test_base", ":multi_worker_test_base",
":one_device_strategy", ":one_device_strategy",
":parameter_server_strategy_v2",
":sharded_variable", ":sharded_variable",
"//tensorflow/python/distribute/client", "//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/distribute/experimental", "//tensorflow/python/distribute/experimental",
], ],
) )
@ -1880,6 +1881,7 @@ tf_py_test(
":multi_worker_test_base", ":multi_worker_test_base",
":parameter_server_strategy_v2", ":parameter_server_strategy_v2",
":sharded_variable", ":sharded_variable",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:extra_py_tests_deps",

View File

@ -8,8 +8,8 @@ package(
exports_files(["LICENSE"]) exports_files(["LICENSE"])
py_library( py_library(
name = "client", name = "cluster_coordinator",
srcs = ["client.py"], srcs = ["cluster_coordinator.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":metric_utils", ":metric_utils",
@ -34,9 +34,9 @@ py_library(
) )
tf_py_test( tf_py_test(
name = "client_test", name = "cluster_coordinator_test",
size = "small", size = "small",
srcs = ["client_test.py"], srcs = ["cluster_coordinator_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 50, shard_count = 50,
tags = [ tags = [
@ -44,7 +44,7 @@ tf_py_test(
"notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards "notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards
], ],
deps = [ deps = [
":client", ":cluster_coordinator",
"//tensorflow/python:check_ops", "//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
@ -66,12 +66,13 @@ tf_py_test(
) )
tf_py_test( tf_py_test(
name = "client_mpr_test", name = "cluster_coordinator_mpr_test",
srcs = ["client_mpr_test.py"], srcs = ["cluster_coordinator_mpr_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 2, shard_count = 2,
tags = ["no_oss"], # TODO(b/162119374) tags = ["no_oss"], # TODO(b/162119374)
deps = [ deps = [
":cluster_coordinator",
":remote_eager_lib", ":remote_eager_lib",
":utils", ":utils",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
@ -81,7 +82,6 @@ tf_py_test(
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
@ -102,7 +102,7 @@ tf_py_test(
srcs = ["metric_utils_test.py"], srcs = ["metric_utils_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":client", ":cluster_coordinator",
":metric_utils", ":metric_utils",
"//tensorflow/python:training_server_lib", "//tensorflow/python:training_server_lib",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Module for `Client` and relevant cluster-worker related library. """Module for `ClusterCoordinator` and relevant cluster-worker related library.
This is currently under development and the API is subject to change. This is currently under development and the API is subject to change.
""" """
@ -35,7 +35,7 @@ from six.moves import queue
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import metric_utils from tensorflow.python.distribute.coordinator import metric_utils
from tensorflow.python.eager import cancellation from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -48,6 +48,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Maximum time for failed worker to come back is 1 hour # Maximum time for failed worker to come back is 1 hour
_WORKER_MAXIMUM_RECOVERY_SEC = 3600 _WORKER_MAXIMUM_RECOVERY_SEC = 3600
@ -56,9 +57,10 @@ _WORKER_MAXIMUM_RECOVERY_SEC = 3600
# When the maximum queue size is reached, further schedule calls will become # When the maximum queue size is reached, further schedule calls will become
# blocking until some previously queued closures are executed on workers. # blocking until some previously queued closures are executed on workers.
# Note that using an "infinite" queue size can take a non-trivial portion of # Note that using an "infinite" queue size can take a non-trivial portion of
# memory, and even lead to client OOM. Modify the size to a smaller value for # memory, and even lead to coordinator OOM. Modify the size to a smaller value
# client with constrained memory resource (only recommended for advanced users). # for coordinator with constrained memory resource (only recommended for
# Also used in unit tests to ensure the correctness when the queue is full. # advanced users). Also used in unit tests to ensure the correctness when the
# queue is full.
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 _CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
# RPC error message from PS # RPC error message from PS
@ -99,18 +101,77 @@ class _RemoteValueStatus(enum.Enum):
READY = "READY" READY = "READY"
@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
class RemoteValue(object): class RemoteValue(object):
"""An asynchronously available value of a remotely executed function. """An asynchronously available value of a remotely executed function.
`RemoteValue` class is used as the return value of `Client.schedule()` where `tf.distribute.experimental.coordinator.RemoteValue` class is used as the
the underlying concrete value comes at a later time once the function has been return value of
remotely executed. `RemoteValue` can be used as an input to a subsequent `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` where
function scheduled with `Client.schedule()`. the underlying concrete value becomes available at a later time once the
function has been remotely executed. The underlying concrete value is the
`tf.Tensor.numpy()` result of the `tf.Tensor`, or the structure of
`tf.Tensor`s, returned from the `tf.function` that was scheduled.
Note: this class is not thread-safe. `tf.distribute.experimental.coordinator.RemoteValue` to be used as an input to
a subsequent function scheduled with
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` is
currently not supported.
Example:
```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
with strategy.scope():
v1 = tf.Variable(initial_value=0.0)
v2 = tf.Variable(initial_value=1.0)
@tf.function
def worker_fn():
v1.assign_add(0.1)
v2.assign_sub(0.2)
return v1.read_value() / v2.read_value()
result = coordinator.schedule(worker_fn)
# Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
assert result.fetch() == 0.125
for _ in range(10):
# `worker_fn` will be run on arbitrary workers that are available. The
# `result` value will be non-deterministic because the workers are executing
# the functions asynchronously.
result = coordinator.schedule(worker_fn)
```
""" """
def __init__(self, closure, type_spec): def fetch(self):
"""Wait for the result of `RemoteValue` to be ready and return the result.
This makes the value concrete by copying the remote value to local.
Returns:
The actual output of the `tf.function` associated with this `RemoteValue`,
previously by a
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
This can be a single value, or a structure of values, depending on the
output of the `tf.function`.
Raises:
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure, and the user should handle and
reschedule.
"""
raise NotImplementedError("Must be implemented in subclasses.")
class RemoteValueImpl(RemoteValue):
"""Implementation of `RemoteValue`."""
def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
self._closure = closure self._closure = closure
# The type spec for this `RemoteValue` which is used to trace functions that # The type spec for this `RemoteValue` which is used to trace functions that
# take this `RemoteValue` as input. # take this `RemoteValue` as input.
@ -157,16 +218,6 @@ class RemoteValue(object):
self._type_spec = func_graph.convert_structure_to_signature(type_spec) self._type_spec = func_graph.convert_structure_to_signature(type_spec)
def fetch(self): def fetch(self):
"""Wait for the result of RemoteValue to be ready and return the result.
Returns:
The remote value, as a numpy data type (if scalar) or ndarray.
Raises:
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure, and the user should handle and
reschedule.
"""
self._status_available_event.wait() self._status_available_event.wait()
if self._status is _RemoteValueStatus.ABORTED: if self._status is _RemoteValueStatus.ABORTED:
raise errors.CancelledError( raise errors.CancelledError(
@ -241,8 +292,23 @@ def _maybe_as_type_spec(val):
return val return val
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(object): class PerWorkerValues(object):
"""Holds a list of per worker values.""" """A container that holds a list of values, one value per worker.
`tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
of values, where each of the values represents a resource that are located in
individual workers, and upon being used as one of the `args` or `kwargs` of
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
value specific to a worker will be passed into the function being executed at
that particular worker.
Currently, the only supported path to create an object of
`tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
`iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
distributed dataset instance. The mechanism to create a custom
`tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
"""
def __init__(self, values): def __init__(self, values):
self._values = tuple(values) self._values = tuple(values)
@ -262,9 +328,10 @@ def _disallow_remote_value_as_input(structured):
def _raise_if_remote_value(x): def _raise_if_remote_value(x):
if isinstance(x, RemoteValue): if isinstance(x, RemoteValue):
raise ValueError("RemoteValue cannot be used as an input to scheduled " raise ValueError(
"function. Please file a feature request if you need " "`tf.distribute.experimental.coordinator.RemoteValue` used "
"this feature.") "as an input to scheduled function is not yet "
"supported.")
nest.map_structure(_raise_if_remote_value, structured) nest.map_structure(_raise_if_remote_value, structured)
@ -274,8 +341,8 @@ class Closure(object):
def __init__(self, function, cancellation_mgr, args=None, kwargs=None): def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function): if not callable(function):
raise ValueError("Function passed to `Client.schedule` must be a " raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
"callable object.") "be a callable object.")
self._args = args or () self._args = args or ()
self._kwargs = kwargs or {} self._kwargs = kwargs or {}
@ -287,9 +354,9 @@ class Closure(object):
replica_kwargs = _select_worker_slice(0, self._kwargs) replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of # Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The client # failure will not raise exceptions as designed in the runtime. The
# has to rely on subsequent operations that raise to catch function # coordinator has to rely on subsequent operations that raise to catch
# registration failure. # function registration failure.
# Record the function tracing overhead. Note that we pass in the tracing # Record the function tracing overhead. Note that we pass in the tracing
# count of the def_function.Function as a state tracker, so that metrics # count of the def_function.Function as a state tracker, so that metrics
@ -303,17 +370,18 @@ class Closure(object):
self._function = cancellation_mgr.get_cancelable_function( self._function = cancellation_mgr.get_cancelable_function(
concrete_function) concrete_function)
self._output_remote_values = nest.map_structure( self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), concrete_function.structured_outputs) lambda x: RemoteValueImpl(self, x),
concrete_function.structured_outputs)
elif isinstance(function, tf_function.ConcreteFunction): elif isinstance(function, tf_function.ConcreteFunction):
self._function = cancellation_mgr.get_cancelable_function(function) self._function = cancellation_mgr.get_cancelable_function(function)
self._output_remote_values = nest.map_structure( self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), function.structured_outputs) lambda x: RemoteValueImpl(self, x), function.structured_outputs)
else: else:
# Regular python functions. # Regular python functions.
self._function = function self._function = function
# TODO(yuefengz): maybe we should trace python functions if their inputs # TODO(yuefengz): maybe we should trace python functions if their inputs
# are Python primitives, tensors and composite tensors. # are Python primitives, tensors and composite tensors.
self._output_remote_values = RemoteValue(self, None) self._output_remote_values = RemoteValueImpl(self, None)
def _fetch_output_remote_values(self): def _fetch_output_remote_values(self):
"""Temporary method used to sync the scheduler.""" """Temporary method used to sync the scheduler."""
@ -394,7 +462,7 @@ class _CoordinatedClosureQueue(object):
if _CLOSURE_QUEUE_MAX_SIZE <= 0: if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning( logging.warning(
"In a `Client`, creating an infinite closure queue can " "In a `ClusterCoordinator`, creating an infinite closure queue can "
"consume a significant amount of memory and even lead to OOM.") "consume a significant amount of memory and even lead to OOM.")
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None self._error = None
@ -430,10 +498,10 @@ class _CoordinatedClosureQueue(object):
# The cancellation manager cannot be reused once cancelled. After all # The cancellation manager cannot be reused once cancelled. After all
# closures (queued or inflight) are cleaned up, recreate the cancellation # closures (queued or inflight) are cleaned up, recreate the cancellation
# manager with clean state. # manager with clean state.
# Note on thread-safety: this is triggered when one of theses client APIs # Note on thread-safety: this is triggered when one of theses
# are called: `schedule`, `wait`, and `done`. At the same time, no new # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
# closures can be constructed (which reads the _cancellation_mgr to get # same time, no new closures can be constructed (which reads the
# cancellable functions). # _cancellation_mgr to get cancellable functions).
self._cancellation_mgr = cancellation.CancellationManager() self._cancellation_mgr = cancellation.CancellationManager()
def _raise_if_error(self): def _raise_if_error(self):
@ -742,8 +810,8 @@ class Worker(object):
def _register_resource(self, resource_remote_value): def _register_resource(self, resource_remote_value):
if not isinstance(resource_remote_value, RemoteValue): if not isinstance(resource_remote_value, RemoteValue):
raise ValueError( raise ValueError("Resource being registered is not of type "
"Resource being registered is not of type `RemoteValue`.") "`tf.distribute.experimental.coordinator.RemoteValue`.")
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
@ -753,7 +821,7 @@ class Cluster(object):
We assume all function errors are fatal and based on this assumption our We assume all function errors are fatal and based on this assumption our
error reporting logic is: error reporting logic is:
1) Both `schedule` and `join` can raise a non-retryable error which is the 1) Both `schedule` and `join` can raise a non-retryable error which is the
first error seen by the client from any previously scheduled functions. first error seen by the coordinator from any previously scheduled functions.
2) When an error is raised, there is no guarantee on how many previously 2) When an error is raised, there is no guarantee on how many previously
scheduled functions have been executed; functions that have not been executed scheduled functions have been executed; functions that have not been executed
will be thrown away and marked as cancelled. will be thrown away and marked as cancelled.
@ -775,17 +843,17 @@ class Cluster(object):
# Ignore PS failures reported by workers due to transient connection errors. # Ignore PS failures reported by workers due to transient connection errors.
# Transient connectivity issues between workers and PS are relayed by the # Transient connectivity issues between workers and PS are relayed by the
# workers to the client, leading the client to believe that there are PS # workers to the coordinator, leading the coordinator to believe that there
# failures. The difference between transient vs. permanent PS failure is the # are PS failures. The difference between transient vs. permanent PS failure
# number of reports from the workers. When this env var is set to a positive # is the number of reports from the workers. When this env var is set to a
# integer K, the client ignores up to K reports of a failed PS task. I.e., # positive integer K, the coordinator ignores up to K reports of a failed PS
# only when there are more than K trials of executing closures fail due to # task, i.e., only when there are more than K trials of executing closures
# errors from the same PS instance do we consider the PS instance encounters # fail due to errors from the same PS instance do we consider the PS
# a failure. # instance encounters a failure.
# TODO(b/164279603): Remove this workaround when the underlying connectivity # TODO(b/164279603): Remove this workaround when the underlying connectivity
# issue in gRPC server is resolved. # issue in gRPC server is resolved.
self._transient_ps_failures_threshold = int(os.environ.get( self._transient_ps_failures_threshold = int(
"TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3)) os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._potential_ps_failures_lock = threading.Lock() self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps self._potential_ps_failures_count = [0] * self._num_ps
@ -825,7 +893,7 @@ class Cluster(object):
kwargs: Keyword arguments for `fn`. kwargs: Keyword arguments for `fn`.
Returns: Returns:
A structure of `RemoteValue` object. A `RemoteValue` object.
""" """
closure = Closure( closure = Closure(
function, function,
@ -844,68 +912,124 @@ class Cluster(object):
return self._closure_queue.done() return self._closure_queue.done()
class Client(object): @tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
"""An object to schedule and orchestrate remote function execution. class ClusterCoordinator(object):
"""An object to schedule and coordinate remote function execution.
A `Client` object represents a program used to create dataset, schedule A `tf.distribute.experimental.coordinator.ClusterCoordinator` object
functions to be executed, and fetch the results of the functions. represents a program used to distribute dataset onto the workers, schedule
functions to be executed, and fetch the results of the functions. It expects
the cluster to contain some machines with processes running TensorFlow
servers.
Currently, `Client` is not supported to be used in a standalone manner. Currently, `tf.distribute.experimental.coordinator.ClusterCoordinator` is not
It should be used in conjunction with `ParameterServerStrategyV2`. supported to be used in a standalone manner. It should be used in conjunction
with a `tf.distribute` strategy that is designed to work with it. Currently,
only `tf.distribute.experimental.ParameterServerStrategy` is supported to work
with `tf.distribute.experimental.coordinator.ClusterCoordinator`.
__Fault tolerance__
`tf.distribute.experimental.coordinator.ClusterCoordinator`, when used with
`tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
fault tolerance for worker failures. That is, when some workers are not
available for any reason to be reached from the coordinator, the training
progress continues to be made with the remaining operating workers, without
the need of any additional treatment in user code.
On the other hand, when a parameter server or the coordinator fails, a
`tf.errors.UnavailableError` is raised by `ClusterCoordinator.schedule` or
`ClusterCoordinator.join`. If any parameter server fails, the user should
restart the processes on the failed parameter servers when they become
available, *and* restart the process on the coordinator, so the coordinator
can re-create the variables on the parameter servers. If the coordinator fails
but all other machines continue to be operating, the user only needs to
restart the process on the coordinator, which will automatically connect to
the parameter servers and workers, and continue the progress.
It is thus essential that in user's custom training loop, a checkpoint file is
periodically saved, and restored at the start of the program.
* At-least-once semantics of `schedule`: `schedule` puts the `tf.function` in
a queue where it gets picked up by a worker to execute. If a worker picks up
a `tf.function`, has begun to execute it, but the worker process is
disconnected from the coordinator in the middle of execution, the
`tf.function` is deemed not completed, and is put back to the queue. This is
regardless of how much progress the `tf.function` has run. Once it is
re-queued, it will be picked up again by an arbitrary available worker at some
later time. As a result, the same function may be executed more than once, but
not less than once. If an `tf.keras.optimizers.Optimizer` is used,
`tf.keras.optimizers.Optimizer.iterations` roughly indicates the number of
times the gradients have been applied and can be used as an approximation of
the total number of steps. The number of times the function has been scheduled
by the `tf.distribute.experimental.coordinator.ClusterCoordinator`, on the
other hand, should not be indicative of actual steps run. See
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` for more
information.
See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
example usage of this API.
This is currently under development, and the API as well as implementation This is currently under development, and the API as well as implementation
is subject to changes. are subject to changes.
""" """
def __init__(self, strategy): def __init__(self, strategy):
"""Initialization of a `Client` instance. """Initialization of a `ClusterCoordinator` instance.
This connects the client to remote workers and parameter servers, through
a `tf.config.experimental_connect_to_cluster` call.
Args: Args:
strategy: a `tf.distribute.Strategy` object. Currently, only strategy: a supported `tf.distribute.Strategy` object. Currently, only
`ParameterServerStrategyV2` is supported. `tf.distribute.experimental.ParameterServerStrategy` is supported.
Raises: Raises:
ValueError: if the strategy being used is not supported. ValueError: if the strategy being used is not supported.
""" """
if not isinstance(strategy, if not isinstance(strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2): parameter_server_strategy_v2.ParameterServerStrategyV2):
raise ValueError("Only `ParameterServerStrategyV2` is supported in " raise ValueError(
"`Client` currently.") "Only `tf.distribute.experimental.ParameterServerStrategy` "
"is supported to work with "
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
"currently.")
self._strategy = strategy self._strategy = strategy
self.cluster = Cluster(strategy) self.cluster = Cluster(strategy)
@property @property
def strategy(self): def strategy(self):
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
return self._strategy return self._strategy
def schedule(self, fn, args=None, kwargs=None): def schedule(self, fn, args=None, kwargs=None):
"""Schedules `fn` to be dispatched to a worker for execution asynchronously. """Schedules `fn` to be dispatched to a worker for asynchronous execution.
When calling `schedule` with a function `fn`, `fn` will be executed on a When calling `schedule` with a function `fn`, `fn` will be executed on a
remote worker at some later time. The process is asynchronous, meaning remote worker at some later time. The process is asynchronous, meaning
`schedule` returns immediately, possibly without having the result ready `schedule` returns immediately, possibly without having the result ready
yet. `schedule` returns a structure of `RemoteValue` object, which wraps the yet. `schedule` returns a
output of the function. Call `fetch()` on `RemoteValue` to wait for the `tf.distribute.experimental.coordinator.RemoteValue` object, which wraps the
output of the function. After `schedule` is called, `fetch` can be called
on the `tf.distribute.experimental.coordinator.RemoteValue` to wait for the
function execution to finish and retrieve its output from the remote worker. function execution to finish and retrieve its output from the remote worker.
On the other hand, call
`tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
all scheduled functions to finish execution before proceeding.
`schedule` guarantees that `fn` will be executed on a worker at least once; `schedule` guarantees that `fn` will be executed on a worker at least once;
it could be more than once if its corresponding worker fails in the middle it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially executing the function, it is possible that the function is partially
executed, but `Client` guarantees that in those events, the function will executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
eventually be fully executed, possibly on a different worker that is guarantees that in those events, the function will eventually be fully
available. executed, possibly on a different worker that is available.
If any previously scheduled function raises an error, `schedule` will fail If any previously scheduled function raises an error, `schedule` will fail
by raising any one of those errors, and clear the errors collected so far. by raising any one of those errors, and clear the errors collected so far.
There are two implications when this happens: 1) user should call `schedule` There are two implications when this happens: 1) user should call `schedule`
with `fn` again to re-schedule, and 2) some of the previously scheduled with `fn` again to re-schedule, and 2) some of the previously scheduled
functions may have not been executed. User can call `fetch` on the returned functions may have not been executed. User can call `fetch` on the returned
`RemoteValue` to inspect if they have executed, failed, or cancelled, and `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
reschedule the corresponding function if needed. executed, failed, or cancelled, and reschedule the corresponding function if
needed.
When `schedule` raises, it guarantees that there is no function that is When `schedule` raises, it guarantees that there is no function that is
still being executed. still being executed.
@ -914,12 +1038,14 @@ class Client(object):
execution, or priority of the workers. execution, or priority of the workers.
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
executed on a worker. They can be `PerWorkerValues`, which is a collection executed on a worker. They can be
of values, each of which represents a component specific to a worker; in `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
this case, the argument will be substituted with the corresponding component 'collection of values, each of which represents a component specific to a
on the target worker. Arguments that are not `PerWorkerValues` will be worker; in this case, the argument will be substituted with the
passed into `fn` as-is. Currently, `RemoteValue` is not supported to be corresponding component on the target worker. Arguments that are not
input `args` or `kwargs`. `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
`fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
is not supported to be input `args` or `kwargs`.
Args: Args:
fn: A `tf.function`; the function to be dispatched to a worker for fn: A `tf.function`; the function to be dispatched to a worker for
@ -928,12 +1054,13 @@ class Client(object):
kwargs: Keyword arguments for `fn`. kwargs: Keyword arguments for `fn`.
Returns: Returns:
A structure of `RemoteValue` object. A `tf.distribute.experimental.coordinator.RemoteValue` object that
represents the output of the function scheduled.
Raises: Raises:
Exception: one of the exceptions caught by the client by any previously Exception: one of the exceptions caught by the coordinator by any
scheduled function since the last time an error was thrown or since previously scheduled function since the last time an error was thrown or
the beginning of the program. since the beginning of the program.
""" """
# Slot variables are usually created during function tracing time; thus # Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`. # `schedule` needs to be called within the `strategy.scope()`.
@ -946,18 +1073,18 @@ class Client(object):
If any previously scheduled function raises an error, `join` will fail by If any previously scheduled function raises an error, `join` will fail by
raising any one of those errors, and clear the errors collected so far. If raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may have not been this happens, some of the previously scheduled functions may have not been
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if executed. Users can call `fetch` on the returned
they have executed, failed, or cancelled. If some that have been cancelled `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
need to be rescheduled, users should call `schedule` with the function executed, failed, or cancelled. If some that have been cancelled need to be
again. rescheduled, users should call `schedule` with the function again.
When `join` returns or raises, it guarantees that there is no function that When `join` returns or raises, it guarantees that there is no function that
is still being executed. is still being executed.
Raises: Raises:
Exception: one of the exceptions caught by the client by any previously Exception: one of the exceptions caught by the coordinator by any
scheduled function since the last time an error was thrown or since previously scheduled function since the last time an error was thrown or
the beginning of the program. since the beginning of the program.
""" """
self.cluster.join() self.cluster.join()
@ -969,6 +1096,9 @@ class Client(object):
When `done` returns True or raises, it guarantees that there is no function When `done` returns True or raises, it guarantees that there is no function
that is still being executed. that is still being executed.
Returns:
Whether all the scheduled functions have finished execution.
""" """
return self.cluster.done() return self.cluster.done()
@ -978,12 +1108,15 @@ class Client(object):
This creates the given dataset generated by dataset_fn on the workers This creates the given dataset generated by dataset_fn on the workers
and returns an object that represents the collection of those individual and returns an object that represents the collection of those individual
datasets. Calling `iter` on such collection of dataset returns a datasets. Calling `iter` on such collection of dataset returns a
`PerWorkerValues`, which is a collection of iterators, where the iterators `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
have been placed on respective workers. collection of iterators, where the iterators have been placed on respective
workers.
Calling `next` on this `PerWorkerValues` of iterators is currently Calling `next` on this
unsupported; it is meant to be passed as an argument into `Client.schedule`. `tf.distribute.experimental.coordinator.PerWorkerValues` of iterators is
When the scheduled function is picked up and being executed by a worker, the currently unsupported; it is meant to be passed as an argument into
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
the scheduled function is picked up and being executed by a worker, the
function will receive the individual iterator that corresponds to the function will receive the individual iterator that corresponds to the
worker, and now `next` can be called on iterator to get the next (batch or worker, and now `next` can be called on iterator to get the next (batch or
example) of data. example) of data.
@ -993,6 +1126,26 @@ class Client(object):
examples may be skipped and not covered by other workers, if the dataset is examples may be skipped and not covered by other workers, if the dataset is
sharded. sharded.
Example:
```python
@tf.function
def worker_fn(iterator):
return next(iterator)
def dataset_fn():
return tf.data.from_tensor_slices([3] * 3)
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
per_worker_dataset = coordinator.create_per_worker_dataset(dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3
```
Args: Args:
dataset_fn: The dataset function that returns a dataset. This is to be dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers. executed on the workers.
@ -1000,7 +1153,8 @@ class Client(object):
Returns: Returns:
An object that represents the collection of those individual An object that represents the collection of those individual
datasets. `iter` is expected to be called on this object that returns datasets. `iter` is expected to be called on this object that returns
a `PerWorkerValues` of the iterators (that are on the workers). a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
iterators (that are on the workers).
""" """
input_workers = input_lib.InputWorkers([ input_workers = input_lib.InputWorkers([
(w.device_name, [w.device_name]) for w in self.cluster.workers (w.device_name, [w.device_name]) for w in self.cluster.workers
@ -1011,7 +1165,8 @@ class Client(object):
def _create_per_worker_resources(self, fn, args=None, kwargs=None): def _create_per_worker_resources(self, fn, args=None, kwargs=None):
"""Synchronously create resources on the workers. """Synchronously create resources on the workers.
The resources are represented by `RemoteValue`s. The resources are represented by
`tf.distribute.experimental.coordinator.RemoteValue`s.
Args: Args:
fn: The function to be dispatched to all workers for execution fn: The function to be dispatched to all workers for execution
@ -1020,7 +1175,9 @@ class Client(object):
kwargs: Keyword arguments for `fn`. kwargs: Keyword arguments for `fn`.
Returns: Returns:
A `PerWorkerValues` object, which wraps a tuple of `RemoteValue` objects. A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
objects.
""" """
results = [] results = []
for w in self.cluster.workers: for w in self.cluster.workers:
@ -1028,21 +1185,52 @@ class Client(object):
return PerWorkerValues(tuple(results)) return PerWorkerValues(tuple(results))
def fetch(self, val): def fetch(self, val):
"""Blocking call to fetch results from `RemoteValue`s. """Blocking call to fetch results from the remote values.
This returns the execution result of `RemoteValue`s; if not ready, This is a wrapper around
waiting for it while blocking the caller. `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
`RemoteValue` structure; it returns the execution results of
`RemoteValue`s. If not ready, wait for them while blocking the caller.
Example:
```python
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),))
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
```
Args: Args:
val: The value to fetch the results from. If this is structure of val: The value to fetch the results from. If this is structure of
`RemoteValue`, `fetch()` will be called on the individual `RemoteValue` `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
to get the result. called on the individual
`tf.distribute.experimental.coordinator.RemoteValue` to get the result.
Returns: Returns:
If `val` is a `RemoteValue` or a structure of `RemoteValue`s, returns If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
the fetched `RemoteValue` value immediately if it's available, or blocks structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
the call until it's available, and returns the fetched `RemoteValue` return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
values with the same structure. If `val` is other types, return (`val`,). values immediately if they are available, or block the call until they are
available, and return the fetched
`tf.distribute.experimental.coordinator.RemoteValue` values with the same
structure. If `val` is other types, return it as-is.
""" """
def _maybe_fetch(val): def _maybe_fetch(val):
@ -1052,10 +1240,7 @@ class Client(object):
return val return val
# TODO(yuefengz): we should fetch values in a batch. # TODO(yuefengz): we should fetch values in a batch.
result = nest.map_structure(_maybe_fetch, val) return nest.map_structure(_maybe_fetch, val)
if not isinstance(result, tuple):
return (result,)
return result
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
@ -1075,13 +1260,14 @@ def handle_parameter_server_failure():
class _PerWorkerDistributedDataset(object): class _PerWorkerDistributedDataset(object):
"""Represents worker-distributed datasets created from dataset function.""" """Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, input_workers, client): def __init__(self, dataset_fn, input_workers, coordinator):
"""Makes an iterable from datasets created by the given function. """Makes an iterable from datasets created by the given function.
Args: Args:
dataset_fn: A function that returns a `Dataset`. dataset_fn: A function that returns a `Dataset`.
input_workers: an `InputWorkers` object. input_workers: an `InputWorkers` object.
client: a `Client` object, used to create dataset resources. coordinator: a `ClusterCoordinator` object, used to create dataset
resources.
""" """
def disallow_variable_creation(next_creator, **kwargs): def disallow_variable_creation(next_creator, **kwargs):
raise ValueError("Creating variables in `dataset_fn` is not allowed.") raise ValueError("Creating variables in `dataset_fn` is not allowed.")
@ -1094,7 +1280,7 @@ class _PerWorkerDistributedDataset(object):
dataset_fn = def_function.function(dataset_fn).get_concrete_function() dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._input_workers = input_workers self._input_workers = input_workers
self._client = client self._coordinator = coordinator
self._element_spec = None self._element_spec = None
def __iter__(self): def __iter__(self):
@ -1112,7 +1298,7 @@ class _PerWorkerDistributedDataset(object):
# If _PerWorkerDistributedDataset.__iter__ is called multiple # If _PerWorkerDistributedDataset.__iter__ is called multiple
# times, for the same object it should only create and register resource # times, for the same object it should only create and register resource
# once. Using object id to distinguish different iterator resources. # once. Using object id to distinguish different iterator resources.
per_worker_iterator = self._client._create_per_worker_resources( per_worker_iterator = self._coordinator._create_per_worker_resources(
_create_per_worker_iterator) _create_per_worker_iterator)
# Setting type_spec of each RemoteValue so that functions taking these # Setting type_spec of each RemoteValue so that functions taking these
@ -1131,7 +1317,7 @@ class _PerWorkerDistributedDataset(object):
class _PerWorkerDistributedIterator(PerWorkerValues): class _PerWorkerDistributedIterator(PerWorkerValues):
"""Distributed iterator for `Client`.""" """Distributed iterator for `ClusterCoordinator`."""
def __next__(self): def __next__(self):
return self.get_next() return self.get_next()

View File

@ -13,21 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`.""" """Multi-process runner tests for `ClusterCoordinator` with PSv2."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time import time
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.client import utils
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import utils
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -38,7 +37,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
class ClientMprTest(test.TestCase): class ClusterCoordinatorMprTest(test.TestCase):
def testScheduleTranslatePSFailureError(self): def testScheduleTranslatePSFailureError(self):
self._test_translate_ps_failure_error(test_schedule=True) self._test_translate_ps_failure_error(test_schedule=True)
@ -56,7 +55,7 @@ class ClientMprTest(test.TestCase):
utils.start_server(cluster_resolver, "grpc") utils.start_server(cluster_resolver, "grpc")
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver) cluster_resolver)
ps_client = client_lib.Client(strategy) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)
with strategy.scope(): with strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int32) v = variables.Variable(initial_value=0, dtype=dtypes.int32)
@ -68,19 +67,19 @@ class ClientMprTest(test.TestCase):
v.assign_add(1) v.assign_add(1)
# Keep the two workers occupied. # Keep the two workers occupied.
ps_client.schedule(worker_fn) ps_coordinator.schedule(worker_fn)
ps_client.schedule(worker_fn) ps_coordinator.schedule(worker_fn)
# Now the main process can terminate. # Now the main process can terminate.
functions_scheduled_event.set() functions_scheduled_event.set()
# Verified that join and schedule indeed raise UnavailableError. # Verified that join and schedule indeed raise UnavailableError.
try: try:
if test_join: if test_join:
ps_client.join() ps_coordinator.join()
if test_schedule: if test_schedule:
while ps_client.cluster._closure_queue._error is None: while ps_coordinator.cluster._closure_queue._error is None:
time.sleep(1) time.sleep(1)
ps_client.schedule(worker_fn) ps_coordinator.schedule(worker_fn)
except errors.UnavailableError: except errors.UnavailableError:
# The following verifies that after PS fails, continue executing # The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure. # functions on workers should fail and indicate it's PS failure.
@ -91,7 +90,7 @@ class ClientMprTest(test.TestCase):
# failure. # failure.
worker_fn() worker_fn()
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
if client_lib._is_ps_failure(e): if coordinator_lib._is_ps_failure(e):
if worker_id < 2: if worker_id < 2:
continue continue
logging.info("_test_translate_ps_failure_error ends properly.") logging.info("_test_translate_ps_failure_error ends properly.")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for client.py.""" """Tests for coordinator.py."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -29,8 +29,8 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.eager import cancellation from tensorflow.python.eager import cancellation
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
@ -52,7 +52,7 @@ from tensorflow.python.util import nest
class CoordinatedClosureQueueTest(test.TestCase): class CoordinatedClosureQueueTest(test.TestCase):
def testBasic(self): def testBasic(self):
queue = client_lib._CoordinatedClosureQueue() queue = coordinator_lib._CoordinatedClosureQueue()
closure1 = self._create_closure(queue._cancellation_mgr) closure1 = self._create_closure(queue._cancellation_mgr)
queue.put(closure1) queue.put(closure1)
self.assertIs(closure1, queue.get()) self.assertIs(closure1, queue.get())
@ -64,7 +64,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
queue.wait() queue.wait()
def testProcessAtLeaseOnce(self): def testProcessAtLeaseOnce(self):
closure_queue = client_lib._CoordinatedClosureQueue() closure_queue = coordinator_lib._CoordinatedClosureQueue()
labels = ['A', 'B', 'C', 'D', 'E'] labels = ['A', 'B', 'C', 'D', 'E']
processed_count = collections.defaultdict(int) processed_count = collections.defaultdict(int)
@ -94,7 +94,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
cm = cancellation.CancellationManager() cm = cancellation.CancellationManager()
for label in labels: for label in labels:
closure_queue.put(client_lib.Closure(get_func(label), cm)) closure_queue.put(coordinator_lib.Closure(get_func(label), cm))
t1 = threading.Thread(target=process_queue, daemon=True) t1 = threading.Thread(target=process_queue, daemon=True)
t1.start() t1.start()
t2 = threading.Thread(target=process_queue, daemon=True) t2 = threading.Thread(target=process_queue, daemon=True)
@ -111,7 +111,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
coord.join([t1, t2]) coord.join([t1, t2])
def testNotifyBeforeWait(self): def testNotifyBeforeWait(self):
closure_queue = client_lib._CoordinatedClosureQueue() closure_queue = coordinator_lib._CoordinatedClosureQueue()
def func(): def func():
logging.info('func running') logging.info('func running')
@ -123,7 +123,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
closure_queue.get() closure_queue.get()
closure_queue.mark_finished() closure_queue.mark_finished()
closure_queue.put(client_lib.Closure(func, closure_queue._cancellation_mgr)) closure_queue.put(
coordinator_lib.Closure(func, closure_queue._cancellation_mgr))
t = threading.Thread(target=process_queue) t = threading.Thread(target=process_queue)
t.start() t.start()
coord.join([t]) coord.join([t])
@ -159,7 +160,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
# TODO(b/165013260): Fix this # TODO(b/165013260): Fix this
self.skipTest('Test is currently broken on Windows with Python 3.8') self.skipTest('Test is currently broken on Windows with Python 3.8')
closure_queue = client_lib._CoordinatedClosureQueue() closure_queue = coordinator_lib._CoordinatedClosureQueue()
closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
closure = closure_queue.get() closure = closure_queue.get()
@ -189,10 +190,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
def some_function(): def some_function():
return 1.0 return 1.0
return client_lib.Closure(some_function, cancellation_mgr) return coordinator_lib.Closure(some_function, cancellation_mgr)
def _put_two_closures_and_get_one(self): def _put_two_closures_and_get_one(self):
closure_queue = client_lib._CoordinatedClosureQueue() closure_queue = coordinator_lib._CoordinatedClosureQueue()
closure1 = self._create_closure(closure_queue._cancellation_mgr) closure1 = self._create_closure(closure_queue._cancellation_mgr)
closure_queue.put(closure1) closure_queue.put(closure1)
@ -330,7 +331,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
# Closure2 was an inflight closure when it got cancelled. # Closure2 was an inflight closure when it got cancelled.
self.assertEqual(closure2._output_remote_values._status, self.assertEqual(closure2._output_remote_values._status,
client_lib._RemoteValueStatus.READY) coordinator_lib._RemoteValueStatus.READY)
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'): with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
closure2._fetch_output_remote_values() closure2._fetch_output_remote_values()
@ -361,7 +362,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
def testThreadSafey(self): def testThreadSafey(self):
thread_count = 10 thread_count = 10
queue = client_lib._CoordinatedClosureQueue() queue = coordinator_lib._CoordinatedClosureQueue()
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are # Each thread performs 20 queue actions: 10 are `put_back` and 10 are
# `mark_finished`. # `mark_finished`.
@ -427,7 +428,7 @@ class TestCaseWithErrorReportingThread(test.TestCase):
raise ErrorReportingThread.error # pylint: disable=raising-bad-type raise ErrorReportingThread.error # pylint: disable=raising-bad-type
def make_client(num_workers, num_ps): def make_coordinator(num_workers, num_ps):
# TODO(rchao): Test the internal rpc_layer version. # TODO(rchao): Test the internal rpc_layer version.
cluster_def = multi_worker_test_base.create_in_process_cluster( cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
@ -438,16 +439,16 @@ def make_client(num_workers, num_ps):
ClusterSpec(cluster_def), rpc_layer='grpc') ClusterSpec(cluster_def), rpc_layer='grpc')
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver) cluster_resolver)
return client_lib.Client(strategy) return coordinator_lib.ClusterCoordinator(strategy)
class ClientTest(TestCaseWithErrorReportingThread): class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(ClientTest, cls).setUpClass() super(ClusterCoordinatorTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2) cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy cls.strategy = cls.coordinator.strategy
def testFnReturnNestedValues(self): def testFnReturnNestedValues(self):
x = constant_op.constant(1) x = constant_op.constant(1)
@ -456,9 +457,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
def f(): def f():
return x + 1, (x + 2, x + 3), [x + 4], {'v': x} return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
got = self.client.schedule(f) got = self.coordinator.schedule(f)
want = 2, (3, 4), [5], {'v': 1} want = 2, (3, 4), [5], {'v': 1}
self.assertEqual(self.client.fetch(got), want) self.assertEqual(self.coordinator.fetch(got), want)
def testInputFunction(self): def testInputFunction(self):
@ -474,12 +475,14 @@ class ClientTest(TestCaseWithErrorReportingThread):
v.assign_add(x) v.assign_add(x)
return x return x
distributed_dataset = self.client.create_per_worker_dataset(input_fn) distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),)) result = self.coordinator.schedule(
result = self.client.fetch(result) worker_fn, args=(iter(distributed_dataset),))
result = self.coordinator.fetch(result)
self.assertEqual(result, (1,)) self.assertEqual(result, (1,))
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),)) result = self.coordinator.schedule(
result = self.client.fetch(result) worker_fn, args=(iter(distributed_dataset),))
result = self.coordinator.fetch(result)
self.assertEqual(result, (1,)) self.assertEqual(result, (1,))
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6) self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
@ -499,30 +502,30 @@ class ClientTest(TestCaseWithErrorReportingThread):
x = next(iterator) x = next(iterator)
v.assign_add(x) v.assign_add(x)
distributed_dataset = self.client.create_per_worker_dataset(input_fn) distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
iterator = iter(distributed_dataset) iterator = iter(distributed_dataset)
# Verifying joining without any scheduling doesn't hang. # Verifying joining without any scheduling doesn't hang.
self.client.join() self.coordinator.join()
self.assertEqual(v.read_value().numpy(), 0) self.assertEqual(v.read_value().numpy(), 0)
for _ in range(5): for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,)) self.coordinator.schedule(worker_fn, args=(iterator,))
self.client.join() self.coordinator.join()
# With 5 addition it should be 2*5 = 10. # With 5 addition it should be 2*5 = 10.
self.assertEqual(v.read_value().numpy(), 10) self.assertEqual(v.read_value().numpy(), 10)
for _ in range(5): for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,)) self.coordinator.schedule(worker_fn, args=(iterator,))
# Verifying multiple join is fine. # Verifying multiple join is fine.
self.client.join() self.coordinator.join()
self.client.join() self.coordinator.join()
self.client.join() self.coordinator.join()
self.assertTrue(self.client.done()) self.assertTrue(self.coordinator.done())
# Likewise, it's now 20. # Likewise, it's now 20.
self.assertEqual(v.read_value().numpy(), 20) self.assertEqual(v.read_value().numpy(), 20)
@ -542,8 +545,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
def worker_fn(iterator): def worker_fn(iterator):
return next(iterator) return next(iterator)
distributed_dataset = self.client.create_per_worker_dataset(input_fn) distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),)) result = self.coordinator.schedule(
worker_fn, args=(iter(distributed_dataset),))
self.assertEqual(result.fetch(), (10,)) self.assertEqual(result.fetch(), (10,))
self.assertEqual(self._map_fn_tracing_count, 1) self.assertEqual(self._map_fn_tracing_count, 1)
@ -554,28 +558,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
return v.read_value() return v.read_value()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.client.create_per_worker_dataset(input_fn) self.coordinator.create_per_worker_dataset(input_fn)
def testDatasetsShuffledDifferently(self): def testDatasetsShuffledDifferently(self):
# This test requires at least two workers in the cluster. # This test requires at least two workers in the cluster.
self.assertGreaterEqual(len(self.client.cluster.workers), 2) self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
random_seed.set_random_seed(None) random_seed.set_random_seed(None)
def input_fn(): def input_fn():
return dataset_ops.DatasetV2.range(0, 100).shuffle(100) return dataset_ops.DatasetV2.range(0, 100).shuffle(100)
distributed_dataset = self.client.create_per_worker_dataset(input_fn) distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
distributed_iterator = iter(distributed_dataset) distributed_iterator = iter(distributed_dataset)
# Get elements from the first two iterators. # Get elements from the first two iterators.
iterator_1 = distributed_iterator._values[0] iterator_1 = distributed_iterator._values[0]
iterator_1._rebuild_on(self.client.cluster.workers[0]) iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
iterator_1 = iterator_1.fetch() iterator_1 = iterator_1.fetch()
elements_in_iterator_1 = [e.numpy() for e in iterator_1] elements_in_iterator_1 = [e.numpy() for e in iterator_1]
iterator_2 = distributed_iterator._values[1] iterator_2 = distributed_iterator._values[1]
iterator_2._rebuild_on(self.client.cluster.workers[1]) iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
iterator_2 = iterator_2.fetch() iterator_2 = iterator_2.fetch()
elements_in_iterator_2 = [e.numpy() for e in iterator_2] elements_in_iterator_2 = [e.numpy() for e in iterator_2]
@ -592,12 +596,12 @@ class ClientTest(TestCaseWithErrorReportingThread):
self.assertIn('worker', var.device) self.assertIn('worker', var.device)
return var return var
worker_local_var = self.client._create_per_worker_resources(create_var) worker_local_var = self.coordinator._create_per_worker_resources(create_var)
# The following is a workaround to allow `worker_local_var` to be passed in # The following is a workaround to allow `worker_local_var` to be passed in
# as args to the `client.schedule` method which requires tensor specs to # as args to the `coordinator.schedule` method which requires tensor specs
# trace tf.function but _create_worker_resources' return values don't have # to trace tf.function but _create_worker_resources' return values don't
# tensor specs. We can get rid of this workaround once # have tensor specs. We can get rid of this workaround once
# _create_worker_resources is able to infer the tensor spec of the return # _create_worker_resources is able to infer the tensor spec of the return
# value of the function passed in. See b/154675763. # value of the function passed in. See b/154675763.
for var in worker_local_var._values: for var in worker_local_var._values:
@ -609,10 +613,10 @@ class ClientTest(TestCaseWithErrorReportingThread):
for _ in range(10): for _ in range(10):
# Which slice of `worker_local_var` will be used will depend on which # Which slice of `worker_local_var` will be used will depend on which
# worker the `worker_fn` gets scheduled on. # worker the `worker_fn` gets scheduled on.
self.client.schedule(worker_fn, args=(worker_local_var,)) self.coordinator.schedule(worker_fn, args=(worker_local_var,))
self.client.join() self.coordinator.join()
var_sum = sum(self.client.fetch(worker_local_var._values)) var_sum = sum(self.coordinator.fetch(worker_local_var._values))
self.assertEqual(var_sum, 10.0) self.assertEqual(var_sum, 10.0)
def testDisallowRemoteValueAsInput(self): def testDisallowRemoteValueAsInput(self):
@ -625,28 +629,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
def func_1(x): def func_1(x):
return x + 1.0 return x + 1.0
remote_v = self.client.schedule(func_0) remote_v = self.coordinator.schedule(func_0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.client.schedule(func_1, args=(remote_v,)) self.coordinator.schedule(func_1, args=(remote_v,))
class LimitedClosureQueueSizeBasicTest(ClientTest): class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
"""Test basic functionality works with explicit maximum closure queue size. """Test basic functionality works with explicit maximum closure queue size.
Execute the same set of test cases as in `ClientTest`, with an Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
explicit size limit for the closure queue. Note that even when the queue size explicit size limit for the closure queue. Note that even when the queue size
is set to infinite, there is still a maximum practical size (depends on host is set to infinite, there is still a maximum practical size (depends on host
memory limit) that might cause the queue.put operations to be blocking when memory limit) that might cause the queue.put operations to be blocking when
scheduling a large number of closures on a big cluster. These tests make sure scheduling a large number of closures on a big cluster. These tests make sure
that the client does not run into deadlocks in such scenario. that the coordinator does not run into deadlocks in such scenario.
""" """
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass() super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2) cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy cls.strategy = cls.coordinator.strategy
class ErrorReportingTest(TestCaseWithErrorReportingThread): class ErrorReportingTest(TestCaseWithErrorReportingThread):
@ -654,8 +658,8 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(ErrorReportingTest, cls).setUpClass() super(ErrorReportingTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2) cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy cls.strategy = cls.coordinator.strategy
with cls.strategy.scope(): with cls.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0) cls.iteration = variables.Variable(initial_value=0.0)
@ -686,80 +690,80 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
def testJoinRaiseError(self): def testJoinRaiseError(self):
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.client.join() self.coordinator.join()
def testScheduleRaiseError(self): def testScheduleRaiseError(self):
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
while True: while True:
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
def testScheduleRaiseErrorWithMultipleFailure(self): def testScheduleRaiseErrorWithMultipleFailure(self):
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
while True: while True:
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
self.client.join() self.coordinator.join()
def testErrorWillbeCleared(self): def testErrorWillbeCleared(self):
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.client.join() self.coordinator.join()
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.client.join() self.coordinator.join()
def testRemoteValueReturnError(self): def testRemoteValueReturnError(self):
result = self.client.schedule(self._error_function) result = self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
result.fetch() result.fetch()
# Clear the error. # Clear the error.
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.client.join() self.coordinator.join()
def testInputError(self): def testInputError(self):
worker_local_val = self.client._create_per_worker_resources( worker_local_val = self.coordinator._create_per_worker_resources(
self._error_function) self._error_function)
@def_function.function @def_function.function
def func(x): def func(x):
return x + 1 return x + 1
result = self.client.schedule(func, args=(worker_local_val,)) result = self.coordinator.schedule(func, args=(worker_local_val,))
with self.assertRaises(client_lib.InputError): with self.assertRaises(coordinator_lib.InputError):
self.client.join() self.coordinator.join()
with self.assertRaises(client_lib.InputError): with self.assertRaises(coordinator_lib.InputError):
result.fetch() result.fetch()
def testCancellation(self): def testCancellation(self):
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
long_function = self.client.schedule(self._long_function) long_function = self.coordinator.schedule(self._long_function)
self.client.schedule(self._error_function) self.coordinator.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.client.join() self.coordinator.join()
with self.assertRaises(errors.CancelledError): with self.assertRaises(errors.CancelledError):
long_function.fetch() long_function.fetch()
for _ in range(3): for _ in range(3):
self.client.schedule(self._normal_function) self.coordinator.schedule(self._normal_function)
self.client.join() self.coordinator.join()
class LimitedClosureQueueErrorTest(ErrorReportingTest): class LimitedClosureQueueErrorTest(ErrorReportingTest):
@ -772,11 +776,11 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(LimitedClosureQueueErrorTest, cls).setUpClass() super(LimitedClosureQueueErrorTest, cls).setUpClass()
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2) cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy cls.strategy = cls.coordinator.strategy
with cls.client.strategy.scope(): with cls.coordinator.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0) cls.iteration = variables.Variable(initial_value=0.0)
@ -785,8 +789,8 @@ class StrategyIntegrationTest(test.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(StrategyIntegrationTest, cls).setUpClass() super(StrategyIntegrationTest, cls).setUpClass()
cls.client = make_client(num_workers=1, num_ps=1) cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
cls.strategy = cls.client.strategy cls.strategy = cls.coordinator.strategy
def testBasicVariableAssignment(self): def testBasicVariableAssignment(self):
self.strategy.extended._variable_count = 0 self.strategy.extended._variable_count = 0
@ -801,9 +805,9 @@ class StrategyIntegrationTest(test.TestCase):
v2.assign_sub(0.2) v2.assign_sub(0.2)
return v1.read_value() / v2.read_value() return v1.read_value() / v2.read_value()
results = self.client.schedule(worker_fn) results = self.coordinator.schedule(worker_fn)
logging.info('Results of experimental_run_v2: %f', logging.info('Results of experimental_run_v2: %f',
self.client.fetch(results)) self.coordinator.fetch(results))
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6) self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6) self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
@ -826,12 +830,14 @@ class StrategyIntegrationTest(test.TestCase):
return self.strategy.run(replica_fn, args=(input_tensor,)) return self.strategy.run(replica_fn, args=(input_tensor,))
# Asserting scheduling in scope has the expected behavior. # Asserting scheduling in scope has the expected behavior.
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),)) result = self.coordinator.schedule(
self.assertIsInstance(result, client_lib.RemoteValue) worker_fn, args=(constant_op.constant(3),))
self.assertIsInstance(result, coordinator_lib.RemoteValue)
self.assertEqual(result.fetch(), 4) self.assertEqual(result.fetch(), 4)
# Asserting scheduling out of scope has the expected behavior. # Asserting scheduling out of scope has the expected behavior.
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),)) result = self.coordinator.schedule(
worker_fn, args=(constant_op.constant(3),))
self.assertEqual(result.fetch(), 4) self.assertEqual(result.fetch(), 4)

View File

@ -31,15 +31,15 @@ enable_metrics = False
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6) _time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
_function_tracing_sampler = monitoring.Sampler( _function_tracing_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets, '/tensorflow/api/ps_strategy/coordinator/function_tracing', _time_buckets,
'Sampler to track the time (in seconds) for tracing functions.') 'Sampler to track the time (in seconds) for tracing functions.')
_closure_execution_sampler = monitoring.Sampler( _closure_execution_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets, '/tensorflow/api/ps_strategy/coordinator/closure_execution', _time_buckets,
'Sampler to track the time (in seconds) for executing closures.') 'Sampler to track the time (in seconds) for executing closures.')
_remote_value_fetch_sampler = monitoring.Sampler( _remote_value_fetch_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets, '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', _time_buckets,
'Sampler to track the time (in seconds) for fetching remote_value.') 'Sampler to track the time (in seconds) for fetching remote_value.')
_METRICS_MAPPING = { _METRICS_MAPPING = {

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for metrics collecting in client.""" """Tests for metrics collecting in coordinator."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -22,9 +22,9 @@ from __future__ import print_function
import time import time
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import metric_utils
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import metric_utils
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import ClusterSpec
@ -35,7 +35,7 @@ class MetricUtilsTest(test.TestCase):
def get_rpc_layer(self): def get_rpc_layer(self):
return 'grpc' return 'grpc'
def testClientMetrics(self): def testClusterCoordinatorMetrics(self):
metric_utils.enable_metrics = True metric_utils.enable_metrics = True
@ -48,7 +48,7 @@ class MetricUtilsTest(test.TestCase):
ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer()) ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver) cluster_resolver)
cluster = client.Cluster(strategy) cluster = coordinator_lib.Cluster(strategy)
@def_function.function @def_function.function
def func(): def func():

View File

@ -48,7 +48,7 @@ _LOCAL_CPU = "/device:CPU:0"
# TODO(yuefengz): maybe cache variables on local CPU. # TODO(yuefengz): maybe cache variables on local CPU.
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) # TODO(b/171250971): Remove this and change all symbol usage of this to V1.
class ParameterServerStrategy(distribute_lib.Strategy): class ParameterServerStrategy(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy. """An asynchronous multi-worker parameter server tf.distribute strategy.
@ -154,8 +154,9 @@ class ParameterServerStrategy(distribute_lib.Strategy):
def _raise_pss_error_if_eager(self): def _raise_pss_error_if_eager(self):
if context.executing_eagerly(): if context.executing_eagerly():
raise NotImplementedError("ParameterServerStrategy currently only works " raise NotImplementedError(
"with the tf.Estimator API") "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
"currently only works with the tf.Estimator API")
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring @tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring

View File

@ -37,82 +37,400 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access @tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
class ParameterServerStrategyV2(distribute_lib.Strategy): class ParameterServerStrategyV2(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy. """An multi-worker tf.distribute strategy with parameter servers.
Currently, `ParameterServerStrategyV2` is not supported to be used as a Parameter server training refers to the distributed training architecture that
standalone tf.distribute strategy. It should be used in conjunction with requires two types of tasks in the cluster: workers (referred to as "worker"
`Client`. Please see `Client` for more information. task) and parameter servers (referred to as "ps" task). The variables and
updates to those variables are placed on ps, and most computation intensive
operations are placed on workers.
This is currently under development, and the API as well as implementation In TF2, parameter server training makes use of one coordinator, with some
is subject to changes. number of workers, and (usually fewer) ps. The coordinator uses a
`tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
cluster, and a `tf.distribute.experimental.ParameterServerStrategy` for
variable distribution. The coordinator does not perform the actual training.
Each of the workers and ps runs a `tf.distribute.Server`, which the
coordinator connects to through the use of aforementioned two APIs.
For the training to work, the coordinator sends requests to workers for the
`tf.function`s to be executed on remote workers. Upon receiving requests from
the coordinator, a worker executes the `tf.function` by reading the variables
from parameter servers, executing the ops, and updating the variables on the
parameter servers. Each of the worker only processes the requests from the
coordinator, and communicates with parameter servers, without direct
interactions with any of the other workers in the cluster.
As a result, failures of some workers do not prevent the cluster from
continuing the work, and this allows the cluster to train with instances that
can be occasionally unavailable (e.g. preemptible or spot instances). The
coordinator and parameter servers though, must be available at all times for
the cluster to make progress.
Note that the coordinator is not one of the training worker. Instead, its
responsibility includes placing variables on ps, remotely executing
`tf.function`s on workers, and saving checkpoints. Parameter server training
thus consists of a server cluster with worker and ps, and a coordinator which
connects to them to coordinate. Optionally, an evaluator can be run on the
side that periodically reads the checkpoints saved by the coordinator, and
saves summaries for example.
`tf.distribute.experimental.ParameterServerStrategy` works closely with the
associated `tf.distribute.experimental.coordinator.ClusterCoordinator` object,
and should be used in conjunction with it. Standalone usage of
`tf.distribute.experimental.ParameterServerStrategy` without a
`tf.distribute.experimental.coordinator.ClusterCoordinator` indicates
a parameter server training scheme without a centralized coordinator, which is
not supported at this time.
__Example code for coordinator__
Here's an example usage of the API, with a custom training loop to train a
model. This code snippet is intended to be run on (the only) one machine that
is designated as the coordinator. Note that `cluster_resolver`,
`variable_partitioner`, and `dataset_fn` arguments are explained in the
following "Cluster setup", "Variable partitioning", and "Dataset preparation"
sections.
Currently, environment variable `GRPC_FAIL_FAST` needs to be set in all tasks
to work around a known hanging issue as the following code illustrates:
```python
# Set the environment variable to allow reporting worker and ps failure to the
# coordinator.
os.environ["GRPC_FAIL_FAST"] = "use_caller"
# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...,
variable_partitioner=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
# Prepare a distribute dataset that will place datasets on the workers.
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
with strategy.scope():
model = ... # Variables created can possibly be container of variables
optimizer, metrics = ... # Keras optimizer/metrics are great choices
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint_dir, max_to_keep=2)
# `load_checkpoint` infers initial epoch from `optimizer.iterations`.
initial_epoch = load_checkpoint(checkpoint_manager) or 0
@tf.function
def worker_fn(iterator):
def replica_fn(inputs):
batch_data, labels = inputs
# calculate gradient, applying gradient, metrics update etc.
strategy.run(replica_fn, args=(next(iterator),))
for epoch in range(initial_epoch, num_epoch):
distributed_iterator = iter(distributed_dataset) # Reset iterator state.
for step in range(steps_per_epoch):
# Asynchronously schedule the `worker_fn` to be executed on an arbitrary
# worker. This call returns immediately.
coordinator.schedule(worker_fn, args=(distributed_iterator,))
# `join` blocks until all scheduled `worker_fn`s finish execution. Once it
# returns, we can read the metrics and save checkpoints as needed.
coordinator.join()
logging.info('Metric result: %r', metrics.result())
train_accuracy.reset_states()
checkpoint_manager.save()
```
__Example code for worker and parameter servers__
In addition to the coordinator, there should be multiple machines designated
as "worker" or "ps". They should run the following code to start a TensorFlow
server, waiting for coordinator's request to execute functions or place
variables:
```python
# Set the environment variable to allow reporting worker and ps failure to the
# coordinator.
os.environ["GRPC_FAIL_FAST"] = "use_caller"
# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
# the cluster information. See below "Cluster setup" section.
cluster_resolver = ...
server = tf.distribute.Server(
cluster_resolver.cluster_spec().as_cluster_def(),
job_name=cluster_resolver.task_type,
task_index=cluster_resolver.task_id,
protocol=protocol)
# Blocking the process that starts a server from exiting.
server.join()
```
__Cluster setup__
In order for the tasks in the cluster to know other tasks' addresses,
a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
in coordinator, worker, and ps. The
`tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
the cluster information, as well as the task type and id of the current task.
See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
If `TF_CONFIG` environment variable is used for the processes to know the
cluster information, a
`tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used. Note
that for legacy reason, "chief" should be used as the task type for the
coordinator, as the following example demonstrates. Here we set `TF_CONFIG`
in environment variable, intended to be run by the process of the machine
designated as the parameter server (task type "ps") and index 1 (the second),
in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that the
it needs to be set before the use of
`tf.distribute.cluster_resolver.TFConfigClusterResolver`.
Example code for cluster setup:
```python
os.environ['TF_CONFIG'] = '''
{
"cluster": {
"chief": ["chief.example.com:2222"],
"ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
"worker": ["worker0.example.com:2222", "worker1.example.com:2222",
"worker2.example.com:2222"]
},
"task": {
"type": "ps",
"index": 1
}
}
'''
os.environ["GRPC_FAIL_FAST"] = "use_caller"
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
# If coordinator ("chief" task type), create a strategy
if cluster_resolver.task_type == 'chief':
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
...
# If worker/ps, create a server
elif cluster_resolver.task_type in ("worker", "ps"):
server = tf.distribute.Server(...)
...
```
__Variable creation with `strategy.scope()`__
`tf.distribute.experimental.ParameterServerStrategy` follows the
`tf.distribute` API contract where variable creation is expected to be inside
the context manager returned by `strategy.scope()`, in order to be correctly
placed on parameter servers in a round-robin manner:
```python
# In this example, we're assuming having 3 ps.
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
# Variables should be created inside scope to be placed on parameter servers.
# If created outside scope such as `v1` here, it would be placed on
coordinator.
v1 = tf.Variable(initial_value=0.0)
with strategy.scope():
v2 = tf.Variable(initial_value=1.0)
v3 = tf.Variable(initial_value=2.0)
v4 = tf.Variable(initial_value=3.0)
v5 = tf.Variable(initial_value=4.0)
# v2 through v5 are created in scope and are distributed on parameter servers.
# Default placement is round-robin but the order should not be relied on.
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
```
See `distribute.Strategy.scope` for more information.
__Variable partitioning__
Having dedicated servers to store variables means being able to divide up, or
"shard" the variables across the ps. Large embeddings that would otherwise
exceed memory limit of a single machine can be used in a cluster with enough
number of ps.
With `tf.distribute.experimental.ParameterServerStrategy`, if a
`variable_partitioner` is provided to `__init__` and certain conditions are
satisfied, the resulting variables created in scope are sharded across the
parameter servers, in a round-robin fashion. The variable reference returned
from `tf.Variable` becomes a type that serves as the container of the sharded
variables. Access `variables` attribute of this container for the actual
variable components. See arguments section of
`tf.distribute.experimental.ParameterServerStrategy.__init__` for more
information.
To initialize the sharded variables in a more memory-efficient way, use an
initializer whose `__call__` accepts a `shard_info` argument, and use
`shard_info.offset` and `shard_info.shape` to create and return a
partition-aware `tf.Tensor` to initialize the variable components.
```python
class PartitionAwareIdentity(object):
def __call__(self, shape, dtype, shard_info):
value = tf.eye(*shape, dtype=dtype)
if shard_info is not None:
value = tf.slice(value, shard_info.offset, shard_info.shape)
return value
cluster_resolver = ...
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver, tf.fixed_size_partitioner(2))
with strategy.scope():
initializer = PartitionAwareIdentity()
initial_value = functools.partial(initializer, shape=(4, 4), dtype=tf.int64)
v = tf.Variable(
initial_value=initial_value, shape=(4, 4), dtype=tf.int64)
# `v.variables` gives the actual variable components.
assert len(v.variables) == 2
assert v.variables[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v.variables[1].device == "/job:ps/replica:0/task:1/device:CPU:0"
assert np.array_equal(v.variables[0].numpy(), [[1, 0, 0, 0], [0, 1, 0, 0]])
assert np.array_equal(v.variables[1].numpy(), [[0, 0, 1, 0], [0, 0, 0, 1]])
```
__Dataset preparation__
With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
created in each of the workers to be used for training. This is done by
creating a `dataset_fn` that takes no argument and returns a
`tf.data.Dataset`, and passing the `dataset_fn` into
`tf.distribute.experimental.coordinator.
ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
shuffled and repeated to have the examples run through the training as evenly
as possible.
```python
def dataset_fn():
filenames = ...
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Dataset is recommended to be shuffled, and repeated.
return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
coordinator =
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
```
__Limitations__
* `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
and the API is subject to further changes.
* `tf.distribute.experimental.ParameterServerStrategy` does not yet support
training with GPU(s). This is a feature request being developed.
* `tf.distribute.experimental.ParameterServerStrategy` only supports
[custom training loop
API](https://www.tensorflow.org/tutorials/distribute/custom_training)
currently in TF2. Usage of it with Keras `compile`/`fit` API is being
developed.
* `tf.distribute.experimental.ParameterServerStrategy` must be used with
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
* This strategy is not intended for TPU. Use
`tf.distribute.experimental.TPUStrategy` instead.
""" """
# pyformat: disable
def __init__(self, cluster_resolver, variable_partitioner=None): def __init__(self, cluster_resolver, variable_partitioner=None):
"""Initializes the V2 parameter server strategy. """Initializes the TF2 parameter server strategy.
This also connects to the remote server cluster. This initializes the `tf.distribute.experimental.ParameterServerStrategy`
object to be ready for use with
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
Args: Args:
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
object. object.
variable_partitioner: a callable with the signature `num_partitions = variable_partitioner:
fn(shape, dtype)`, where `num_partitions` is a list/tuple representing a callable with the signature `num_partitions = fn(shape, dtype)`, where
the number of partitions on each axis, and `shape` and `dtype` are of `num_partitions` is a list/tuple representing the number of partitions
types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will on each axis, and `shape` and `dtype` are of types `tf.TensorShape` and
not be partitioned. * `variable_partitioner` will be called for all `tf.dtypes.Dtype`. If `None`, variables will not be partitioned.
variables created under strategy `scope` to instruct how the variables
should be partitioned. Variables will be partitioned if there are more * `variable_partitioner` will be called for all variables created under
than one partitions along the partitioning axis, otherwise it falls back strategy `scope` to instruct how the variables should be partitioned.
to normal `tf.Variable`. * Only the first / outermost axis partitioning Variables will be created in multiple partitions if there are more than
is supported, namely, elements in `num_partitions` must be 1 other than one partition along the partitioning axis, otherwise it falls back to
the first element. * Partitioner like `min_max_variable_partitioner`, normal `tf.Variable`.
`variable_axis_size_partitioner` and `fixed_size_partitioner` are also
supported since they conform to the required signature. * Div partition * Only the first / outermost axis partitioning is supported, namely,
elements in `num_partitions` must be 1 other than the first element.
* Partitioner like `tf.compat.v1.min_max_variable_partitioner`,
`tf.compat.v1.variable_axis_size_partitioner` and
`tf.compat.v1.fixed_size_partitioner` are also supported since they
conform to the required signature.
* Div partition
strategy is used to partition variables. Assuming we assign consecutive strategy is used to partition variables. Assuming we assign consecutive
integer ids along the first axis of a variable, then ids are assigned to integer ids along the first axis of a variable, then ids are assigned to
shards in a contiguous manner, while attempting to keep each shard size shards in a contiguous manner, while attempting to keep each shard size
identical. If the ids do not evenly divide the number of shards, each of identical. If the ids do not evenly divide the number of shards, each of
the first several shards will be assigned one more id. For instance, a the first several shards will be assigned one more id. For instance, a
variable whose first dimension is variable whose first dimension is 13 has 13 ids, and they are split
13 has 13 ids, and they are split across 5 shards as: `[[0, 1, 2], [3, across 5 shards as:
4, 5], [6, 7, 8], [9, 10], [11, 12]]`. * Variables created under `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
`strategy.extended.colocate_vars_with` will not be partitioned, e.g,
optimizer's slot variables. * Variables created under `strategy.extended.colocate_vars_with` will
not be partitioned, e.g, optimizer's slot variables.
""" """
# pyformat: enable
self._cluster_resolver = cluster_resolver self._cluster_resolver = cluster_resolver
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
variable_partitioner) variable_partitioner)
self._verify_args_and_config(cluster_resolver) self._verify_args_and_config(cluster_resolver)
logging.info( logging.info(
"ParameterServerStrategyV2 is initialized with cluster_spec: " "`tf.distribute.experimental.ParameterServerStrategy` is initialized "
"%s", cluster_resolver.cluster_spec()) "with cluster_spec: %s", cluster_resolver.cluster_spec())
# TODO(b/167894802): Make chief, worker, and ps names customizable. # TODO(b/167894802): Make coordinator, worker, and ps names customizable.
self._connect_to_cluster(client_name="chief") self._connect_to_cluster(coordinator_name="chief")
super(ParameterServerStrategyV2, self).__init__(self._extended) super(ParameterServerStrategyV2, self).__init__(self._extended)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set( distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"ParameterServerStrategy") "ParameterServerStrategy")
def _connect_to_cluster(self, client_name): def _connect_to_cluster(self, coordinator_name):
if client_name in ["worker", "ps"]: if coordinator_name in ["worker", "ps"]:
raise ValueError("Client name should not be 'worker' or 'ps'.") raise ValueError("coordinator name should not be 'worker' or 'ps'.")
cluster_spec = self._cluster_resolver.cluster_spec() cluster_spec = self._cluster_resolver.cluster_spec()
self._num_workers = len(cluster_spec.as_dict().get("worker", ())) self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
self._num_ps = len(cluster_spec.as_dict().get("ps", ())) self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
device_filters = server_lib.ClusterDeviceFilters() device_filters = server_lib.ClusterDeviceFilters()
# For any worker, only the devices on PS and chief nodes are visible # For any worker, only the devices on ps and coordinator nodes are visible
for i in range(self._num_workers): for i in range(self._num_workers):
device_filters.set_device_filters( device_filters.set_device_filters(
"worker", i, ["/job:ps", "/job:%s" % client_name]) "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
# Similarly for any ps, only the devices on workers and chief are visible # Similarly for any ps, only the devices on workers and coordinator are
# visible
for i in range(self._num_ps): for i in range(self._num_ps):
device_filters.set_device_filters( device_filters.set_device_filters(
"ps", i, ["/job:worker", "/job:%s" % client_name]) "ps", i, ["/job:worker", "/job:%s" % coordinator_name])
# Allow at most one outstanding RPC for each worker at a certain time. This # Allow at most one outstanding RPC for each worker at a certain time. This
# is to simplify worker failure handling in the runtime # is to simplify worker failure handling in the runtime
@ -122,7 +440,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
self.__class__.__name__, cluster_spec) self.__class__.__name__, cluster_spec)
remote.connect_to_cluster( remote.connect_to_cluster(
cluster_spec, cluster_spec,
job_name=client_name, job_name=coordinator_name,
protocol=self._cluster_resolver.rpc_layer, protocol=self._cluster_resolver.rpc_layer,
cluster_device_filters=device_filters) cluster_device_filters=device_filters)
@ -134,7 +452,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
def _verify_args_and_config(self, cluster_resolver): def _verify_args_and_config(self, cluster_resolver):
if not cluster_resolver.cluster_spec(): if not cluster_resolver.cluster_spec():
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.") raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
if self.extended._num_gpus_per_worker > 1: if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
raise NotImplementedError("Multi-gpu is not supported yet.") raise NotImplementedError("Multi-gpu is not supported yet.")
@ -205,8 +523,8 @@ class ParameterServerStrategyV2Extended(
init_from_fn = False init_from_fn = False
initial_value = initial_value() initial_value = initial_value()
if not init_from_fn: if not init_from_fn:
# The initial_value is created on client, it will need to be sent to # The initial_value is created on coordinator, it will need to be sent to
# PS for variable initialization, which can be inefficient and can # ps for variable initialization, which can be inefficient and can
# potentially hit the 2GB limit on protobuf serialization. # potentially hit the 2GB limit on protobuf serialization.
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
dtype = initial_value.dtype dtype = initial_value.dtype

View File

@ -870,8 +870,8 @@ py_test(
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for ParameterServerClient and Keras models.""" """Tests for ClusterCoordinator and Keras models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -27,8 +27,8 @@ from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import ClusterSpec
def make_client(num_workers, num_ps): def make_coordinator(num_workers, num_ps):
cluster_def = multi_worker_test_base.create_in_process_cluster( cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
cluster_def["chief"] = [ cluster_def["chief"] = [
@ -52,7 +52,7 @@ def make_client(num_workers, num_ps):
] ]
cluster_resolver = SimpleClusterResolver( cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc") ClusterSpec(cluster_def), rpc_layer="grpc")
return client_lib.Client( return coordinator_lib.ClusterCoordinator(
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)) parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
@ -61,7 +61,7 @@ class KPLTest(test.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(KPLTest, cls).setUpClass() super(KPLTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2) cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
def testTrainAndServe(self): def testTrainAndServe(self):
# These vocabularies usually come from TFT or a Beam pipeline. # These vocabularies usually come from TFT or a Beam pipeline.
@ -71,10 +71,10 @@ class KPLTest(test.TestCase):
] ]
label_vocab = ["yes", "no"] label_vocab = ["yes", "no"]
with self.client.strategy.scope(): with self.coordinator.strategy.scope():
# Define KPLs under strategy's scope. Right now, if they have look up # Define KPLs under strategy's scope. Right now, if they have look up
# tables, they will be created on the client. Their variables will be # tables, they will be created on the coordinator. Their variables will be
# created on PS. Ideally they should be cached on each worker since they # created on PS. Ideally they should be cached on each worker since they
# will not be changed in a training step. # will not be changed in a training step.
feature_lookup_layer = string_lookup.StringLookup() feature_lookup_layer = string_lookup.StringLookup()
@ -113,7 +113,7 @@ class KPLTest(test.TestCase):
label = "yes" if "avenger" in features else "no" label = "yes" if "avenger" in features else "no"
yield {"features": features, "label": label} yield {"features": features, "label": label}
# The dataset will be created on the client? # The dataset will be created on the coordinator?
raw_dataset = dataset_ops.Dataset.from_generator( raw_dataset = dataset_ops.Dataset.from_generator(
feature_and_label_gen, feature_and_label_gen,
output_types={ output_types={
@ -131,7 +131,8 @@ class KPLTest(test.TestCase):
}, [x["label"]])) }, [x["label"]]))
return train_dataset return train_dataset
distributed_dataset = self.client.create_per_worker_dataset(dataset_fn) distributed_dataset = self.coordinator.create_per_worker_dataset(
dataset_fn)
model_input = keras.layers.Input( model_input = keras.layers.Input(
shape=(3,), dtype=dtypes.int64, name="model_input") shape=(3,), dtype=dtypes.int64, name="model_input")
@ -163,12 +164,12 @@ class KPLTest(test.TestCase):
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64) actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
accuracy.update_state(labels, actual_pred) accuracy.update_state(labels, actual_pred)
self.client._strategy.run(train_step, args=(iterator,)) self.coordinator._strategy.run(train_step, args=(iterator,))
distributed_iterator = iter(distributed_dataset) distributed_iterator = iter(distributed_dataset)
for _ in range(10): for _ in range(10):
self.client.schedule(worker_fn, args=(distributed_iterator,)) self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
self.client.join() self.coordinator.join()
self.assertGreater(accuracy.result().numpy(), 0.0) self.assertGreater(accuracy.result().numpy(), 0.0)
# Create a saved model. # Create a saved model.

View File

@ -362,7 +362,8 @@ class Model(training_lib.Model):
if isinstance(self._distribution_strategy, if isinstance(self._distribution_strategy,
(parameter_server_strategy.ParameterServerStrategyV1, (parameter_server_strategy.ParameterServerStrategyV1,
parameter_server_strategy.ParameterServerStrategy)): parameter_server_strategy.ParameterServerStrategy)):
raise NotImplementedError('ParameterServerStrategy currently only works ' raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
'erServerStrategy` currently only works '
'with the tf.Estimator API') 'with the tf.Estimator API')
if not self._experimental_run_tf_function: if not self._experimental_run_tf_function:

View File

@ -64,6 +64,9 @@ from tensorflow.python.framework.test_combinations import *
from tensorflow.python.util.tf_decorator import make_decorator from tensorflow.python.util.tf_decorator import make_decorator
from tensorflow.python.util.tf_decorator import unwrap from tensorflow.python.util.tf_decorator import unwrap
from tensorflow.python.distribute.parameter_server_strategy_v2 import *
from tensorflow.python.distribute.coordinator.cluster_coordinator import *
tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator) tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator)
tf_export('__internal__.decorator.unwrap', v1=[])(unwrap) tf_export('__internal__.decorator.unwrap', v1=[])(unwrap)

View File

@ -31,6 +31,7 @@ TENSORFLOW_API_INIT_FILES = [
"distribute/__init__.py", "distribute/__init__.py",
"distribute/cluster_resolver/__init__.py", "distribute/cluster_resolver/__init__.py",
"distribute/experimental/__init__.py", "distribute/experimental/__init__.py",
"distribute/experimental/coordinator/__init__.py",
"dtypes/__init__.py", "dtypes/__init__.py",
"errors/__init__.py", "errors/__init__.py",
"experimental/__init__.py", "experimental/__init__.py",

View File

@ -27,6 +27,8 @@ from tensorflow.lite.python import lite as _tflite_for_api_traversal
from tensorflow.python import modules_with_exports from tensorflow.python import modules_with_exports
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.framework import test_combinations from tensorflow.python.framework import test_combinations
# pylint: enable=unused-import # pylint: enable=unused-import

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.experimental.ParameterServerStrategy" path: "tensorflow.distribute.experimental.ParameterServerStrategy"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategy\'>" is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy_v2.ParameterServerStrategyV2\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
@ -18,7 +18,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'cluster_resolver\', \'variable_partitioner\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "colocate_vars_with" name: "colocate_vars_with"

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.experimental.coordinator.ClusterCoordinator"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.ClusterCoordinator\'>"
is_instance: "<type \'object\'>"
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "create_per_worker_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "done"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "fetch"
argspec: "args=[\'self\', \'val\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "join"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "schedule"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.coordinator.PerWorkerValues"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.PerWorkerValues\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.experimental.coordinator.RemoteValue"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.RemoteValue\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "fetch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,15 @@
path: "tensorflow.distribute.experimental.coordinator"
tf_module {
member {
name: "ClusterCoordinator"
mtype: "<type \'type\'>"
}
member {
name: "PerWorkerValues"
mtype: "<type \'type\'>"
}
member {
name: "RemoteValue"
mtype: "<type \'type\'>"
}
}

View File

@ -36,4 +36,8 @@ tf_module {
name: "ValueContext" name: "ValueContext"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "coordinator"
mtype: "<type \'module\'>"
}
} }

View File

@ -154,9 +154,9 @@ COMMON_PIP_DEPS = [
"//tensorflow/tools/common:test_module1", "//tensorflow/tools/common:test_module1",
"//tensorflow/tools/common:traverse", "//tensorflow/tools/common:traverse",
"//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute/client:client", "//tensorflow/python/distribute/coordinator:cluster_coordinator",
"//tensorflow/python/distribute/client:remote_eager_lib", "//tensorflow/python/distribute/coordinator:remote_eager_lib",
"//tensorflow/python/distribute/client:metric_utils", "//tensorflow/python/distribute/coordinator:metric_utils",
] ]
# On Windows, python binary is a zip file of runfiles tree. # On Windows, python binary is a zip file of runfiles tree.