PSv2: Add checks that ParameterServerStrategy's run, reduce, experimental_distribute_dataset, and distribute_datasets_from_function are used with a ClusterCoordinator, and that run and reduce need to be used within a function that is used with schedule.

PiperOrigin-RevId: 344297761
Change-Id: Ia7c101ed04e2db24564e83b47db7762b096ec267
This commit is contained in:
Rick Chao 2020-11-25 12:24:57 -08:00 committed by TensorFlower Gardener
parent cc9e5d8075
commit 287d8116bf
4 changed files with 126 additions and 2 deletions

View File

@ -870,6 +870,7 @@ class Cluster(object):
self.workers = [ self.workers = [
Worker(i, w, self) for i, w in enumerate(worker_device_strings) Worker(i, w, self) for i, w in enumerate(worker_device_strings)
] ]
self._strategy = strategy
def _record_and_ignore_transient_ps_failure(self, e): def _record_and_ignore_transient_ps_failure(self, e):
"""Records potential PS failures and return if failure should be ignored.""" """Records potential PS failures and return if failure should be ignored."""
@ -899,11 +900,13 @@ class Cluster(object):
Returns: Returns:
A `RemoteValue` object. A `RemoteValue` object.
""" """
self._strategy.extended._being_scheduled = True # pylint: disable=protected-access
closure = Closure( closure = Closure(
function, function,
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
args=args, args=args,
kwargs=kwargs) kwargs=kwargs)
self._strategy.extended._being_scheduled = False # pylint: disable=protected-access
self._closure_queue.put(closure) self._closure_queue.put(closure)
return closure.output_remote_value return closure.output_remote_value
@ -990,6 +993,7 @@ class ClusterCoordinator(object):
"`tf.distribute.experimental.coordinator.ClusterCoordinator` " "`tf.distribute.experimental.coordinator.ClusterCoordinator` "
"currently.") "currently.")
self._strategy = strategy self._strategy = strategy
self._strategy.extended._used_with_coordinator = True
self.cluster = Cluster(strategy) self.cluster = Cluster(strategy)
@property @property

View File

@ -17,8 +17,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import contextlib
import functools import functools
import os import os
import platform import platform
@ -678,6 +678,47 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
'error message is Failed copying input tensor from'): 'error message is Failed copying input tensor from'):
self.coordinator.join() self.coordinator.join()
def testRunNotUsedWithClusterCoordinatorSchedule(self):
@def_function.function
def input_fn():
return dataset_ops.DatasetV2.range(1, 10)
with self.strategy.scope():
v = variables.Variable(initial_value=1, dtype=dtypes.int64)
def replica_fn(input_tensor):
return input_tensor + v, input_tensor - v
@def_function.function
def worker_fn(iterator):
return self.strategy.run(replica_fn, args=(next(iterator),))
per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
@contextlib.contextmanager
def _assert_raises_usage_error():
with self.assertRaisesRegexp(
NotImplementedError,
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
'`reduce` must be used within a function passed to '
'`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`'
'.'):
yield
with _assert_raises_usage_error():
# Invoking `run` without `coordinator.schedule` should error.
self.strategy.run(replica_fn, args=(next(iter(input_fn())),))
# A proper `schedule` should succeed.
rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
with _assert_raises_usage_error():
# Invoking `run` without `coordinator.schedule` again should error.
self.strategy.run(replica_fn, args=(next(iter(input_fn())),))
self.assertEqual((2, 0), rv.fetch())
class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest): class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
"""Test basic functionality works with explicit maximum closure queue size. """Test basic functionality works with explicit maximum closure queue size.

View File

