PSv2: Export a few tf.distribute symbols related to TF2 parameter server training.
This change exports the following class symbols, and adds relevant documentation and example code to tf.distribute.experimental.ParameterServerStrategy tf.distribute.experimental.coordinator.ClusterCoordinator tf.distribute.experimental.coordinator.PerWorkerValues tf.distribute.experimental.coordinator.RemoteValue PiperOrigin-RevId: 338151262 Change-Id: If2d1c513d30a999c728cecc2e73b75adda1948c2
This commit is contained in:
parent
833b3a49a9
commit
32f35aabce
11
RELEASE.md
11
RELEASE.md
@ -208,7 +208,16 @@
|
|||||||
how many times the function is called, and independent of global seed
|
how many times the function is called, and independent of global seed
|
||||||
settings.
|
settings.
|
||||||
* `tf.distribute`:
|
* `tf.distribute`:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* (Experimental) Parameter server training:
|
||||||
|
* Replaced the existing
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy` symbol with
|
||||||
|
a new class that is for parameter server training in TF2. Usage with
|
||||||
|
the old symbol, usually with Estimator, should be replaced with
|
||||||
|
`tf.compat.v1.distribute.experimental.ParameterServerStrategy`.
|
||||||
|
* Added `tf.distribute.experimental.coordinator.*` namespace,
|
||||||
|
including the main API `ClusterCoordinator` for coordinating the
|
||||||
|
training cluster, the related data structure `RemoteValue`
|
||||||
|
and `PerWorkerValue`.
|
||||||
* `tf.keras`:
|
* `tf.keras`:
|
||||||
* Improvements from the functional API refactoring:
|
* Improvements from the functional API refactoring:
|
||||||
* Functional model construction does not need to maintain a global
|
* Functional model construction does not need to maintain a global
|
||||||
|
@ -153,8 +153,9 @@ py_library(
|
|||||||
":multi_process_runner",
|
":multi_process_runner",
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
":one_device_strategy",
|
":one_device_strategy",
|
||||||
|
":parameter_server_strategy_v2",
|
||||||
":sharded_variable",
|
":sharded_variable",
|
||||||
"//tensorflow/python/distribute/client",
|
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||||
"//tensorflow/python/distribute/experimental",
|
"//tensorflow/python/distribute/experimental",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1880,6 +1881,7 @@ tf_py_test(
|
|||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
":parameter_server_strategy_v2",
|
":parameter_server_strategy_v2",
|
||||||
":sharded_variable",
|
":sharded_variable",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:extra_py_tests_deps",
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
|
@ -8,8 +8,8 @@ package(
|
|||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "client",
|
name = "cluster_coordinator",
|
||||||
srcs = ["client.py"],
|
srcs = ["cluster_coordinator.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":metric_utils",
|
":metric_utils",
|
||||||
@ -34,9 +34,9 @@ py_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "client_test",
|
name = "cluster_coordinator_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["client_test.py"],
|
srcs = ["cluster_coordinator_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = [
|
tags = [
|
||||||
@ -44,7 +44,7 @@ tf_py_test(
|
|||||||
"notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards
|
"notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":client",
|
":cluster_coordinator",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
@ -66,12 +66,13 @@ tf_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "client_mpr_test",
|
name = "cluster_coordinator_mpr_test",
|
||||||
srcs = ["client_mpr_test.py"],
|
srcs = ["cluster_coordinator_mpr_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["no_oss"], # TODO(b/162119374)
|
tags = ["no_oss"], # TODO(b/162119374)
|
||||||
deps = [
|
deps = [
|
||||||
|
":cluster_coordinator",
|
||||||
":remote_eager_lib",
|
":remote_eager_lib",
|
||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
@ -81,7 +82,6 @@ tf_py_test(
|
|||||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||||
"//tensorflow/python/distribute:sharded_variable",
|
"//tensorflow/python/distribute:sharded_variable",
|
||||||
"//tensorflow/python/distribute/client",
|
|
||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
@ -102,7 +102,7 @@ tf_py_test(
|
|||||||
srcs = ["metric_utils_test.py"],
|
srcs = ["metric_utils_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":client",
|
":cluster_coordinator",
|
||||||
":metric_utils",
|
":metric_utils",
|
||||||
"//tensorflow/python:training_server_lib",
|
"//tensorflow/python:training_server_lib",
|
||||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Module for `Client` and relevant cluster-worker related library.
|
"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
|
||||||
|
|
||||||
This is currently under development and the API is subject to change.
|
This is currently under development and the API is subject to change.
|
||||||
"""
|
"""
|
||||||
@ -35,7 +35,7 @@ from six.moves import queue
|
|||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.client import metric_utils
|
from tensorflow.python.distribute.coordinator import metric_utils
|
||||||
from tensorflow.python.eager import cancellation
|
from tensorflow.python.eager import cancellation
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -48,6 +48,7 @@ from tensorflow.python.ops import resource_variable_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
# Maximum time for failed worker to come back is 1 hour
|
# Maximum time for failed worker to come back is 1 hour
|
||||||
_WORKER_MAXIMUM_RECOVERY_SEC = 3600
|
_WORKER_MAXIMUM_RECOVERY_SEC = 3600
|
||||||
@ -56,9 +57,10 @@ _WORKER_MAXIMUM_RECOVERY_SEC = 3600
|
|||||||
# When the maximum queue size is reached, further schedule calls will become
|
# When the maximum queue size is reached, further schedule calls will become
|
||||||
# blocking until some previously queued closures are executed on workers.
|
# blocking until some previously queued closures are executed on workers.
|
||||||
# Note that using an "infinite" queue size can take a non-trivial portion of
|
# Note that using an "infinite" queue size can take a non-trivial portion of
|
||||||
# memory, and even lead to client OOM. Modify the size to a smaller value for
|
# memory, and even lead to coordinator OOM. Modify the size to a smaller value
|
||||||
# client with constrained memory resource (only recommended for advanced users).
|
# for coordinator with constrained memory resource (only recommended for
|
||||||
# Also used in unit tests to ensure the correctness when the queue is full.
|
# advanced users). Also used in unit tests to ensure the correctness when the
|
||||||
|
# queue is full.
|
||||||
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
|
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
|
||||||
|
|
||||||
# RPC error message from PS
|
# RPC error message from PS
|
||||||
@ -99,18 +101,77 @@ class _RemoteValueStatus(enum.Enum):
|
|||||||
READY = "READY"
|
READY = "READY"
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
|
||||||
class RemoteValue(object):
|
class RemoteValue(object):
|
||||||
"""An asynchronously available value of a remotely executed function.
|
"""An asynchronously available value of a remotely executed function.
|
||||||
|
|
||||||
`RemoteValue` class is used as the return value of `Client.schedule()` where
|
`tf.distribute.experimental.coordinator.RemoteValue` class is used as the
|
||||||
the underlying concrete value comes at a later time once the function has been
|
return value of
|
||||||
remotely executed. `RemoteValue` can be used as an input to a subsequent
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` where
|
||||||
function scheduled with `Client.schedule()`.
|
the underlying concrete value becomes available at a later time once the
|
||||||
|
function has been remotely executed. The underlying concrete value is the
|
||||||
|
`tf.Tensor.numpy()` result of the `tf.Tensor`, or the structure of
|
||||||
|
`tf.Tensor`s, returned from the `tf.function` that was scheduled.
|
||||||
|
|
||||||
Note: this class is not thread-safe.
|
`tf.distribute.experimental.coordinator.RemoteValue` to be used as an input to
|
||||||
|
a subsequent function scheduled with
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()` is
|
||||||
|
currently not supported.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver=...)
|
||||||
|
coordinator = (
|
||||||
|
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
v1 = tf.Variable(initial_value=0.0)
|
||||||
|
v2 = tf.Variable(initial_value=1.0)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def worker_fn():
|
||||||
|
v1.assign_add(0.1)
|
||||||
|
v2.assign_sub(0.2)
|
||||||
|
return v1.read_value() / v2.read_value()
|
||||||
|
|
||||||
|
result = coordinator.schedule(worker_fn)
|
||||||
|
# Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
|
||||||
|
assert result.fetch() == 0.125
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
# `worker_fn` will be run on arbitrary workers that are available. The
|
||||||
|
# `result` value will be non-deterministic because the workers are executing
|
||||||
|
# the functions asynchronously.
|
||||||
|
result = coordinator.schedule(worker_fn)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, closure, type_spec):
|
def fetch(self):
|
||||||
|
"""Wait for the result of `RemoteValue` to be ready and return the result.
|
||||||
|
|
||||||
|
This makes the value concrete by copying the remote value to local.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The actual output of the `tf.function` associated with this `RemoteValue`,
|
||||||
|
previously by a
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
|
||||||
|
This can be a single value, or a structure of values, depending on the
|
||||||
|
output of the `tf.function`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
tf.errors.CancelledError: If the function that produces this `RemoteValue`
|
||||||
|
is aborted or cancelled due to failure, and the user should handle and
|
||||||
|
reschedule.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Must be implemented in subclasses.")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteValueImpl(RemoteValue):
|
||||||
|
"""Implementation of `RemoteValue`."""
|
||||||
|
|
||||||
|
def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
|
||||||
self._closure = closure
|
self._closure = closure
|
||||||
# The type spec for this `RemoteValue` which is used to trace functions that
|
# The type spec for this `RemoteValue` which is used to trace functions that
|
||||||
# take this `RemoteValue` as input.
|
# take this `RemoteValue` as input.
|
||||||
@ -157,16 +218,6 @@ class RemoteValue(object):
|
|||||||
self._type_spec = func_graph.convert_structure_to_signature(type_spec)
|
self._type_spec = func_graph.convert_structure_to_signature(type_spec)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
"""Wait for the result of RemoteValue to be ready and return the result.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The remote value, as a numpy data type (if scalar) or ndarray.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
tf.errors.CancelledError: If the function that produces this `RemoteValue`
|
|
||||||
is aborted or cancelled due to failure, and the user should handle and
|
|
||||||
reschedule.
|
|
||||||
"""
|
|
||||||
self._status_available_event.wait()
|
self._status_available_event.wait()
|
||||||
if self._status is _RemoteValueStatus.ABORTED:
|
if self._status is _RemoteValueStatus.ABORTED:
|
||||||
raise errors.CancelledError(
|
raise errors.CancelledError(
|
||||||
@ -241,8 +292,23 @@ def _maybe_as_type_spec(val):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
|
||||||
class PerWorkerValues(object):
|
class PerWorkerValues(object):
|
||||||
"""Holds a list of per worker values."""
|
"""A container that holds a list of values, one value per worker.
|
||||||
|
|
||||||
|
`tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
|
||||||
|
of values, where each of the values represents a resource that are located in
|
||||||
|
individual workers, and upon being used as one of the `args` or `kwargs` of
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
|
||||||
|
value specific to a worker will be passed into the function being executed at
|
||||||
|
that particular worker.
|
||||||
|
|
||||||
|
Currently, the only supported path to create an object of
|
||||||
|
`tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
|
||||||
|
`iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
|
||||||
|
distributed dataset instance. The mechanism to create a custom
|
||||||
|
`tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, values):
|
def __init__(self, values):
|
||||||
self._values = tuple(values)
|
self._values = tuple(values)
|
||||||
@ -262,9 +328,10 @@ def _disallow_remote_value_as_input(structured):
|
|||||||
|
|
||||||
def _raise_if_remote_value(x):
|
def _raise_if_remote_value(x):
|
||||||
if isinstance(x, RemoteValue):
|
if isinstance(x, RemoteValue):
|
||||||
raise ValueError("RemoteValue cannot be used as an input to scheduled "
|
raise ValueError(
|
||||||
"function. Please file a feature request if you need "
|
"`tf.distribute.experimental.coordinator.RemoteValue` used "
|
||||||
"this feature.")
|
"as an input to scheduled function is not yet "
|
||||||
|
"supported.")
|
||||||
|
|
||||||
nest.map_structure(_raise_if_remote_value, structured)
|
nest.map_structure(_raise_if_remote_value, structured)
|
||||||
|
|
||||||
@ -274,8 +341,8 @@ class Closure(object):
|
|||||||
|
|
||||||
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
|
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
|
||||||
if not callable(function):
|
if not callable(function):
|
||||||
raise ValueError("Function passed to `Client.schedule` must be a "
|
raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
|
||||||
"callable object.")
|
"be a callable object.")
|
||||||
self._args = args or ()
|
self._args = args or ()
|
||||||
self._kwargs = kwargs or {}
|
self._kwargs = kwargs or {}
|
||||||
|
|
||||||
@ -287,9 +354,9 @@ class Closure(object):
|
|||||||
replica_kwargs = _select_worker_slice(0, self._kwargs)
|
replica_kwargs = _select_worker_slice(0, self._kwargs)
|
||||||
|
|
||||||
# Note: no need to handle function registration failure since this kind of
|
# Note: no need to handle function registration failure since this kind of
|
||||||
# failure will not raise exceptions as designed in the runtime. The client
|
# failure will not raise exceptions as designed in the runtime. The
|
||||||
# has to rely on subsequent operations that raise to catch function
|
# coordinator has to rely on subsequent operations that raise to catch
|
||||||
# registration failure.
|
# function registration failure.
|
||||||
|
|
||||||
# Record the function tracing overhead. Note that we pass in the tracing
|
# Record the function tracing overhead. Note that we pass in the tracing
|
||||||
# count of the def_function.Function as a state tracker, so that metrics
|
# count of the def_function.Function as a state tracker, so that metrics
|
||||||
@ -303,17 +370,18 @@ class Closure(object):
|
|||||||
self._function = cancellation_mgr.get_cancelable_function(
|
self._function = cancellation_mgr.get_cancelable_function(
|
||||||
concrete_function)
|
concrete_function)
|
||||||
self._output_remote_values = nest.map_structure(
|
self._output_remote_values = nest.map_structure(
|
||||||
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
|
lambda x: RemoteValueImpl(self, x),
|
||||||
|
concrete_function.structured_outputs)
|
||||||
elif isinstance(function, tf_function.ConcreteFunction):
|
elif isinstance(function, tf_function.ConcreteFunction):
|
||||||
self._function = cancellation_mgr.get_cancelable_function(function)
|
self._function = cancellation_mgr.get_cancelable_function(function)
|
||||||
self._output_remote_values = nest.map_structure(
|
self._output_remote_values = nest.map_structure(
|
||||||
lambda x: RemoteValue(self, x), function.structured_outputs)
|
lambda x: RemoteValueImpl(self, x), function.structured_outputs)
|
||||||
else:
|
else:
|
||||||
# Regular python functions.
|
# Regular python functions.
|
||||||
self._function = function
|
self._function = function
|
||||||
# TODO(yuefengz): maybe we should trace python functions if their inputs
|
# TODO(yuefengz): maybe we should trace python functions if their inputs
|
||||||
# are Python primitives, tensors and composite tensors.
|
# are Python primitives, tensors and composite tensors.
|
||||||
self._output_remote_values = RemoteValue(self, None)
|
self._output_remote_values = RemoteValueImpl(self, None)
|
||||||
|
|
||||||
def _fetch_output_remote_values(self):
|
def _fetch_output_remote_values(self):
|
||||||
"""Temporary method used to sync the scheduler."""
|
"""Temporary method used to sync the scheduler."""
|
||||||
@ -394,7 +462,7 @@ class _CoordinatedClosureQueue(object):
|
|||||||
|
|
||||||
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
|
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"In a `Client`, creating an infinite closure queue can "
|
"In a `ClusterCoordinator`, creating an infinite closure queue can "
|
||||||
"consume a significant amount of memory and even lead to OOM.")
|
"consume a significant amount of memory and even lead to OOM.")
|
||||||
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
|
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
|
||||||
self._error = None
|
self._error = None
|
||||||
@ -430,10 +498,10 @@ class _CoordinatedClosureQueue(object):
|
|||||||
# The cancellation manager cannot be reused once cancelled. After all
|
# The cancellation manager cannot be reused once cancelled. After all
|
||||||
# closures (queued or inflight) are cleaned up, recreate the cancellation
|
# closures (queued or inflight) are cleaned up, recreate the cancellation
|
||||||
# manager with clean state.
|
# manager with clean state.
|
||||||
# Note on thread-safety: this is triggered when one of theses client APIs
|
# Note on thread-safety: this is triggered when one of theses
|
||||||
# are called: `schedule`, `wait`, and `done`. At the same time, no new
|
# ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
|
||||||
# closures can be constructed (which reads the _cancellation_mgr to get
|
# same time, no new closures can be constructed (which reads the
|
||||||
# cancellable functions).
|
# _cancellation_mgr to get cancellable functions).
|
||||||
self._cancellation_mgr = cancellation.CancellationManager()
|
self._cancellation_mgr = cancellation.CancellationManager()
|
||||||
|
|
||||||
def _raise_if_error(self):
|
def _raise_if_error(self):
|
||||||
@ -742,8 +810,8 @@ class Worker(object):
|
|||||||
|
|
||||||
def _register_resource(self, resource_remote_value):
|
def _register_resource(self, resource_remote_value):
|
||||||
if not isinstance(resource_remote_value, RemoteValue):
|
if not isinstance(resource_remote_value, RemoteValue):
|
||||||
raise ValueError(
|
raise ValueError("Resource being registered is not of type "
|
||||||
"Resource being registered is not of type `RemoteValue`.")
|
"`tf.distribute.experimental.coordinator.RemoteValue`.")
|
||||||
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
|
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
|
||||||
|
|
||||||
|
|
||||||
@ -753,7 +821,7 @@ class Cluster(object):
|
|||||||
We assume all function errors are fatal and based on this assumption our
|
We assume all function errors are fatal and based on this assumption our
|
||||||
error reporting logic is:
|
error reporting logic is:
|
||||||
1) Both `schedule` and `join` can raise a non-retryable error which is the
|
1) Both `schedule` and `join` can raise a non-retryable error which is the
|
||||||
first error seen by the client from any previously scheduled functions.
|
first error seen by the coordinator from any previously scheduled functions.
|
||||||
2) When an error is raised, there is no guarantee on how many previously
|
2) When an error is raised, there is no guarantee on how many previously
|
||||||
scheduled functions have been executed; functions that have not been executed
|
scheduled functions have been executed; functions that have not been executed
|
||||||
will be thrown away and marked as cancelled.
|
will be thrown away and marked as cancelled.
|
||||||
@ -775,17 +843,17 @@ class Cluster(object):
|
|||||||
|
|
||||||
# Ignore PS failures reported by workers due to transient connection errors.
|
# Ignore PS failures reported by workers due to transient connection errors.
|
||||||
# Transient connectivity issues between workers and PS are relayed by the
|
# Transient connectivity issues between workers and PS are relayed by the
|
||||||
# workers to the client, leading the client to believe that there are PS
|
# workers to the coordinator, leading the coordinator to believe that there
|
||||||
# failures. The difference between transient vs. permanent PS failure is the
|
# are PS failures. The difference between transient vs. permanent PS failure
|
||||||
# number of reports from the workers. When this env var is set to a positive
|
# is the number of reports from the workers. When this env var is set to a
|
||||||
# integer K, the client ignores up to K reports of a failed PS task. I.e.,
|
# positive integer K, the coordinator ignores up to K reports of a failed PS
|
||||||
# only when there are more than K trials of executing closures fail due to
|
# task, i.e., only when there are more than K trials of executing closures
|
||||||
# errors from the same PS instance do we consider the PS instance encounters
|
# fail due to errors from the same PS instance do we consider the PS
|
||||||
# a failure.
|
# instance encounters a failure.
|
||||||
# TODO(b/164279603): Remove this workaround when the underlying connectivity
|
# TODO(b/164279603): Remove this workaround when the underlying connectivity
|
||||||
# issue in gRPC server is resolved.
|
# issue in gRPC server is resolved.
|
||||||
self._transient_ps_failures_threshold = int(os.environ.get(
|
self._transient_ps_failures_threshold = int(
|
||||||
"TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3))
|
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
|
||||||
self._potential_ps_failures_lock = threading.Lock()
|
self._potential_ps_failures_lock = threading.Lock()
|
||||||
self._potential_ps_failures_count = [0] * self._num_ps
|
self._potential_ps_failures_count = [0] * self._num_ps
|
||||||
|
|
||||||
@ -825,7 +893,7 @@ class Cluster(object):
|
|||||||
kwargs: Keyword arguments for `fn`.
|
kwargs: Keyword arguments for `fn`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A structure of `RemoteValue` object.
|
A `RemoteValue` object.
|
||||||
"""
|
"""
|
||||||
closure = Closure(
|
closure = Closure(
|
||||||
function,
|
function,
|
||||||
@ -844,68 +912,124 @@ class Cluster(object):
|
|||||||
return self._closure_queue.done()
|
return self._closure_queue.done()
|
||||||
|
|
||||||
|
|
||||||
class Client(object):
|
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
|
||||||
"""An object to schedule and orchestrate remote function execution.
|
class ClusterCoordinator(object):
|
||||||
|
"""An object to schedule and coordinate remote function execution.
|
||||||
|
|
||||||
A `Client` object represents a program used to create dataset, schedule
|
A `tf.distribute.experimental.coordinator.ClusterCoordinator` object
|
||||||
functions to be executed, and fetch the results of the functions.
|
represents a program used to distribute dataset onto the workers, schedule
|
||||||
|
functions to be executed, and fetch the results of the functions. It expects
|
||||||
|
the cluster to contain some machines with processes running TensorFlow
|
||||||
|
servers.
|
||||||
|
|
||||||
Currently, `Client` is not supported to be used in a standalone manner.
|
Currently, `tf.distribute.experimental.coordinator.ClusterCoordinator` is not
|
||||||
It should be used in conjunction with `ParameterServerStrategyV2`.
|
supported to be used in a standalone manner. It should be used in conjunction
|
||||||
|
with a `tf.distribute` strategy that is designed to work with it. Currently,
|
||||||
|
only `tf.distribute.experimental.ParameterServerStrategy` is supported to work
|
||||||
|
with `tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||||
|
|
||||||
|
__Fault tolerance__
|
||||||
|
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator`, when used with
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
|
||||||
|
fault tolerance for worker failures. That is, when some workers are not
|
||||||
|
available for any reason to be reached from the coordinator, the training
|
||||||
|
progress continues to be made with the remaining operating workers, without
|
||||||
|
the need of any additional treatment in user code.
|
||||||
|
|
||||||
|
On the other hand, when a parameter server or the coordinator fails, a
|
||||||
|
`tf.errors.UnavailableError` is raised by `ClusterCoordinator.schedule` or
|
||||||
|
`ClusterCoordinator.join`. If any parameter server fails, the user should
|
||||||
|
restart the processes on the failed parameter servers when they become
|
||||||
|
available, *and* restart the process on the coordinator, so the coordinator
|
||||||
|
can re-create the variables on the parameter servers. If the coordinator fails
|
||||||
|
but all other machines continue to be operating, the user only needs to
|
||||||
|
restart the process on the coordinator, which will automatically connect to
|
||||||
|
the parameter servers and workers, and continue the progress.
|
||||||
|
|
||||||
|
It is thus essential that in user's custom training loop, a checkpoint file is
|
||||||
|
periodically saved, and restored at the start of the program.
|
||||||
|
|
||||||
|
* At-least-once semantics of `schedule`: `schedule` puts the `tf.function` in
|
||||||
|
a queue where it gets picked up by a worker to execute. If a worker picks up
|
||||||
|
a `tf.function`, has begun to execute it, but the worker process is
|
||||||
|
disconnected from the coordinator in the middle of execution, the
|
||||||
|
`tf.function` is deemed not completed, and is put back to the queue. This is
|
||||||
|
regardless of how much progress the `tf.function` has run. Once it is
|
||||||
|
re-queued, it will be picked up again by an arbitrary available worker at some
|
||||||
|
later time. As a result, the same function may be executed more than once, but
|
||||||
|
not less than once. If an `tf.keras.optimizers.Optimizer` is used,
|
||||||
|
`tf.keras.optimizers.Optimizer.iterations` roughly indicates the number of
|
||||||
|
times the gradients have been applied and can be used as an approximation of
|
||||||
|
the total number of steps. The number of times the function has been scheduled
|
||||||
|
by the `tf.distribute.experimental.coordinator.ClusterCoordinator`, on the
|
||||||
|
other hand, should not be indicative of actual steps run. See
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` for more
|
||||||
|
information.
|
||||||
|
|
||||||
|
See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
|
||||||
|
example usage of this API.
|
||||||
|
|
||||||
This is currently under development, and the API as well as implementation
|
This is currently under development, and the API as well as implementation
|
||||||
is subject to changes.
|
are subject to changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, strategy):
|
def __init__(self, strategy):
|
||||||
"""Initialization of a `Client` instance.
|
"""Initialization of a `ClusterCoordinator` instance.
|
||||||
|
|
||||||
This connects the client to remote workers and parameter servers, through
|
|
||||||
a `tf.config.experimental_connect_to_cluster` call.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strategy: a `tf.distribute.Strategy` object. Currently, only
|
strategy: a supported `tf.distribute.Strategy` object. Currently, only
|
||||||
`ParameterServerStrategyV2` is supported.
|
`tf.distribute.experimental.ParameterServerStrategy` is supported.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the strategy being used is not supported.
|
ValueError: if the strategy being used is not supported.
|
||||||
"""
|
"""
|
||||||
if not isinstance(strategy,
|
if not isinstance(strategy,
|
||||||
parameter_server_strategy_v2.ParameterServerStrategyV2):
|
parameter_server_strategy_v2.ParameterServerStrategyV2):
|
||||||
raise ValueError("Only `ParameterServerStrategyV2` is supported in "
|
raise ValueError(
|
||||||
"`Client` currently.")
|
"Only `tf.distribute.experimental.ParameterServerStrategy` "
|
||||||
|
"is supported to work with "
|
||||||
|
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
|
||||||
|
"currently.")
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self.cluster = Cluster(strategy)
|
self.cluster = Cluster(strategy)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strategy(self):
|
def strategy(self):
|
||||||
|
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
|
||||||
return self._strategy
|
return self._strategy
|
||||||
|
|
||||||
def schedule(self, fn, args=None, kwargs=None):
|
def schedule(self, fn, args=None, kwargs=None):
|
||||||
"""Schedules `fn` to be dispatched to a worker for execution asynchronously.
|
"""Schedules `fn` to be dispatched to a worker for asynchronous execution.
|
||||||
|
|
||||||
When calling `schedule` with a function `fn`, `fn` will be executed on a
|
When calling `schedule` with a function `fn`, `fn` will be executed on a
|
||||||
remote worker at some later time. The process is asynchronous, meaning
|
remote worker at some later time. The process is asynchronous, meaning
|
||||||
`schedule` returns immediately, possibly without having the result ready
|
`schedule` returns immediately, possibly without having the result ready
|
||||||
yet. `schedule` returns a structure of `RemoteValue` object, which wraps the
|
yet. `schedule` returns a
|
||||||
output of the function. Call `fetch()` on `RemoteValue` to wait for the
|
`tf.distribute.experimental.coordinator.RemoteValue` object, which wraps the
|
||||||
|
output of the function. After `schedule` is called, `fetch` can be called
|
||||||
|
on the `tf.distribute.experimental.coordinator.RemoteValue` to wait for the
|
||||||
function execution to finish and retrieve its output from the remote worker.
|
function execution to finish and retrieve its output from the remote worker.
|
||||||
|
On the other hand, call
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
|
||||||
|
all scheduled functions to finish execution before proceeding.
|
||||||
|
|
||||||
`schedule` guarantees that `fn` will be executed on a worker at least once;
|
`schedule` guarantees that `fn` will be executed on a worker at least once;
|
||||||
it could be more than once if its corresponding worker fails in the middle
|
it could be more than once if its corresponding worker fails in the middle
|
||||||
of its execution. Note that since worker can fail at any point when
|
of its execution. Note that since worker can fail at any point when
|
||||||
executing the function, it is possible that the function is partially
|
executing the function, it is possible that the function is partially
|
||||||
executed, but `Client` guarantees that in those events, the function will
|
executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
|
||||||
eventually be fully executed, possibly on a different worker that is
|
guarantees that in those events, the function will eventually be fully
|
||||||
available.
|
executed, possibly on a different worker that is available.
|
||||||
|
|
||||||
If any previously scheduled function raises an error, `schedule` will fail
|
If any previously scheduled function raises an error, `schedule` will fail
|
||||||
by raising any one of those errors, and clear the errors collected so far.
|
by raising any one of those errors, and clear the errors collected so far.
|
||||||
There are two implications when this happens: 1) user should call `schedule`
|
There are two implications when this happens: 1) user should call `schedule`
|
||||||
with `fn` again to re-schedule, and 2) some of the previously scheduled
|
with `fn` again to re-schedule, and 2) some of the previously scheduled
|
||||||
functions may have not been executed. User can call `fetch` on the returned
|
functions may have not been executed. User can call `fetch` on the returned
|
||||||
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
|
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
|
||||||
reschedule the corresponding function if needed.
|
executed, failed, or cancelled, and reschedule the corresponding function if
|
||||||
|
needed.
|
||||||
|
|
||||||
When `schedule` raises, it guarantees that there is no function that is
|
When `schedule` raises, it guarantees that there is no function that is
|
||||||
still being executed.
|
still being executed.
|
||||||
@ -914,12 +1038,14 @@ class Client(object):
|
|||||||
execution, or priority of the workers.
|
execution, or priority of the workers.
|
||||||
|
|
||||||
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is
|
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is
|
||||||
executed on a worker. They can be `PerWorkerValues`, which is a collection
|
executed on a worker. They can be
|
||||||
of values, each of which represents a component specific to a worker; in
|
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
|
||||||
this case, the argument will be substituted with the corresponding component
|
'collection of values, each of which represents a component specific to a
|
||||||
on the target worker. Arguments that are not `PerWorkerValues` will be
|
worker; in this case, the argument will be substituted with the
|
||||||
passed into `fn` as-is. Currently, `RemoteValue` is not supported to be
|
corresponding component on the target worker. Arguments that are not
|
||||||
input `args` or `kwargs`.
|
`tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
|
||||||
|
`fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
|
||||||
|
is not supported to be input `args` or `kwargs`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn: A `tf.function`; the function to be dispatched to a worker for
|
fn: A `tf.function`; the function to be dispatched to a worker for
|
||||||
@ -928,12 +1054,13 @@ class Client(object):
|
|||||||
kwargs: Keyword arguments for `fn`.
|
kwargs: Keyword arguments for `fn`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A structure of `RemoteValue` object.
|
A `tf.distribute.experimental.coordinator.RemoteValue` object that
|
||||||
|
represents the output of the function scheduled.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: one of the exceptions caught by the client by any previously
|
Exception: one of the exceptions caught by the coordinator by any
|
||||||
scheduled function since the last time an error was thrown or since
|
previously scheduled function since the last time an error was thrown or
|
||||||
the beginning of the program.
|
since the beginning of the program.
|
||||||
"""
|
"""
|
||||||
# Slot variables are usually created during function tracing time; thus
|
# Slot variables are usually created during function tracing time; thus
|
||||||
# `schedule` needs to be called within the `strategy.scope()`.
|
# `schedule` needs to be called within the `strategy.scope()`.
|
||||||
@ -946,18 +1073,18 @@ class Client(object):
|
|||||||
If any previously scheduled function raises an error, `join` will fail by
|
If any previously scheduled function raises an error, `join` will fail by
|
||||||
raising any one of those errors, and clear the errors collected so far. If
|
raising any one of those errors, and clear the errors collected so far. If
|
||||||
this happens, some of the previously scheduled functions may have not been
|
this happens, some of the previously scheduled functions may have not been
|
||||||
executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
|
executed. Users can call `fetch` on the returned
|
||||||
they have executed, failed, or cancelled. If some that have been cancelled
|
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
|
||||||
need to be rescheduled, users should call `schedule` with the function
|
executed, failed, or cancelled. If some that have been cancelled need to be
|
||||||
again.
|
rescheduled, users should call `schedule` with the function again.
|
||||||
|
|
||||||
When `join` returns or raises, it guarantees that there is no function that
|
When `join` returns or raises, it guarantees that there is no function that
|
||||||
is still being executed.
|
is still being executed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: one of the exceptions caught by the client by any previously
|
Exception: one of the exceptions caught by the coordinator by any
|
||||||
scheduled function since the last time an error was thrown or since
|
previously scheduled function since the last time an error was thrown or
|
||||||
the beginning of the program.
|
since the beginning of the program.
|
||||||
"""
|
"""
|
||||||
self.cluster.join()
|
self.cluster.join()
|
||||||
|
|
||||||
@ -969,6 +1096,9 @@ class Client(object):
|
|||||||
|
|
||||||
When `done` returns True or raises, it guarantees that there is no function
|
When `done` returns True or raises, it guarantees that there is no function
|
||||||
that is still being executed.
|
that is still being executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether all the scheduled functions have finished execution.
|
||||||
"""
|
"""
|
||||||
return self.cluster.done()
|
return self.cluster.done()
|
||||||
|
|
||||||
@ -978,12 +1108,15 @@ class Client(object):
|
|||||||
This creates the given dataset generated by dataset_fn on the workers
|
This creates the given dataset generated by dataset_fn on the workers
|
||||||
and returns an object that represents the collection of those individual
|
and returns an object that represents the collection of those individual
|
||||||
datasets. Calling `iter` on such collection of dataset returns a
|
datasets. Calling `iter` on such collection of dataset returns a
|
||||||
`PerWorkerValues`, which is a collection of iterators, where the iterators
|
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
|
||||||
have been placed on respective workers.
|
collection of iterators, where the iterators have been placed on respective
|
||||||
|
workers.
|
||||||
|
|
||||||
Calling `next` on this `PerWorkerValues` of iterators is currently
|
Calling `next` on this
|
||||||
unsupported; it is meant to be passed as an argument into `Client.schedule`.
|
`tf.distribute.experimental.coordinator.PerWorkerValues` of iterators is
|
||||||
When the scheduled function is picked up and being executed by a worker, the
|
currently unsupported; it is meant to be passed as an argument into
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
|
||||||
|
the scheduled function is picked up and being executed by a worker, the
|
||||||
function will receive the individual iterator that corresponds to the
|
function will receive the individual iterator that corresponds to the
|
||||||
worker, and now `next` can be called on iterator to get the next (batch or
|
worker, and now `next` can be called on iterator to get the next (batch or
|
||||||
example) of data.
|
example) of data.
|
||||||
@ -993,6 +1126,26 @@ class Client(object):
|
|||||||
examples may be skipped and not covered by other workers, if the dataset is
|
examples may be skipped and not covered by other workers, if the dataset is
|
||||||
sharded.
|
sharded.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@tf.function
|
||||||
|
def worker_fn(iterator):
|
||||||
|
return next(iterator)
|
||||||
|
|
||||||
|
def dataset_fn():
|
||||||
|
return tf.data.from_tensor_slices([3] * 3)
|
||||||
|
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver=...)
|
||||||
|
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||||
|
strategy=strategy)
|
||||||
|
per_worker_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||||
|
per_worker_iter = iter(per_worker_dataset)
|
||||||
|
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
|
||||||
|
assert remote_value.fetch() == 3
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_fn: The dataset function that returns a dataset. This is to be
|
dataset_fn: The dataset function that returns a dataset. This is to be
|
||||||
executed on the workers.
|
executed on the workers.
|
||||||
@ -1000,7 +1153,8 @@ class Client(object):
|
|||||||
Returns:
|
Returns:
|
||||||
An object that represents the collection of those individual
|
An object that represents the collection of those individual
|
||||||
datasets. `iter` is expected to be called on this object that returns
|
datasets. `iter` is expected to be called on this object that returns
|
||||||
a `PerWorkerValues` of the iterators (that are on the workers).
|
a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
|
||||||
|
iterators (that are on the workers).
|
||||||
"""
|
"""
|
||||||
input_workers = input_lib.InputWorkers([
|
input_workers = input_lib.InputWorkers([
|
||||||
(w.device_name, [w.device_name]) for w in self.cluster.workers
|
(w.device_name, [w.device_name]) for w in self.cluster.workers
|
||||||
@ -1011,7 +1165,8 @@ class Client(object):
|
|||||||
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
|
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
|
||||||
"""Synchronously create resources on the workers.
|
"""Synchronously create resources on the workers.
|
||||||
|
|
||||||
The resources are represented by `RemoteValue`s.
|
The resources are represented by
|
||||||
|
`tf.distribute.experimental.coordinator.RemoteValue`s.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn: The function to be dispatched to all workers for execution
|
fn: The function to be dispatched to all workers for execution
|
||||||
@ -1020,7 +1175,9 @@ class Client(object):
|
|||||||
kwargs: Keyword arguments for `fn`.
|
kwargs: Keyword arguments for `fn`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `PerWorkerValues` object, which wraps a tuple of `RemoteValue` objects.
|
A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
|
||||||
|
wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
|
||||||
|
objects.
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for w in self.cluster.workers:
|
for w in self.cluster.workers:
|
||||||
@ -1028,21 +1185,52 @@ class Client(object):
|
|||||||
return PerWorkerValues(tuple(results))
|
return PerWorkerValues(tuple(results))
|
||||||
|
|
||||||
def fetch(self, val):
|
def fetch(self, val):
|
||||||
"""Blocking call to fetch results from `RemoteValue`s.
|
"""Blocking call to fetch results from the remote values.
|
||||||
|
|
||||||
This returns the execution result of `RemoteValue`s; if not ready,
|
This is a wrapper around
|
||||||
waiting for it while blocking the caller.
|
`tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
|
||||||
|
`RemoteValue` structure; it returns the execution results of
|
||||||
|
`RemoteValue`s. If not ready, wait for them while blocking the caller.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
strategy = ...
|
||||||
|
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||||
|
strategy)
|
||||||
|
|
||||||
|
def dataset_fn():
|
||||||
|
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
v = tf.Variable(initial_value=0)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def worker_fn(iterator):
|
||||||
|
def replica_fn(x):
|
||||||
|
v.assign_add(x)
|
||||||
|
return v.read_value()
|
||||||
|
return strategy.run(replica_fn, args=(next(iterator),))
|
||||||
|
|
||||||
|
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||||
|
distributed_iterator = iter(distributed_dataset)
|
||||||
|
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||||
|
assert coordinator.fetch(result) == 1
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
val: The value to fetch the results from. If this is structure of
|
val: The value to fetch the results from. If this is structure of
|
||||||
`RemoteValue`, `fetch()` will be called on the individual `RemoteValue`
|
`tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
|
||||||
to get the result.
|
called on the individual
|
||||||
|
`tf.distribute.experimental.coordinator.RemoteValue` to get the result.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If `val` is a `RemoteValue` or a structure of `RemoteValue`s, returns
|
If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
|
||||||
the fetched `RemoteValue` value immediately if it's available, or blocks
|
structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
|
||||||
the call until it's available, and returns the fetched `RemoteValue`
|
return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
|
||||||
values with the same structure. If `val` is other types, return (`val`,).
|
values immediately if they are available, or block the call until they are
|
||||||
|
available, and return the fetched
|
||||||
|
`tf.distribute.experimental.coordinator.RemoteValue` values with the same
|
||||||
|
structure. If `val` is other types, return it as-is.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _maybe_fetch(val):
|
def _maybe_fetch(val):
|
||||||
@ -1052,10 +1240,7 @@ class Client(object):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
# TODO(yuefengz): we should fetch values in a batch.
|
# TODO(yuefengz): we should fetch values in a batch.
|
||||||
result = nest.map_structure(_maybe_fetch, val)
|
return nest.map_structure(_maybe_fetch, val)
|
||||||
if not isinstance(result, tuple):
|
|
||||||
return (result,)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=missing-function-docstring
|
# pylint: disable=missing-function-docstring
|
||||||
@ -1075,13 +1260,14 @@ def handle_parameter_server_failure():
|
|||||||
class _PerWorkerDistributedDataset(object):
|
class _PerWorkerDistributedDataset(object):
|
||||||
"""Represents worker-distributed datasets created from dataset function."""
|
"""Represents worker-distributed datasets created from dataset function."""
|
||||||
|
|
||||||
def __init__(self, dataset_fn, input_workers, client):
|
def __init__(self, dataset_fn, input_workers, coordinator):
|
||||||
"""Makes an iterable from datasets created by the given function.
|
"""Makes an iterable from datasets created by the given function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_fn: A function that returns a `Dataset`.
|
dataset_fn: A function that returns a `Dataset`.
|
||||||
input_workers: an `InputWorkers` object.
|
input_workers: an `InputWorkers` object.
|
||||||
client: a `Client` object, used to create dataset resources.
|
coordinator: a `ClusterCoordinator` object, used to create dataset
|
||||||
|
resources.
|
||||||
"""
|
"""
|
||||||
def disallow_variable_creation(next_creator, **kwargs):
|
def disallow_variable_creation(next_creator, **kwargs):
|
||||||
raise ValueError("Creating variables in `dataset_fn` is not allowed.")
|
raise ValueError("Creating variables in `dataset_fn` is not allowed.")
|
||||||
@ -1094,7 +1280,7 @@ class _PerWorkerDistributedDataset(object):
|
|||||||
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
|
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
|
||||||
self._dataset_fn = dataset_fn
|
self._dataset_fn = dataset_fn
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._client = client
|
self._coordinator = coordinator
|
||||||
self._element_spec = None
|
self._element_spec = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -1112,7 +1298,7 @@ class _PerWorkerDistributedDataset(object):
|
|||||||
# If _PerWorkerDistributedDataset.__iter__ is called multiple
|
# If _PerWorkerDistributedDataset.__iter__ is called multiple
|
||||||
# times, for the same object it should only create and register resource
|
# times, for the same object it should only create and register resource
|
||||||
# once. Using object id to distinguish different iterator resources.
|
# once. Using object id to distinguish different iterator resources.
|
||||||
per_worker_iterator = self._client._create_per_worker_resources(
|
per_worker_iterator = self._coordinator._create_per_worker_resources(
|
||||||
_create_per_worker_iterator)
|
_create_per_worker_iterator)
|
||||||
|
|
||||||
# Setting type_spec of each RemoteValue so that functions taking these
|
# Setting type_spec of each RemoteValue so that functions taking these
|
||||||
@ -1131,7 +1317,7 @@ class _PerWorkerDistributedDataset(object):
|
|||||||
|
|
||||||
|
|
||||||
class _PerWorkerDistributedIterator(PerWorkerValues):
|
class _PerWorkerDistributedIterator(PerWorkerValues):
|
||||||
"""Distributed iterator for `Client`."""
|
"""Distributed iterator for `ClusterCoordinator`."""
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
return self.get_next()
|
return self.get_next()
|
@ -13,21 +13,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`."""
|
"""Multi-process runner tests for `ClusterCoordinator` with PSv2."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.client import client as client_lib
|
|
||||||
from tensorflow.python.distribute.client import utils
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||||
|
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||||
|
from tensorflow.python.distribute.coordinator import utils
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -38,7 +37,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
class ClientMprTest(test.TestCase):
|
class ClusterCoordinatorMprTest(test.TestCase):
|
||||||
|
|
||||||
def testScheduleTranslatePSFailureError(self):
|
def testScheduleTranslatePSFailureError(self):
|
||||||
self._test_translate_ps_failure_error(test_schedule=True)
|
self._test_translate_ps_failure_error(test_schedule=True)
|
||||||
@ -56,7 +55,7 @@ class ClientMprTest(test.TestCase):
|
|||||||
utils.start_server(cluster_resolver, "grpc")
|
utils.start_server(cluster_resolver, "grpc")
|
||||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
ps_client = client_lib.Client(strategy)
|
ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||||
@ -68,19 +67,19 @@ class ClientMprTest(test.TestCase):
|
|||||||
v.assign_add(1)
|
v.assign_add(1)
|
||||||
|
|
||||||
# Keep the two workers occupied.
|
# Keep the two workers occupied.
|
||||||
ps_client.schedule(worker_fn)
|
ps_coordinator.schedule(worker_fn)
|
||||||
ps_client.schedule(worker_fn)
|
ps_coordinator.schedule(worker_fn)
|
||||||
# Now the main process can terminate.
|
# Now the main process can terminate.
|
||||||
functions_scheduled_event.set()
|
functions_scheduled_event.set()
|
||||||
|
|
||||||
# Verified that join and schedule indeed raise UnavailableError.
|
# Verified that join and schedule indeed raise UnavailableError.
|
||||||
try:
|
try:
|
||||||
if test_join:
|
if test_join:
|
||||||
ps_client.join()
|
ps_coordinator.join()
|
||||||
if test_schedule:
|
if test_schedule:
|
||||||
while ps_client.cluster._closure_queue._error is None:
|
while ps_coordinator.cluster._closure_queue._error is None:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
ps_client.schedule(worker_fn)
|
ps_coordinator.schedule(worker_fn)
|
||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
# The following verifies that after PS fails, continue executing
|
# The following verifies that after PS fails, continue executing
|
||||||
# functions on workers should fail and indicate it's PS failure.
|
# functions on workers should fail and indicate it's PS failure.
|
||||||
@ -91,7 +90,7 @@ class ClientMprTest(test.TestCase):
|
|||||||
# failure.
|
# failure.
|
||||||
worker_fn()
|
worker_fn()
|
||||||
except Exception as e: # pylint: disable=broad-except
|
except Exception as e: # pylint: disable=broad-except
|
||||||
if client_lib._is_ps_failure(e):
|
if coordinator_lib._is_ps_failure(e):
|
||||||
if worker_id < 2:
|
if worker_id < 2:
|
||||||
continue
|
continue
|
||||||
logging.info("_test_translate_ps_failure_error ends properly.")
|
logging.info("_test_translate_ps_failure_error ends properly.")
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for client.py."""
|
"""Tests for coordinator.py."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -29,8 +29,8 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.client import client as client_lib
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
|
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||||
from tensorflow.python.eager import cancellation
|
from tensorflow.python.eager import cancellation
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
@ -52,7 +52,7 @@ from tensorflow.python.util import nest
|
|||||||
class CoordinatedClosureQueueTest(test.TestCase):
|
class CoordinatedClosureQueueTest(test.TestCase):
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
queue = client_lib._CoordinatedClosureQueue()
|
queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
closure1 = self._create_closure(queue._cancellation_mgr)
|
closure1 = self._create_closure(queue._cancellation_mgr)
|
||||||
queue.put(closure1)
|
queue.put(closure1)
|
||||||
self.assertIs(closure1, queue.get())
|
self.assertIs(closure1, queue.get())
|
||||||
@ -64,7 +64,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
queue.wait()
|
queue.wait()
|
||||||
|
|
||||||
def testProcessAtLeaseOnce(self):
|
def testProcessAtLeaseOnce(self):
|
||||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
labels = ['A', 'B', 'C', 'D', 'E']
|
labels = ['A', 'B', 'C', 'D', 'E']
|
||||||
processed_count = collections.defaultdict(int)
|
processed_count = collections.defaultdict(int)
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
|
|
||||||
cm = cancellation.CancellationManager()
|
cm = cancellation.CancellationManager()
|
||||||
for label in labels:
|
for label in labels:
|
||||||
closure_queue.put(client_lib.Closure(get_func(label), cm))
|
closure_queue.put(coordinator_lib.Closure(get_func(label), cm))
|
||||||
t1 = threading.Thread(target=process_queue, daemon=True)
|
t1 = threading.Thread(target=process_queue, daemon=True)
|
||||||
t1.start()
|
t1.start()
|
||||||
t2 = threading.Thread(target=process_queue, daemon=True)
|
t2 = threading.Thread(target=process_queue, daemon=True)
|
||||||
@ -111,7 +111,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
coord.join([t1, t2])
|
coord.join([t1, t2])
|
||||||
|
|
||||||
def testNotifyBeforeWait(self):
|
def testNotifyBeforeWait(self):
|
||||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
|
|
||||||
def func():
|
def func():
|
||||||
logging.info('func running')
|
logging.info('func running')
|
||||||
@ -123,7 +123,8 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
closure_queue.get()
|
closure_queue.get()
|
||||||
closure_queue.mark_finished()
|
closure_queue.mark_finished()
|
||||||
|
|
||||||
closure_queue.put(client_lib.Closure(func, closure_queue._cancellation_mgr))
|
closure_queue.put(
|
||||||
|
coordinator_lib.Closure(func, closure_queue._cancellation_mgr))
|
||||||
t = threading.Thread(target=process_queue)
|
t = threading.Thread(target=process_queue)
|
||||||
t.start()
|
t.start()
|
||||||
coord.join([t])
|
coord.join([t])
|
||||||
@ -159,7 +160,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
# TODO(b/165013260): Fix this
|
# TODO(b/165013260): Fix this
|
||||||
self.skipTest('Test is currently broken on Windows with Python 3.8')
|
self.skipTest('Test is currently broken on Windows with Python 3.8')
|
||||||
|
|
||||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
|
closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
|
||||||
closure = closure_queue.get()
|
closure = closure_queue.get()
|
||||||
|
|
||||||
@ -189,10 +190,10 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
def some_function():
|
def some_function():
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
return client_lib.Closure(some_function, cancellation_mgr)
|
return coordinator_lib.Closure(some_function, cancellation_mgr)
|
||||||
|
|
||||||
def _put_two_closures_and_get_one(self):
|
def _put_two_closures_and_get_one(self):
|
||||||
closure_queue = client_lib._CoordinatedClosureQueue()
|
closure_queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
closure1 = self._create_closure(closure_queue._cancellation_mgr)
|
closure1 = self._create_closure(closure_queue._cancellation_mgr)
|
||||||
closure_queue.put(closure1)
|
closure_queue.put(closure1)
|
||||||
|
|
||||||
@ -330,7 +331,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
|
|
||||||
# Closure2 was an inflight closure when it got cancelled.
|
# Closure2 was an inflight closure when it got cancelled.
|
||||||
self.assertEqual(closure2._output_remote_values._status,
|
self.assertEqual(closure2._output_remote_values._status,
|
||||||
client_lib._RemoteValueStatus.READY)
|
coordinator_lib._RemoteValueStatus.READY)
|
||||||
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
|
with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
|
||||||
closure2._fetch_output_remote_values()
|
closure2._fetch_output_remote_values()
|
||||||
|
|
||||||
@ -361,7 +362,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
|||||||
|
|
||||||
def testThreadSafey(self):
|
def testThreadSafey(self):
|
||||||
thread_count = 10
|
thread_count = 10
|
||||||
queue = client_lib._CoordinatedClosureQueue()
|
queue = coordinator_lib._CoordinatedClosureQueue()
|
||||||
|
|
||||||
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
|
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
|
||||||
# `mark_finished`.
|
# `mark_finished`.
|
||||||
@ -427,7 +428,7 @@ class TestCaseWithErrorReportingThread(test.TestCase):
|
|||||||
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
|
raise ErrorReportingThread.error # pylint: disable=raising-bad-type
|
||||||
|
|
||||||
|
|
||||||
def make_client(num_workers, num_ps):
|
def make_coordinator(num_workers, num_ps):
|
||||||
# TODO(rchao): Test the internal rpc_layer version.
|
# TODO(rchao): Test the internal rpc_layer version.
|
||||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||||
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
|
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
|
||||||
@ -438,16 +439,16 @@ def make_client(num_workers, num_ps):
|
|||||||
ClusterSpec(cluster_def), rpc_layer='grpc')
|
ClusterSpec(cluster_def), rpc_layer='grpc')
|
||||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
return client_lib.Client(strategy)
|
return coordinator_lib.ClusterCoordinator(strategy)
|
||||||
|
|
||||||
|
|
||||||
class ClientTest(TestCaseWithErrorReportingThread):
|
class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(ClientTest, cls).setUpClass()
|
super(ClusterCoordinatorTest, cls).setUpClass()
|
||||||
cls.client = make_client(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
cls.strategy = cls.client.strategy
|
cls.strategy = cls.coordinator.strategy
|
||||||
|
|
||||||
def testFnReturnNestedValues(self):
|
def testFnReturnNestedValues(self):
|
||||||
x = constant_op.constant(1)
|
x = constant_op.constant(1)
|
||||||
@ -456,9 +457,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
def f():
|
def f():
|
||||||
return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
|
return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
|
||||||
|
|
||||||
got = self.client.schedule(f)
|
got = self.coordinator.schedule(f)
|
||||||
want = 2, (3, 4), [5], {'v': 1}
|
want = 2, (3, 4), [5], {'v': 1}
|
||||||
self.assertEqual(self.client.fetch(got), want)
|
self.assertEqual(self.coordinator.fetch(got), want)
|
||||||
|
|
||||||
def testInputFunction(self):
|
def testInputFunction(self):
|
||||||
|
|
||||||
@ -474,12 +475,14 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
v.assign_add(x)
|
v.assign_add(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
result = self.coordinator.schedule(
|
||||||
result = self.client.fetch(result)
|
worker_fn, args=(iter(distributed_dataset),))
|
||||||
|
result = self.coordinator.fetch(result)
|
||||||
self.assertEqual(result, (1,))
|
self.assertEqual(result, (1,))
|
||||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
result = self.coordinator.schedule(
|
||||||
result = self.client.fetch(result)
|
worker_fn, args=(iter(distributed_dataset),))
|
||||||
|
result = self.coordinator.fetch(result)
|
||||||
self.assertEqual(result, (1,))
|
self.assertEqual(result, (1,))
|
||||||
|
|
||||||
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
|
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
|
||||||
@ -499,30 +502,30 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
x = next(iterator)
|
x = next(iterator)
|
||||||
v.assign_add(x)
|
v.assign_add(x)
|
||||||
|
|
||||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
|
|
||||||
iterator = iter(distributed_dataset)
|
iterator = iter(distributed_dataset)
|
||||||
|
|
||||||
# Verifying joining without any scheduling doesn't hang.
|
# Verifying joining without any scheduling doesn't hang.
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
self.assertEqual(v.read_value().numpy(), 0)
|
self.assertEqual(v.read_value().numpy(), 0)
|
||||||
|
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
self.client.schedule(worker_fn, args=(iterator,))
|
self.coordinator.schedule(worker_fn, args=(iterator,))
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
# With 5 addition it should be 2*5 = 10.
|
# With 5 addition it should be 2*5 = 10.
|
||||||
self.assertEqual(v.read_value().numpy(), 10)
|
self.assertEqual(v.read_value().numpy(), 10)
|
||||||
|
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
self.client.schedule(worker_fn, args=(iterator,))
|
self.coordinator.schedule(worker_fn, args=(iterator,))
|
||||||
|
|
||||||
# Verifying multiple join is fine.
|
# Verifying multiple join is fine.
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
self.assertTrue(self.client.done())
|
self.assertTrue(self.coordinator.done())
|
||||||
|
|
||||||
# Likewise, it's now 20.
|
# Likewise, it's now 20.
|
||||||
self.assertEqual(v.read_value().numpy(), 20)
|
self.assertEqual(v.read_value().numpy(), 20)
|
||||||
@ -542,8 +545,9 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
def worker_fn(iterator):
|
def worker_fn(iterator):
|
||||||
return next(iterator)
|
return next(iterator)
|
||||||
|
|
||||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
result = self.coordinator.schedule(
|
||||||
|
worker_fn, args=(iter(distributed_dataset),))
|
||||||
self.assertEqual(result.fetch(), (10,))
|
self.assertEqual(result.fetch(), (10,))
|
||||||
self.assertEqual(self._map_fn_tracing_count, 1)
|
self.assertEqual(self._map_fn_tracing_count, 1)
|
||||||
|
|
||||||
@ -554,28 +558,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
return v.read_value()
|
return v.read_value()
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.client.create_per_worker_dataset(input_fn)
|
self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
|
|
||||||
def testDatasetsShuffledDifferently(self):
|
def testDatasetsShuffledDifferently(self):
|
||||||
# This test requires at least two workers in the cluster.
|
# This test requires at least two workers in the cluster.
|
||||||
self.assertGreaterEqual(len(self.client.cluster.workers), 2)
|
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
|
||||||
|
|
||||||
random_seed.set_random_seed(None)
|
random_seed.set_random_seed(None)
|
||||||
|
|
||||||
def input_fn():
|
def input_fn():
|
||||||
return dataset_ops.DatasetV2.range(0, 100).shuffle(100)
|
return dataset_ops.DatasetV2.range(0, 100).shuffle(100)
|
||||||
|
|
||||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
distributed_iterator = iter(distributed_dataset)
|
distributed_iterator = iter(distributed_dataset)
|
||||||
|
|
||||||
# Get elements from the first two iterators.
|
# Get elements from the first two iterators.
|
||||||
iterator_1 = distributed_iterator._values[0]
|
iterator_1 = distributed_iterator._values[0]
|
||||||
iterator_1._rebuild_on(self.client.cluster.workers[0])
|
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
|
||||||
iterator_1 = iterator_1.fetch()
|
iterator_1 = iterator_1.fetch()
|
||||||
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
||||||
|
|
||||||
iterator_2 = distributed_iterator._values[1]
|
iterator_2 = distributed_iterator._values[1]
|
||||||
iterator_2._rebuild_on(self.client.cluster.workers[1])
|
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
|
||||||
iterator_2 = iterator_2.fetch()
|
iterator_2 = iterator_2.fetch()
|
||||||
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
||||||
|
|
||||||
@ -592,12 +596,12 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
self.assertIn('worker', var.device)
|
self.assertIn('worker', var.device)
|
||||||
return var
|
return var
|
||||||
|
|
||||||
worker_local_var = self.client._create_per_worker_resources(create_var)
|
worker_local_var = self.coordinator._create_per_worker_resources(create_var)
|
||||||
|
|
||||||
# The following is a workaround to allow `worker_local_var` to be passed in
|
# The following is a workaround to allow `worker_local_var` to be passed in
|
||||||
# as args to the `client.schedule` method which requires tensor specs to
|
# as args to the `coordinator.schedule` method which requires tensor specs
|
||||||
# trace tf.function but _create_worker_resources' return values don't have
|
# to trace tf.function but _create_worker_resources' return values don't
|
||||||
# tensor specs. We can get rid of this workaround once
|
# have tensor specs. We can get rid of this workaround once
|
||||||
# _create_worker_resources is able to infer the tensor spec of the return
|
# _create_worker_resources is able to infer the tensor spec of the return
|
||||||
# value of the function passed in. See b/154675763.
|
# value of the function passed in. See b/154675763.
|
||||||
for var in worker_local_var._values:
|
for var in worker_local_var._values:
|
||||||
@ -609,10 +613,10 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
# Which slice of `worker_local_var` will be used will depend on which
|
# Which slice of `worker_local_var` will be used will depend on which
|
||||||
# worker the `worker_fn` gets scheduled on.
|
# worker the `worker_fn` gets scheduled on.
|
||||||
self.client.schedule(worker_fn, args=(worker_local_var,))
|
self.coordinator.schedule(worker_fn, args=(worker_local_var,))
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
var_sum = sum(self.client.fetch(worker_local_var._values))
|
var_sum = sum(self.coordinator.fetch(worker_local_var._values))
|
||||||
self.assertEqual(var_sum, 10.0)
|
self.assertEqual(var_sum, 10.0)
|
||||||
|
|
||||||
def testDisallowRemoteValueAsInput(self):
|
def testDisallowRemoteValueAsInput(self):
|
||||||
@ -625,28 +629,28 @@ class ClientTest(TestCaseWithErrorReportingThread):
|
|||||||
def func_1(x):
|
def func_1(x):
|
||||||
return x + 1.0
|
return x + 1.0
|
||||||
|
|
||||||
remote_v = self.client.schedule(func_0)
|
remote_v = self.coordinator.schedule(func_0)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.client.schedule(func_1, args=(remote_v,))
|
self.coordinator.schedule(func_1, args=(remote_v,))
|
||||||
|
|
||||||
|
|
||||||
class LimitedClosureQueueSizeBasicTest(ClientTest):
|
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
|
||||||
"""Test basic functionality works with explicit maximum closure queue size.
|
"""Test basic functionality works with explicit maximum closure queue size.
|
||||||
|
|
||||||
Execute the same set of test cases as in `ClientTest`, with an
|
Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
|
||||||
explicit size limit for the closure queue. Note that even when the queue size
|
explicit size limit for the closure queue. Note that even when the queue size
|
||||||
is set to infinite, there is still a maximum practical size (depends on host
|
is set to infinite, there is still a maximum practical size (depends on host
|
||||||
memory limit) that might cause the queue.put operations to be blocking when
|
memory limit) that might cause the queue.put operations to be blocking when
|
||||||
scheduling a large number of closures on a big cluster. These tests make sure
|
scheduling a large number of closures on a big cluster. These tests make sure
|
||||||
that the client does not run into deadlocks in such scenario.
|
that the coordinator does not run into deadlocks in such scenario.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
|
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
|
||||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||||
cls.client = make_client(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
cls.strategy = cls.client.strategy
|
cls.strategy = cls.coordinator.strategy
|
||||||
|
|
||||||
|
|
||||||
class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||||
@ -654,8 +658,8 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(ErrorReportingTest, cls).setUpClass()
|
super(ErrorReportingTest, cls).setUpClass()
|
||||||
cls.client = make_client(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
cls.strategy = cls.client.strategy
|
cls.strategy = cls.coordinator.strategy
|
||||||
|
|
||||||
with cls.strategy.scope():
|
with cls.strategy.scope():
|
||||||
cls.iteration = variables.Variable(initial_value=0.0)
|
cls.iteration = variables.Variable(initial_value=0.0)
|
||||||
@ -686,80 +690,80 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
|||||||
|
|
||||||
def testJoinRaiseError(self):
|
def testJoinRaiseError(self):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
def testScheduleRaiseError(self):
|
def testScheduleRaiseError(self):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
while True:
|
while True:
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
|
|
||||||
def testScheduleRaiseErrorWithMultipleFailure(self):
|
def testScheduleRaiseErrorWithMultipleFailure(self):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
while True:
|
while True:
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
def testErrorWillbeCleared(self):
|
def testErrorWillbeCleared(self):
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
def testRemoteValueReturnError(self):
|
def testRemoteValueReturnError(self):
|
||||||
result = self.client.schedule(self._error_function)
|
result = self.coordinator.schedule(self._error_function)
|
||||||
|
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
result.fetch()
|
result.fetch()
|
||||||
|
|
||||||
# Clear the error.
|
# Clear the error.
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
def testInputError(self):
|
def testInputError(self):
|
||||||
|
|
||||||
worker_local_val = self.client._create_per_worker_resources(
|
worker_local_val = self.coordinator._create_per_worker_resources(
|
||||||
self._error_function)
|
self._error_function)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def func(x):
|
def func(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
result = self.client.schedule(func, args=(worker_local_val,))
|
result = self.coordinator.schedule(func, args=(worker_local_val,))
|
||||||
with self.assertRaises(client_lib.InputError):
|
with self.assertRaises(coordinator_lib.InputError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
with self.assertRaises(client_lib.InputError):
|
with self.assertRaises(coordinator_lib.InputError):
|
||||||
result.fetch()
|
result.fetch()
|
||||||
|
|
||||||
def testCancellation(self):
|
def testCancellation(self):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
long_function = self.client.schedule(self._long_function)
|
long_function = self.coordinator.schedule(self._long_function)
|
||||||
self.client.schedule(self._error_function)
|
self.coordinator.schedule(self._error_function)
|
||||||
|
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
with self.assertRaises(errors.CancelledError):
|
with self.assertRaises(errors.CancelledError):
|
||||||
long_function.fetch()
|
long_function.fetch()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.client.schedule(self._normal_function)
|
self.coordinator.schedule(self._normal_function)
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
|
|
||||||
class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
||||||
@ -772,11 +776,11 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(LimitedClosureQueueErrorTest, cls).setUpClass()
|
super(LimitedClosureQueueErrorTest, cls).setUpClass()
|
||||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||||
cls.client = make_client(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
cls.strategy = cls.client.strategy
|
cls.strategy = cls.coordinator.strategy
|
||||||
|
|
||||||
with cls.client.strategy.scope():
|
with cls.coordinator.strategy.scope():
|
||||||
cls.iteration = variables.Variable(initial_value=0.0)
|
cls.iteration = variables.Variable(initial_value=0.0)
|
||||||
|
|
||||||
|
|
||||||
@ -785,8 +789,8 @@ class StrategyIntegrationTest(test.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(StrategyIntegrationTest, cls).setUpClass()
|
super(StrategyIntegrationTest, cls).setUpClass()
|
||||||
cls.client = make_client(num_workers=1, num_ps=1)
|
cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
|
||||||
cls.strategy = cls.client.strategy
|
cls.strategy = cls.coordinator.strategy
|
||||||
|
|
||||||
def testBasicVariableAssignment(self):
|
def testBasicVariableAssignment(self):
|
||||||
self.strategy.extended._variable_count = 0
|
self.strategy.extended._variable_count = 0
|
||||||
@ -801,9 +805,9 @@ class StrategyIntegrationTest(test.TestCase):
|
|||||||
v2.assign_sub(0.2)
|
v2.assign_sub(0.2)
|
||||||
return v1.read_value() / v2.read_value()
|
return v1.read_value() / v2.read_value()
|
||||||
|
|
||||||
results = self.client.schedule(worker_fn)
|
results = self.coordinator.schedule(worker_fn)
|
||||||
logging.info('Results of experimental_run_v2: %f',
|
logging.info('Results of experimental_run_v2: %f',
|
||||||
self.client.fetch(results))
|
self.coordinator.fetch(results))
|
||||||
|
|
||||||
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
|
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
|
||||||
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
|
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
|
||||||
@ -826,12 +830,14 @@ class StrategyIntegrationTest(test.TestCase):
|
|||||||
return self.strategy.run(replica_fn, args=(input_tensor,))
|
return self.strategy.run(replica_fn, args=(input_tensor,))
|
||||||
|
|
||||||
# Asserting scheduling in scope has the expected behavior.
|
# Asserting scheduling in scope has the expected behavior.
|
||||||
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
|
result = self.coordinator.schedule(
|
||||||
self.assertIsInstance(result, client_lib.RemoteValue)
|
worker_fn, args=(constant_op.constant(3),))
|
||||||
|
self.assertIsInstance(result, coordinator_lib.RemoteValue)
|
||||||
self.assertEqual(result.fetch(), 4)
|
self.assertEqual(result.fetch(), 4)
|
||||||
|
|
||||||
# Asserting scheduling out of scope has the expected behavior.
|
# Asserting scheduling out of scope has the expected behavior.
|
||||||
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
|
result = self.coordinator.schedule(
|
||||||
|
worker_fn, args=(constant_op.constant(3),))
|
||||||
self.assertEqual(result.fetch(), 4)
|
self.assertEqual(result.fetch(), 4)
|
||||||
|
|
||||||
|
|
@ -31,15 +31,15 @@ enable_metrics = False
|
|||||||
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
|
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
|
||||||
|
|
||||||
_function_tracing_sampler = monitoring.Sampler(
|
_function_tracing_sampler = monitoring.Sampler(
|
||||||
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets,
|
'/tensorflow/api/ps_strategy/coordinator/function_tracing', _time_buckets,
|
||||||
'Sampler to track the time (in seconds) for tracing functions.')
|
'Sampler to track the time (in seconds) for tracing functions.')
|
||||||
|
|
||||||
_closure_execution_sampler = monitoring.Sampler(
|
_closure_execution_sampler = monitoring.Sampler(
|
||||||
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets,
|
'/tensorflow/api/ps_strategy/coordinator/closure_execution', _time_buckets,
|
||||||
'Sampler to track the time (in seconds) for executing closures.')
|
'Sampler to track the time (in seconds) for executing closures.')
|
||||||
|
|
||||||
_remote_value_fetch_sampler = monitoring.Sampler(
|
_remote_value_fetch_sampler = monitoring.Sampler(
|
||||||
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets,
|
'/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', _time_buckets,
|
||||||
'Sampler to track the time (in seconds) for fetching remote_value.')
|
'Sampler to track the time (in seconds) for fetching remote_value.')
|
||||||
|
|
||||||
_METRICS_MAPPING = {
|
_METRICS_MAPPING = {
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for metrics collecting in client."""
|
"""Tests for metrics collecting in coordinator."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -22,9 +22,9 @@ from __future__ import print_function
|
|||||||
import time
|
import time
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.client import client
|
|
||||||
from tensorflow.python.distribute.client import metric_utils
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
|
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||||
|
from tensorflow.python.distribute.coordinator import metric_utils
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.training.server_lib import ClusterSpec
|
from tensorflow.python.training.server_lib import ClusterSpec
|
||||||
@ -35,7 +35,7 @@ class MetricUtilsTest(test.TestCase):
|
|||||||
def get_rpc_layer(self):
|
def get_rpc_layer(self):
|
||||||
return 'grpc'
|
return 'grpc'
|
||||||
|
|
||||||
def testClientMetrics(self):
|
def testClusterCoordinatorMetrics(self):
|
||||||
|
|
||||||
metric_utils.enable_metrics = True
|
metric_utils.enable_metrics = True
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class MetricUtilsTest(test.TestCase):
|
|||||||
ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
|
ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
|
||||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
cluster = client.Cluster(strategy)
|
cluster = coordinator_lib.Cluster(strategy)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def func():
|
def func():
|
@ -48,7 +48,7 @@ _LOCAL_CPU = "/device:CPU:0"
|
|||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): maybe cache variables on local CPU.
|
# TODO(yuefengz): maybe cache variables on local CPU.
|
||||||
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
|
# TODO(b/171250971): Remove this and change all symbol usage of this to V1.
|
||||||
class ParameterServerStrategy(distribute_lib.Strategy):
|
class ParameterServerStrategy(distribute_lib.Strategy):
|
||||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||||
|
|
||||||
@ -154,8 +154,9 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
|||||||
|
|
||||||
def _raise_pss_error_if_eager(self):
|
def _raise_pss_error_if_eager(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
raise NotImplementedError("ParameterServerStrategy currently only works "
|
raise NotImplementedError(
|
||||||
"with the tf.Estimator API")
|
"`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
|
||||||
|
"currently only works with the tf.Estimator API")
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
|
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
|
||||||
|
@ -37,82 +37,400 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
|
||||||
class ParameterServerStrategyV2(distribute_lib.Strategy):
|
class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
"""An multi-worker tf.distribute strategy with parameter servers.
|
||||||
|
|
||||||
Currently, `ParameterServerStrategyV2` is not supported to be used as a
|
Parameter server training refers to the distributed training architecture that
|
||||||
standalone tf.distribute strategy. It should be used in conjunction with
|
requires two types of tasks in the cluster: workers (referred to as "worker"
|
||||||
`Client`. Please see `Client` for more information.
|
task) and parameter servers (referred to as "ps" task). The variables and
|
||||||
|
updates to those variables are placed on ps, and most computation intensive
|
||||||
|
operations are placed on workers.
|
||||||
|
|
||||||
This is currently under development, and the API as well as implementation
|
In TF2, parameter server training makes use of one coordinator, with some
|
||||||
is subject to changes.
|
number of workers, and (usually fewer) ps. The coordinator uses a
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
|
||||||
|
cluster, and a `tf.distribute.experimental.ParameterServerStrategy` for
|
||||||
|
variable distribution. The coordinator does not perform the actual training.
|
||||||
|
Each of the workers and ps runs a `tf.distribute.Server`, which the
|
||||||
|
coordinator connects to through the use of aforementioned two APIs.
|
||||||
|
|
||||||
|
For the training to work, the coordinator sends requests to workers for the
|
||||||
|
`tf.function`s to be executed on remote workers. Upon receiving requests from
|
||||||
|
the coordinator, a worker executes the `tf.function` by reading the variables
|
||||||
|
from parameter servers, executing the ops, and updating the variables on the
|
||||||
|
parameter servers. Each of the worker only processes the requests from the
|
||||||
|
coordinator, and communicates with parameter servers, without direct
|
||||||
|
interactions with any of the other workers in the cluster.
|
||||||
|
|
||||||
|
As a result, failures of some workers do not prevent the cluster from
|
||||||
|
continuing the work, and this allows the cluster to train with instances that
|
||||||
|
can be occasionally unavailable (e.g. preemptible or spot instances). The
|
||||||
|
coordinator and parameter servers though, must be available at all times for
|
||||||
|
the cluster to make progress.
|
||||||
|
|
||||||
|
Note that the coordinator is not one of the training worker. Instead, its
|
||||||
|
responsibility includes placing variables on ps, remotely executing
|
||||||
|
`tf.function`s on workers, and saving checkpoints. Parameter server training
|
||||||
|
thus consists of a server cluster with worker and ps, and a coordinator which
|
||||||
|
connects to them to coordinate. Optionally, an evaluator can be run on the
|
||||||
|
side that periodically reads the checkpoints saved by the coordinator, and
|
||||||
|
saves summaries for example.
|
||||||
|
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy` works closely with the
|
||||||
|
associated `tf.distribute.experimental.coordinator.ClusterCoordinator` object,
|
||||||
|
and should be used in conjunction with it. Standalone usage of
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy` without a
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator` indicates
|
||||||
|
a parameter server training scheme without a centralized coordinator, which is
|
||||||
|
not supported at this time.
|
||||||
|
|
||||||
|
__Example code for coordinator__
|
||||||
|
|
||||||
|
Here's an example usage of the API, with a custom training loop to train a
|
||||||
|
model. This code snippet is intended to be run on (the only) one machine that
|
||||||
|
is designated as the coordinator. Note that `cluster_resolver`,
|
||||||
|
`variable_partitioner`, and `dataset_fn` arguments are explained in the
|
||||||
|
following "Cluster setup", "Variable partitioning", and "Dataset preparation"
|
||||||
|
sections.
|
||||||
|
|
||||||
|
Currently, environment variable `GRPC_FAIL_FAST` needs to be set in all tasks
|
||||||
|
to work around a known hanging issue as the following code illustrates:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Set the environment variable to allow reporting worker and ps failure to the
|
||||||
|
# coordinator.
|
||||||
|
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||||
|
|
||||||
|
# Prepare a strategy to use with the cluster and variable partitioning info.
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver=...,
|
||||||
|
variable_partitioner=...)
|
||||||
|
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||||
|
strategy=strategy)
|
||||||
|
|
||||||
|
# Prepare a distribute dataset that will place datasets on the workers.
|
||||||
|
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
model = ... # Variables created can possibly be container of variables
|
||||||
|
optimizer, metrics = ... # Keras optimizer/metrics are great choices
|
||||||
|
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
|
||||||
|
checkpoint_manager = tf.train.CheckpointManager(
|
||||||
|
checkpoint, checkpoint_dir, max_to_keep=2)
|
||||||
|
# `load_checkpoint` infers initial epoch from `optimizer.iterations`.
|
||||||
|
initial_epoch = load_checkpoint(checkpoint_manager) or 0
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def worker_fn(iterator):
|
||||||
|
|
||||||
|
def replica_fn(inputs):
|
||||||
|
batch_data, labels = inputs
|
||||||
|
# calculate gradient, applying gradient, metrics update etc.
|
||||||
|
|
||||||
|
strategy.run(replica_fn, args=(next(iterator),))
|
||||||
|
|
||||||
|
for epoch in range(initial_epoch, num_epoch):
|
||||||
|
distributed_iterator = iter(distributed_dataset) # Reset iterator state.
|
||||||
|
for step in range(steps_per_epoch):
|
||||||
|
|
||||||
|
# Asynchronously schedule the `worker_fn` to be executed on an arbitrary
|
||||||
|
# worker. This call returns immediately.
|
||||||
|
coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||||
|
|
||||||
|
# `join` blocks until all scheduled `worker_fn`s finish execution. Once it
|
||||||
|
# returns, we can read the metrics and save checkpoints as needed.
|
||||||
|
coordinator.join()
|
||||||
|
logging.info('Metric result: %r', metrics.result())
|
||||||
|
train_accuracy.reset_states()
|
||||||
|
checkpoint_manager.save()
|
||||||
|
```
|
||||||
|
|
||||||
|
__Example code for worker and parameter servers__
|
||||||
|
|
||||||
|
In addition to the coordinator, there should be multiple machines designated
|
||||||
|
as "worker" or "ps". They should run the following code to start a TensorFlow
|
||||||
|
server, waiting for coordinator's request to execute functions or place
|
||||||
|
variables:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Set the environment variable to allow reporting worker and ps failure to the
|
||||||
|
# coordinator.
|
||||||
|
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||||
|
|
||||||
|
# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
|
||||||
|
# the cluster information. See below "Cluster setup" section.
|
||||||
|
cluster_resolver = ...
|
||||||
|
|
||||||
|
server = tf.distribute.Server(
|
||||||
|
cluster_resolver.cluster_spec().as_cluster_def(),
|
||||||
|
job_name=cluster_resolver.task_type,
|
||||||
|
task_index=cluster_resolver.task_id,
|
||||||
|
protocol=protocol)
|
||||||
|
|
||||||
|
# Blocking the process that starts a server from exiting.
|
||||||
|
server.join()
|
||||||
|
```
|
||||||
|
|
||||||
|
__Cluster setup__
|
||||||
|
|
||||||
|
In order for the tasks in the cluster to know other tasks' addresses,
|
||||||
|
a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
|
||||||
|
in coordinator, worker, and ps. The
|
||||||
|
`tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
|
||||||
|
the cluster information, as well as the task type and id of the current task.
|
||||||
|
See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
|
||||||
|
|
||||||
|
If `TF_CONFIG` environment variable is used for the processes to know the
|
||||||
|
cluster information, a
|
||||||
|
`tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used. Note
|
||||||
|
that for legacy reason, "chief" should be used as the task type for the
|
||||||
|
coordinator, as the following example demonstrates. Here we set `TF_CONFIG`
|
||||||
|
in environment variable, intended to be run by the process of the machine
|
||||||
|
designated as the parameter server (task type "ps") and index 1 (the second),
|
||||||
|
in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that the
|
||||||
|
it needs to be set before the use of
|
||||||
|
`tf.distribute.cluster_resolver.TFConfigClusterResolver`.
|
||||||
|
|
||||||
|
Example code for cluster setup:
|
||||||
|
```python
|
||||||
|
os.environ['TF_CONFIG'] = '''
|
||||||
|
{
|
||||||
|
"cluster": {
|
||||||
|
"chief": ["chief.example.com:2222"],
|
||||||
|
"ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
|
||||||
|
"worker": ["worker0.example.com:2222", "worker1.example.com:2222",
|
||||||
|
"worker2.example.com:2222"]
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"type": "ps",
|
||||||
|
"index": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||||
|
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
|
||||||
|
|
||||||
|
# If coordinator ("chief" task type), create a strategy
|
||||||
|
if cluster_resolver.task_type == 'chief':
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver)
|
||||||
|
...
|
||||||
|
|
||||||
|
# If worker/ps, create a server
|
||||||
|
elif cluster_resolver.task_type in ("worker", "ps"):
|
||||||
|
server = tf.distribute.Server(...)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
__Variable creation with `strategy.scope()`__
|
||||||
|
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy` follows the
|
||||||
|
`tf.distribute` API contract where variable creation is expected to be inside
|
||||||
|
the context manager returned by `strategy.scope()`, in order to be correctly
|
||||||
|
placed on parameter servers in a round-robin manner:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In this example, we're assuming having 3 ps.
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver=...)
|
||||||
|
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
|
||||||
|
strategy=strategy)
|
||||||
|
|
||||||
|
# Variables should be created inside scope to be placed on parameter servers.
|
||||||
|
# If created outside scope such as `v1` here, it would be placed on
|
||||||
|
coordinator.
|
||||||
|
v1 = tf.Variable(initial_value=0.0)
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
v2 = tf.Variable(initial_value=1.0)
|
||||||
|
v3 = tf.Variable(initial_value=2.0)
|
||||||
|
v4 = tf.Variable(initial_value=3.0)
|
||||||
|
v5 = tf.Variable(initial_value=4.0)
|
||||||
|
|
||||||
|
# v2 through v5 are created in scope and are distributed on parameter servers.
|
||||||
|
# Default placement is round-robin but the order should not be relied on.
|
||||||
|
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||||
|
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
|
||||||
|
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
|
||||||
|
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||||
|
```
|
||||||
|
|
||||||
|
See `distribute.Strategy.scope` for more information.
|
||||||
|
|
||||||
|
__Variable partitioning__
|
||||||
|
|
||||||
|
Having dedicated servers to store variables means being able to divide up, or
|
||||||
|
"shard" the variables across the ps. Large embeddings that would otherwise
|
||||||
|
exceed memory limit of a single machine can be used in a cluster with enough
|
||||||
|
number of ps.
|
||||||
|
|
||||||
|
With `tf.distribute.experimental.ParameterServerStrategy`, if a
|
||||||
|
`variable_partitioner` is provided to `__init__` and certain conditions are
|
||||||
|
satisfied, the resulting variables created in scope are sharded across the
|
||||||
|
parameter servers, in a round-robin fashion. The variable reference returned
|
||||||
|
from `tf.Variable` becomes a type that serves as the container of the sharded
|
||||||
|
variables. Access `variables` attribute of this container for the actual
|
||||||
|
variable components. See arguments section of
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy.__init__` for more
|
||||||
|
information.
|
||||||
|
|
||||||
|
To initialize the sharded variables in a more memory-efficient way, use an
|
||||||
|
initializer whose `__call__` accepts a `shard_info` argument, and use
|
||||||
|
`shard_info.offset` and `shard_info.shape` to create and return a
|
||||||
|
partition-aware `tf.Tensor` to initialize the variable components.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PartitionAwareIdentity(object):
|
||||||
|
|
||||||
|
def __call__(self, shape, dtype, shard_info):
|
||||||
|
value = tf.eye(*shape, dtype=dtype)
|
||||||
|
if shard_info is not None:
|
||||||
|
value = tf.slice(value, shard_info.offset, shard_info.shape)
|
||||||
|
return value
|
||||||
|
|
||||||
|
cluster_resolver = ...
|
||||||
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||||
|
cluster_resolver, tf.fixed_size_partitioner(2))
|
||||||
|
with strategy.scope():
|
||||||
|
initializer = PartitionAwareIdentity()
|
||||||
|
initial_value = functools.partial(initializer, shape=(4, 4), dtype=tf.int64)
|
||||||
|
v = tf.Variable(
|
||||||
|
initial_value=initial_value, shape=(4, 4), dtype=tf.int64)
|
||||||
|
|
||||||
|
# `v.variables` gives the actual variable components.
|
||||||
|
assert len(v.variables) == 2
|
||||||
|
assert v.variables[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
|
||||||
|
assert v.variables[1].device == "/job:ps/replica:0/task:1/device:CPU:0"
|
||||||
|
assert np.array_equal(v.variables[0].numpy(), [[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
|
assert np.array_equal(v.variables[1].numpy(), [[0, 0, 1, 0], [0, 0, 0, 1]])
|
||||||
|
```
|
||||||
|
|
||||||
|
__Dataset preparation__
|
||||||
|
|
||||||
|
With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
|
||||||
|
created in each of the workers to be used for training. This is done by
|
||||||
|
creating a `dataset_fn` that takes no argument and returns a
|
||||||
|
`tf.data.Dataset`, and passing the `dataset_fn` into
|
||||||
|
`tf.distribute.experimental.coordinator.
|
||||||
|
ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
|
||||||
|
shuffled and repeated to have the examples run through the training as evenly
|
||||||
|
as possible.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def dataset_fn():
|
||||||
|
filenames = ...
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(filenames)
|
||||||
|
|
||||||
|
# Dataset is recommended to be shuffled, and repeated.
|
||||||
|
return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
|
||||||
|
|
||||||
|
coordinator =
|
||||||
|
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
|
||||||
|
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
__Limitations__
|
||||||
|
|
||||||
|
* `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
|
||||||
|
and the API is subject to further changes.
|
||||||
|
|
||||||
|
* `tf.distribute.experimental.ParameterServerStrategy` does not yet support
|
||||||
|
training with GPU(s). This is a feature request being developed.
|
||||||
|
|
||||||
|
* `tf.distribute.experimental.ParameterServerStrategy` only supports
|
||||||
|
[custom training loop
|
||||||
|
API](https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||||
|
currently in TF2. Usage of it with Keras `compile`/`fit` API is being
|
||||||
|
developed.
|
||||||
|
|
||||||
|
* `tf.distribute.experimental.ParameterServerStrategy` must be used with
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||||
|
|
||||||
|
* This strategy is not intended for TPU. Use
|
||||||
|
`tf.distribute.experimental.TPUStrategy` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyformat: disable
|
||||||
def __init__(self, cluster_resolver, variable_partitioner=None):
|
def __init__(self, cluster_resolver, variable_partitioner=None):
|
||||||
"""Initializes the V2 parameter server strategy.
|
"""Initializes the TF2 parameter server strategy.
|
||||||
|
|
||||||
This also connects to the remote server cluster.
|
This initializes the `tf.distribute.experimental.ParameterServerStrategy`
|
||||||
|
object to be ready for use with
|
||||||
|
`tf.distribute.experimental.coordinator.ClusterCoordinator`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
|
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
|
||||||
object.
|
object.
|
||||||
variable_partitioner: a callable with the signature `num_partitions =
|
variable_partitioner:
|
||||||
fn(shape, dtype)`, where `num_partitions` is a list/tuple representing
|
a callable with the signature `num_partitions = fn(shape, dtype)`, where
|
||||||
the number of partitions on each axis, and `shape` and `dtype` are of
|
`num_partitions` is a list/tuple representing the number of partitions
|
||||||
types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will
|
on each axis, and `shape` and `dtype` are of types `tf.TensorShape` and
|
||||||
not be partitioned. * `variable_partitioner` will be called for all
|
`tf.dtypes.Dtype`. If `None`, variables will not be partitioned.
|
||||||
variables created under strategy `scope` to instruct how the variables
|
|
||||||
should be partitioned. Variables will be partitioned if there are more
|
* `variable_partitioner` will be called for all variables created under
|
||||||
than one partitions along the partitioning axis, otherwise it falls back
|
strategy `scope` to instruct how the variables should be partitioned.
|
||||||
to normal `tf.Variable`. * Only the first / outermost axis partitioning
|
Variables will be created in multiple partitions if there are more than
|
||||||
is supported, namely, elements in `num_partitions` must be 1 other than
|
one partition along the partitioning axis, otherwise it falls back to
|
||||||
the first element. * Partitioner like `min_max_variable_partitioner`,
|
normal `tf.Variable`.
|
||||||
`variable_axis_size_partitioner` and `fixed_size_partitioner` are also
|
|
||||||
supported since they conform to the required signature. * Div partition
|
* Only the first / outermost axis partitioning is supported, namely,
|
||||||
|
elements in `num_partitions` must be 1 other than the first element.
|
||||||
|
|
||||||
|
* Partitioner like `tf.compat.v1.min_max_variable_partitioner`,
|
||||||
|
`tf.compat.v1.variable_axis_size_partitioner` and
|
||||||
|
`tf.compat.v1.fixed_size_partitioner` are also supported since they
|
||||||
|
conform to the required signature.
|
||||||
|
|
||||||
|
* Div partition
|
||||||
strategy is used to partition variables. Assuming we assign consecutive
|
strategy is used to partition variables. Assuming we assign consecutive
|
||||||
integer ids along the first axis of a variable, then ids are assigned to
|
integer ids along the first axis of a variable, then ids are assigned to
|
||||||
shards in a contiguous manner, while attempting to keep each shard size
|
shards in a contiguous manner, while attempting to keep each shard size
|
||||||
identical. If the ids do not evenly divide the number of shards, each of
|
identical. If the ids do not evenly divide the number of shards, each of
|
||||||
the first several shards will be assigned one more id. For instance, a
|
the first several shards will be assigned one more id. For instance, a
|
||||||
variable whose first dimension is
|
variable whose first dimension is 13 has 13 ids, and they are split
|
||||||
13 has 13 ids, and they are split across 5 shards as: `[[0, 1, 2], [3,
|
across 5 shards as:
|
||||||
4, 5], [6, 7, 8], [9, 10], [11, 12]]`. * Variables created under
|
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
|
||||||
`strategy.extended.colocate_vars_with` will not be partitioned, e.g,
|
|
||||||
optimizer's slot variables.
|
* Variables created under `strategy.extended.colocate_vars_with` will
|
||||||
|
not be partitioned, e.g, optimizer's slot variables.
|
||||||
"""
|
"""
|
||||||
|
# pyformat: enable
|
||||||
self._cluster_resolver = cluster_resolver
|
self._cluster_resolver = cluster_resolver
|
||||||
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
|
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
|
||||||
variable_partitioner)
|
variable_partitioner)
|
||||||
self._verify_args_and_config(cluster_resolver)
|
self._verify_args_and_config(cluster_resolver)
|
||||||
logging.info(
|
logging.info(
|
||||||
"ParameterServerStrategyV2 is initialized with cluster_spec: "
|
"`tf.distribute.experimental.ParameterServerStrategy` is initialized "
|
||||||
"%s", cluster_resolver.cluster_spec())
|
"with cluster_spec: %s", cluster_resolver.cluster_spec())
|
||||||
|
|
||||||
# TODO(b/167894802): Make chief, worker, and ps names customizable.
|
# TODO(b/167894802): Make coordinator, worker, and ps names customizable.
|
||||||
self._connect_to_cluster(client_name="chief")
|
self._connect_to_cluster(coordinator_name="chief")
|
||||||
super(ParameterServerStrategyV2, self).__init__(self._extended)
|
super(ParameterServerStrategyV2, self).__init__(self._extended)
|
||||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
||||||
"ParameterServerStrategy")
|
"ParameterServerStrategy")
|
||||||
|
|
||||||
def _connect_to_cluster(self, client_name):
|
def _connect_to_cluster(self, coordinator_name):
|
||||||
if client_name in ["worker", "ps"]:
|
if coordinator_name in ["worker", "ps"]:
|
||||||
raise ValueError("Client name should not be 'worker' or 'ps'.")
|
raise ValueError("coordinator name should not be 'worker' or 'ps'.")
|
||||||
cluster_spec = self._cluster_resolver.cluster_spec()
|
cluster_spec = self._cluster_resolver.cluster_spec()
|
||||||
self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
|
self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
|
||||||
self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
|
self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
|
||||||
|
|
||||||
device_filters = server_lib.ClusterDeviceFilters()
|
device_filters = server_lib.ClusterDeviceFilters()
|
||||||
# For any worker, only the devices on PS and chief nodes are visible
|
# For any worker, only the devices on ps and coordinator nodes are visible
|
||||||
for i in range(self._num_workers):
|
for i in range(self._num_workers):
|
||||||
device_filters.set_device_filters(
|
device_filters.set_device_filters(
|
||||||
"worker", i, ["/job:ps", "/job:%s" % client_name])
|
"worker", i, ["/job:ps", "/job:%s" % coordinator_name])
|
||||||
# Similarly for any ps, only the devices on workers and chief are visible
|
# Similarly for any ps, only the devices on workers and coordinator are
|
||||||
|
# visible
|
||||||
for i in range(self._num_ps):
|
for i in range(self._num_ps):
|
||||||
device_filters.set_device_filters(
|
device_filters.set_device_filters(
|
||||||
"ps", i, ["/job:worker", "/job:%s" % client_name])
|
"ps", i, ["/job:worker", "/job:%s" % coordinator_name])
|
||||||
|
|
||||||
# Allow at most one outstanding RPC for each worker at a certain time. This
|
# Allow at most one outstanding RPC for each worker at a certain time. This
|
||||||
# is to simplify worker failure handling in the runtime
|
# is to simplify worker failure handling in the runtime
|
||||||
@ -122,7 +440,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
|||||||
self.__class__.__name__, cluster_spec)
|
self.__class__.__name__, cluster_spec)
|
||||||
remote.connect_to_cluster(
|
remote.connect_to_cluster(
|
||||||
cluster_spec,
|
cluster_spec,
|
||||||
job_name=client_name,
|
job_name=coordinator_name,
|
||||||
protocol=self._cluster_resolver.rpc_layer,
|
protocol=self._cluster_resolver.rpc_layer,
|
||||||
cluster_device_filters=device_filters)
|
cluster_device_filters=device_filters)
|
||||||
|
|
||||||
@ -134,7 +452,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
|||||||
def _verify_args_and_config(self, cluster_resolver):
|
def _verify_args_and_config(self, cluster_resolver):
|
||||||
if not cluster_resolver.cluster_spec():
|
if not cluster_resolver.cluster_spec():
|
||||||
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
|
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
|
||||||
if self.extended._num_gpus_per_worker > 1:
|
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
|
||||||
raise NotImplementedError("Multi-gpu is not supported yet.")
|
raise NotImplementedError("Multi-gpu is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
@ -205,8 +523,8 @@ class ParameterServerStrategyV2Extended(
|
|||||||
init_from_fn = False
|
init_from_fn = False
|
||||||
initial_value = initial_value()
|
initial_value = initial_value()
|
||||||
if not init_from_fn:
|
if not init_from_fn:
|
||||||
# The initial_value is created on client, it will need to be sent to
|
# The initial_value is created on coordinator, it will need to be sent to
|
||||||
# PS for variable initialization, which can be inefficient and can
|
# ps for variable initialization, which can be inefficient and can
|
||||||
# potentially hit the 2GB limit on protobuf serialization.
|
# potentially hit the 2GB limit on protobuf serialization.
|
||||||
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
|
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
|
||||||
dtype = initial_value.dtype
|
dtype = initial_value.dtype
|
||||||
|
@ -870,8 +870,8 @@ py_test(
|
|||||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||||
"//tensorflow/python/distribute:sharded_variable",
|
"//tensorflow/python/distribute:sharded_variable",
|
||||||
"//tensorflow/python/distribute/client",
|
|
||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
|
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for ParameterServerClient and Keras models."""
|
"""Tests for ClusterCoordinator and Keras models."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -27,8 +27,8 @@ from tensorflow.python.compat import v2_compat
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.client import client as client_lib
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
|
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training.server_lib import ClusterSpec
|
from tensorflow.python.training.server_lib import ClusterSpec
|
||||||
|
|
||||||
|
|
||||||
def make_client(num_workers, num_ps):
|
def make_coordinator(num_workers, num_ps):
|
||||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||||
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||||
cluster_def["chief"] = [
|
cluster_def["chief"] = [
|
||||||
@ -52,7 +52,7 @@ def make_client(num_workers, num_ps):
|
|||||||
]
|
]
|
||||||
cluster_resolver = SimpleClusterResolver(
|
cluster_resolver = SimpleClusterResolver(
|
||||||
ClusterSpec(cluster_def), rpc_layer="grpc")
|
ClusterSpec(cluster_def), rpc_layer="grpc")
|
||||||
return client_lib.Client(
|
return coordinator_lib.ClusterCoordinator(
|
||||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
||||||
|
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class KPLTest(test.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(KPLTest, cls).setUpClass()
|
super(KPLTest, cls).setUpClass()
|
||||||
cls.client = make_client(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
|
|
||||||
def testTrainAndServe(self):
|
def testTrainAndServe(self):
|
||||||
# These vocabularies usually come from TFT or a Beam pipeline.
|
# These vocabularies usually come from TFT or a Beam pipeline.
|
||||||
@ -71,10 +71,10 @@ class KPLTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
label_vocab = ["yes", "no"]
|
label_vocab = ["yes", "no"]
|
||||||
|
|
||||||
with self.client.strategy.scope():
|
with self.coordinator.strategy.scope():
|
||||||
|
|
||||||
# Define KPLs under strategy's scope. Right now, if they have look up
|
# Define KPLs under strategy's scope. Right now, if they have look up
|
||||||
# tables, they will be created on the client. Their variables will be
|
# tables, they will be created on the coordinator. Their variables will be
|
||||||
# created on PS. Ideally they should be cached on each worker since they
|
# created on PS. Ideally they should be cached on each worker since they
|
||||||
# will not be changed in a training step.
|
# will not be changed in a training step.
|
||||||
feature_lookup_layer = string_lookup.StringLookup()
|
feature_lookup_layer = string_lookup.StringLookup()
|
||||||
@ -113,7 +113,7 @@ class KPLTest(test.TestCase):
|
|||||||
label = "yes" if "avenger" in features else "no"
|
label = "yes" if "avenger" in features else "no"
|
||||||
yield {"features": features, "label": label}
|
yield {"features": features, "label": label}
|
||||||
|
|
||||||
# The dataset will be created on the client?
|
# The dataset will be created on the coordinator?
|
||||||
raw_dataset = dataset_ops.Dataset.from_generator(
|
raw_dataset = dataset_ops.Dataset.from_generator(
|
||||||
feature_and_label_gen,
|
feature_and_label_gen,
|
||||||
output_types={
|
output_types={
|
||||||
@ -131,7 +131,8 @@ class KPLTest(test.TestCase):
|
|||||||
}, [x["label"]]))
|
}, [x["label"]]))
|
||||||
return train_dataset
|
return train_dataset
|
||||||
|
|
||||||
distributed_dataset = self.client.create_per_worker_dataset(dataset_fn)
|
distributed_dataset = self.coordinator.create_per_worker_dataset(
|
||||||
|
dataset_fn)
|
||||||
|
|
||||||
model_input = keras.layers.Input(
|
model_input = keras.layers.Input(
|
||||||
shape=(3,), dtype=dtypes.int64, name="model_input")
|
shape=(3,), dtype=dtypes.int64, name="model_input")
|
||||||
@ -163,12 +164,12 @@ class KPLTest(test.TestCase):
|
|||||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||||
accuracy.update_state(labels, actual_pred)
|
accuracy.update_state(labels, actual_pred)
|
||||||
|
|
||||||
self.client._strategy.run(train_step, args=(iterator,))
|
self.coordinator._strategy.run(train_step, args=(iterator,))
|
||||||
|
|
||||||
distributed_iterator = iter(distributed_dataset)
|
distributed_iterator = iter(distributed_dataset)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.client.schedule(worker_fn, args=(distributed_iterator,))
|
self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||||
self.client.join()
|
self.coordinator.join()
|
||||||
self.assertGreater(accuracy.result().numpy(), 0.0)
|
self.assertGreater(accuracy.result().numpy(), 0.0)
|
||||||
|
|
||||||
# Create a saved model.
|
# Create a saved model.
|
||||||
|
@ -362,7 +362,8 @@ class Model(training_lib.Model):
|
|||||||
if isinstance(self._distribution_strategy,
|
if isinstance(self._distribution_strategy,
|
||||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
(parameter_server_strategy.ParameterServerStrategyV1,
|
||||||
parameter_server_strategy.ParameterServerStrategy)):
|
parameter_server_strategy.ParameterServerStrategy)):
|
||||||
raise NotImplementedError('ParameterServerStrategy currently only works '
|
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
|
||||||
|
'erServerStrategy` currently only works '
|
||||||
'with the tf.Estimator API')
|
'with the tf.Estimator API')
|
||||||
|
|
||||||
if not self._experimental_run_tf_function:
|
if not self._experimental_run_tf_function:
|
||||||
|
@ -64,6 +64,9 @@ from tensorflow.python.framework.test_combinations import *
|
|||||||
from tensorflow.python.util.tf_decorator import make_decorator
|
from tensorflow.python.util.tf_decorator import make_decorator
|
||||||
from tensorflow.python.util.tf_decorator import unwrap
|
from tensorflow.python.util.tf_decorator import unwrap
|
||||||
|
|
||||||
|
from tensorflow.python.distribute.parameter_server_strategy_v2 import *
|
||||||
|
from tensorflow.python.distribute.coordinator.cluster_coordinator import *
|
||||||
|
|
||||||
tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator)
|
tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator)
|
||||||
tf_export('__internal__.decorator.unwrap', v1=[])(unwrap)
|
tf_export('__internal__.decorator.unwrap', v1=[])(unwrap)
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"distribute/__init__.py",
|
"distribute/__init__.py",
|
||||||
"distribute/cluster_resolver/__init__.py",
|
"distribute/cluster_resolver/__init__.py",
|
||||||
"distribute/experimental/__init__.py",
|
"distribute/experimental/__init__.py",
|
||||||
|
"distribute/experimental/coordinator/__init__.py",
|
||||||
"dtypes/__init__.py",
|
"dtypes/__init__.py",
|
||||||
"errors/__init__.py",
|
"errors/__init__.py",
|
||||||
"experimental/__init__.py",
|
"experimental/__init__.py",
|
||||||
|
@ -27,6 +27,8 @@ from tensorflow.lite.python import lite as _tflite_for_api_traversal
|
|||||||
from tensorflow.python import modules_with_exports
|
from tensorflow.python import modules_with_exports
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
|
from tensorflow.python.distribute.coordinator import cluster_coordinator
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import test_combinations
|
from tensorflow.python.framework import test_combinations
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.distribute.experimental.ParameterServerStrategy"
|
path: "tensorflow.distribute.experimental.ParameterServerStrategy"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategy\'>"
|
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy_v2.ParameterServerStrategyV2\'>"
|
||||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
@ -18,7 +18,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'cluster_resolver\', \'variable_partitioner\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "colocate_vars_with"
|
name: "colocate_vars_with"
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
path: "tensorflow.distribute.experimental.coordinator.ClusterCoordinator"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.ClusterCoordinator\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "strategy"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'strategy\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "create_per_worker_dataset"
|
||||||
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "done"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "fetch"
|
||||||
|
argspec: "args=[\'self\', \'val\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "join"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "schedule"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
path: "tensorflow.distribute.experimental.coordinator.PerWorkerValues"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.PerWorkerValues\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
path: "tensorflow.distribute.experimental.coordinator.RemoteValue"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.coordinator.cluster_coordinator.RemoteValue\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "fetch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
path: "tensorflow.distribute.experimental.coordinator"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "ClusterCoordinator"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "PerWorkerValues"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "RemoteValue"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -36,4 +36,8 @@ tf_module {
|
|||||||
name: "ValueContext"
|
name: "ValueContext"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "coordinator"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,9 +154,9 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/tools/common:test_module1",
|
"//tensorflow/tools/common:test_module1",
|
||||||
"//tensorflow/tools/common:traverse",
|
"//tensorflow/tools/common:traverse",
|
||||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||||
"//tensorflow/python/distribute/client:client",
|
"//tensorflow/python/distribute/coordinator:cluster_coordinator",
|
||||||
"//tensorflow/python/distribute/client:remote_eager_lib",
|
"//tensorflow/python/distribute/coordinator:remote_eager_lib",
|
||||||
"//tensorflow/python/distribute/client:metric_utils",
|
"//tensorflow/python/distribute/coordinator:metric_utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
# On Windows, python binary is a zip file of runfiles tree.
|
# On Windows, python binary is a zip file of runfiles tree.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user