PSv2: Add checks that ParameterServerStrategy
's run
, reduce
, experimental_distribute_dataset
, and distribute_datasets_from_function
are used with a ClusterCoordinator
, and that run
and reduce
need to be used within a function that is used with schedule
.
PiperOrigin-RevId: 344297761 Change-Id: Ia7c101ed04e2db24564e83b47db7762b096ec267
This commit is contained in:
parent
cc9e5d8075
commit
287d8116bf
@ -870,6 +870,7 @@ class Cluster(object):
|
|||||||
self.workers = [
|
self.workers = [
|
||||||
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
|
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
|
||||||
]
|
]
|
||||||
|
self._strategy = strategy
|
||||||
|
|
||||||
def _record_and_ignore_transient_ps_failure(self, e):
|
def _record_and_ignore_transient_ps_failure(self, e):
|
||||||
"""Records potential PS failures and return if failure should be ignored."""
|
"""Records potential PS failures and return if failure should be ignored."""
|
||||||
@ -899,11 +900,13 @@ class Cluster(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A `RemoteValue` object.
|
A `RemoteValue` object.
|
||||||
"""
|
"""
|
||||||
|
self._strategy.extended._being_scheduled = True # pylint: disable=protected-access
|
||||||
closure = Closure(
|
closure = Closure(
|
||||||
function,
|
function,
|
||||||
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
|
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
|
||||||
args=args,
|
args=args,
|
||||||
kwargs=kwargs)
|
kwargs=kwargs)
|
||||||
|
self._strategy.extended._being_scheduled = False # pylint: disable=protected-access
|
||||||
self._closure_queue.put(closure)
|
self._closure_queue.put(closure)
|
||||||
return closure.output_remote_value
|
return closure.output_remote_value
|
||||||
|
|
||||||
@ -990,6 +993,7 @@ class ClusterCoordinator(object):
|
|||||||
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
|
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
|
||||||
"currently.")
|
"currently.")
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
|
self._strategy.extended._used_with_coordinator = True
|
||||||
self.cluster = Cluster(strategy)
|
self.cluster = Cluster(strategy)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -17,8 +17,8 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -678,6 +678,47 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
|||||||
'error message is Failed copying input tensor from'):
|
'error message is Failed copying input tensor from'):
|
||||||
self.coordinator.join()
|
self.coordinator.join()
|
||||||
|
|
||||||
|
def testRunNotUsedWithClusterCoordinatorSchedule(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def input_fn():
|
||||||
|
return dataset_ops.DatasetV2.range(1, 10)
|
||||||
|
|
||||||
|
with self.strategy.scope():
|
||||||
|
v = variables.Variable(initial_value=1, dtype=dtypes.int64)
|
||||||
|
|
||||||
|
def replica_fn(input_tensor):
|
||||||
|
return input_tensor + v, input_tensor - v
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def worker_fn(iterator):
|
||||||
|
return self.strategy.run(replica_fn, args=(next(iterator),))
|
||||||
|
|
||||||
|
per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _assert_raises_usage_error():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
NotImplementedError,
|
||||||
|
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
|
||||||
|
'`reduce` must be used within a function passed to '
|
||||||
|
'`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`'
|
||||||
|
'.'):
|
||||||
|
yield
|
||||||
|
|
||||||
|
with _assert_raises_usage_error():
|
||||||
|
# Invoking `run` without `coordinator.schedule` should error.
|
||||||
|
self.strategy.run(replica_fn, args=(next(iter(input_fn())),))
|
||||||
|
|
||||||
|
# A proper `schedule` should succeed.
|
||||||
|
rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
|
||||||
|
|
||||||
|
with _assert_raises_usage_error():
|
||||||
|
# Invoking `run` without `coordinator.schedule` again should error.
|
||||||
|
self.strategy.run(replica_fn, args=(next(iter(input_fn())),))
|
||||||
|
|
||||||
|
self.assertEqual((2, 0), rv.fetch())
|
||||||
|
|
||||||
|
|
||||||
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
|
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
|
||||||
"""Test basic functionality works with explicit maximum closure queue size.
|
"""Test basic functionality works with explicit maximum closure queue size.
|
||||||
|
@ -517,6 +517,11 @@ class ParameterServerStrategyV2Extended(
|
|||||||
self._variable_count = 0
|
self._variable_count = 0
|
||||||
self._variable_partitioner = variable_partitioner
|
self._variable_partitioner = variable_partitioner
|
||||||
|
|
||||||
|
# The following two attrs are to verify that `ParameterServerStrategy`
|
||||||
|
# methods are properly used with a `ClusterCoordinator`.
|
||||||
|
self._used_with_coordinator = False
|
||||||
|
self._being_scheduled = False
|
||||||
|
|
||||||
def _create_variable(self, next_creator, **kwargs):
|
def _create_variable(self, next_creator, **kwargs):
|
||||||
"""Implements StrategyExtendedV2._create_variable.
|
"""Implements StrategyExtendedV2._create_variable.
|
||||||
|
|
||||||
@ -670,7 +675,22 @@ class ParameterServerStrategyV2Extended(
|
|||||||
self._variable_count += 1
|
self._variable_count += 1
|
||||||
return var
|
return var
|
||||||
|
|
||||||
|
def _assert_used_with_cluster_coordinator(self):
|
||||||
|
if not self._used_with_coordinator:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`tf.distribute.experimental.ParameterServerStrategy` must be used "
|
||||||
|
"with `tf.distribute.experimental.coordinator.ClusterCoordinator`.")
|
||||||
|
|
||||||
|
def _assert_being_scheduled_by_cluster_coordinator(self):
|
||||||
|
if not self._being_scheduled:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
|
||||||
|
"`reduce` must be used within a function passed to `"
|
||||||
|
"tf.distribute.experimental.coordinator.ClusterCoordinator.schedule"
|
||||||
|
"`.")
|
||||||
|
|
||||||
def _experimental_distribute_dataset(self, dataset, options):
|
def _experimental_distribute_dataset(self, dataset, options):
|
||||||
|
self._assert_used_with_cluster_coordinator()
|
||||||
if not ops.get_default_graph().building_function:
|
if not ops.get_default_graph().building_function:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `experimental_distribute_dataset` method must be called inside "
|
"The `experimental_distribute_dataset` method must be called inside "
|
||||||
@ -679,6 +699,7 @@ class ParameterServerStrategyV2Extended(
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||||
|
self._assert_used_with_cluster_coordinator()
|
||||||
if not ops.get_default_graph().building_function:
|
if not ops.get_default_graph().building_function:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `distribute_datasets_from_function` method must be called "
|
"The `distribute_datasets_from_function` method must be called "
|
||||||
@ -687,6 +708,7 @@ class ParameterServerStrategyV2Extended(
|
|||||||
return dataset_fn(distribute_lib.InputContext())
|
return dataset_fn(distribute_lib.InputContext())
|
||||||
|
|
||||||
def _call_for_each_replica(self, fn, args, kwargs):
|
def _call_for_each_replica(self, fn, args, kwargs):
|
||||||
|
self._assert_being_scheduled_by_cluster_coordinator()
|
||||||
with distribute_lib.ReplicaContext(
|
with distribute_lib.ReplicaContext(
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
|
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
|
||||||
@ -694,6 +716,7 @@ class ParameterServerStrategyV2Extended(
|
|||||||
return distribute_utils.regroup((fn(*args, **kwargs),))
|
return distribute_utils.regroup((fn(*args, **kwargs),))
|
||||||
|
|
||||||
def _reduce(self, reduce_op, value):
|
def _reduce(self, reduce_op, value):
|
||||||
|
self._assert_being_scheduled_by_cluster_coordinator()
|
||||||
# TODO(rchao): Provide implementation for multi-replica. Also look into why
|
# TODO(rchao): Provide implementation for multi-replica. Also look into why
|
||||||
# the default implementation is not working.
|
# the default implementation is not working.
|
||||||
return value
|
return value
|
||||||
|
@ -19,11 +19,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute import sharded_variable
|
from tensorflow.python.distribute import sharded_variable
|
||||||
@ -73,6 +74,61 @@ class ParameterServerStrategyV2Test(test.TestCase):
|
|||||||
self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0")
|
self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0")
|
||||||
self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
|
self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _assertRaisesUsageError(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
NotImplementedError,
|
||||||
|
"`tf.distribute.experimental.ParameterServerStrategy` must be used "
|
||||||
|
"with `tf.distribute.experimental.coordinator.ClusterCoordinator`."):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _assertRaisesUsageErrorWithSchedule(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
NotImplementedError,
|
||||||
|
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
|
||||||
|
"`reduce` must be used within a function passed to `"
|
||||||
|
"tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`."):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def testRunNotUsedWithClusterCoordinator(self):
|
||||||
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
|
self.cluster_resolver)
|
||||||
|
dataset = dataset_ops.DatasetV2.range(3)
|
||||||
|
with strategy.scope():
|
||||||
|
v = variables.Variable(1, dtype=dtypes.int64)
|
||||||
|
|
||||||
|
def step_fn(iterator):
|
||||||
|
return next(iterator) + v
|
||||||
|
|
||||||
|
with self._assertRaisesUsageErrorWithSchedule():
|
||||||
|
strategy.run(step_fn, args=(iter(dataset),))
|
||||||
|
|
||||||
|
def testReduceNotUsedWithClusterCoordinator(self):
|
||||||
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
|
self.cluster_resolver)
|
||||||
|
with self._assertRaisesUsageErrorWithSchedule():
|
||||||
|
strategy.reduce("SUM", None, axis=None)
|
||||||
|
|
||||||
|
def testDistributeDatasetNotUsedWithClusterCoordinator(self):
|
||||||
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
|
self.cluster_resolver)
|
||||||
|
dataset = dataset_ops.DatasetV2.range(3)
|
||||||
|
with self._assertRaisesUsageError():
|
||||||
|
def_function.function(
|
||||||
|
lambda: strategy.experimental_distribute_dataset(dataset))()
|
||||||
|
|
||||||
|
def testDistributeDatasetFromFunctionNotUsedWithClusterCoordinator(self):
|
||||||
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
|
self.cluster_resolver)
|
||||||
|
|
||||||
|
def dataset_fn(_):
|
||||||
|
return dataset_ops.DatasetV2.range(3)
|
||||||
|
|
||||||
|
with self._assertRaisesUsageError():
|
||||||
|
def_function.function(
|
||||||
|
lambda: strategy.distribute_datasets_from_function(dataset_fn))()
|
||||||
|
|
||||||
|
|
||||||
class PartitionAwareIdentity(object):
|
class PartitionAwareIdentity(object):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user