@ -517,6 +517,11 @@ class ParameterServerStrategyV2Extended(
self._variable_count = 0 self._variable_count = 0
self._variable_partitioner = variable_partitioner self._variable_partitioner = variable_partitioner
# The following two attrs are to verify that `ParameterServerStrategy`
# methods are properly used with a `ClusterCoordinator`.
self._used_with_coordinator = False
self._being_scheduled = False
def _create_variable(self, next_creator, **kwargs): def _create_variable(self, next_creator, **kwargs):
"""Implements StrategyExtendedV2._create_variable. """Implements StrategyExtendedV2._create_variable.
@ -670,7 +675,22 @@ class ParameterServerStrategyV2Extended(
self._variable_count += 1 self._variable_count += 1
return var return var
def _assert_used_with_cluster_coordinator(self):
if not self._used_with_coordinator:
raise NotImplementedError(
"`tf.distribute.experimental.ParameterServerStrategy` must be used "
"with `tf.distribute.experimental.coordinator.ClusterCoordinator`.")
def _assert_being_scheduled_by_cluster_coordinator(self):
if not self._being_scheduled:
raise NotImplementedError(
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
"`reduce` must be used within a function passed to `"
"tf.distribute.experimental.coordinator.ClusterCoordinator.schedule"
"`.")
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
self._assert_used_with_cluster_coordinator()
if not ops.get_default_graph().building_function: if not ops.get_default_graph().building_function:
raise ValueError( raise ValueError(
"The `experimental_distribute_dataset` method must be called inside " "The `experimental_distribute_dataset` method must be called inside "
@ -679,6 +699,7 @@ class ParameterServerStrategyV2Extended(
return dataset return dataset
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
self._assert_used_with_cluster_coordinator()
if not ops.get_default_graph().building_function: if not ops.get_default_graph().building_function:
raise ValueError( raise ValueError(
"The `distribute_datasets_from_function` method must be called " "The `distribute_datasets_from_function` method must be called "
@ -687,6 +708,7 @@ class ParameterServerStrategyV2Extended(
return dataset_fn(distribute_lib.InputContext()) return dataset_fn(distribute_lib.InputContext())
def _call_for_each_replica(self, fn, args, kwargs): def _call_for_each_replica(self, fn, args, kwargs):
self._assert_being_scheduled_by_cluster_coordinator()
with distribute_lib.ReplicaContext( with distribute_lib.ReplicaContext(
self._container_strategy(), self._container_strategy(),
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
@ -694,6 +716,7 @@ class ParameterServerStrategyV2Extended(
return distribute_utils.regroup((fn(*args, **kwargs),)) return distribute_utils.regroup((fn(*args, **kwargs),))
def _reduce(self, reduce_op, value): def _reduce(self, reduce_op, value):
self._assert_being_scheduled_by_cluster_coordinator()
# TODO(rchao): Provide implementation for multi-replica. Also look into why # TODO(rchao): Provide implementation for multi-replica. Also look into why
# the default implementation is not working. # the default implementation is not working.
return value return value

View File

@ -19,11 +19,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib
import functools import functools
import os import os
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import sharded_variable
@ -73,6 +74,61 @@ class ParameterServerStrategyV2Test(test.TestCase):
self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0") self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0")
self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0") self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
@contextlib.contextmanager
def _assertRaisesUsageError(self):
with self.assertRaisesRegexp(
NotImplementedError,
"`tf.distribute.experimental.ParameterServerStrategy` must be used "
"with `tf.distribute.experimental.coordinator.ClusterCoordinator`."):
yield
@contextlib.contextmanager
def _assertRaisesUsageErrorWithSchedule(self):
with self.assertRaisesRegexp(
NotImplementedError,
"`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
"`reduce` must be used within a function passed to `"
"tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`."):
yield
def testRunNotUsedWithClusterCoordinator(self):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
dataset = dataset_ops.DatasetV2.range(3)
with strategy.scope():
v = variables.Variable(1, dtype=dtypes.int64)
def step_fn(iterator):
return next(iterator) + v
with self._assertRaisesUsageErrorWithSchedule():
strategy.run(step_fn, args=(iter(dataset),))
def testReduceNotUsedWithClusterCoordinator(self):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
with self._assertRaisesUsageErrorWithSchedule():
strategy.reduce("SUM", None, axis=None)
def testDistributeDatasetNotUsedWithClusterCoordinator(self):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
dataset = dataset_ops.DatasetV2.range(3)
with self._assertRaisesUsageError():
def_function.function(
lambda: strategy.experimental_distribute_dataset(dataset))()
def testDistributeDatasetFromFunctionNotUsedWithClusterCoordinator(self):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
def dataset_fn(_):
return dataset_ops.DatasetV2.range(3)
with self._assertRaisesUsageError():
def_function.function(
lambda: strategy.distribute_datasets_from_function(dataset_fn))()
class PartitionAwareIdentity(object): class PartitionAwareIdentity(object):