PSv2: Dedup the legacy ParameterServerStrategy class (as the estimator usage of it uses ParameterServerStrategyV1).

PiperOrigin-RevId: 338310081
Change-Id: Icff445e322b22ee4ac7f3e69327c7969444eeb93
This commit is contained in:
Rick Chao 2020-10-21 11:56:40 -07:00 committed by TensorFlower Gardener
parent 43c9b64f53
commit dbf191bb17
4 changed files with 16 additions and 35 deletions

View File

@ -47,9 +47,8 @@ from tensorflow.python.util.tf_export import tf_export
_LOCAL_CPU = "/device:CPU:0" _LOCAL_CPU = "/device:CPU:0"
# TODO(yuefengz): maybe cache variables on local CPU. @tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
# TODO(b/171250971): Remove this and change all symbol usage of this to V1. class ParameterServerStrategyV1(distribute_lib.StrategyV1):
class ParameterServerStrategy(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy. """An asynchronous multi-worker parameter server tf.distribute strategy.
This strategy requires two roles: workers and parameter servers. Variables and This strategy requires two roles: workers and parameter servers. Variables and
@ -112,11 +111,11 @@ class ParameterServerStrategy(distribute_lib.Strategy):
""" """
if cluster_resolver is None: if cluster_resolver is None:
cluster_resolver = TFConfigClusterResolver() cluster_resolver = TFConfigClusterResolver()
if not cluster_resolver.cluster_spec(): super(ParameterServerStrategyV1, self).__init__(
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.") ParameterServerStrategyExtended(
extended = ParameterServerStrategyExtended( self, cluster_resolver=cluster_resolver))
self, cluster_resolver=cluster_resolver) distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
super(ParameterServerStrategy, self).__init__(extended) "ParameterServerStrategy")
def experimental_distribute_dataset(self, dataset, options=None): def experimental_distribute_dataset(self, dataset, options=None):
if (options and options.experimental_replication_mode == if (options and options.experimental_replication_mode ==
@ -127,7 +126,7 @@ class ParameterServerStrategy(distribute_lib.Strategy):
"`experimental_distribute_datasets_from_function`." "`experimental_distribute_datasets_from_function`."
) )
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
super(ParameterServerStrategy, super(ParameterServerStrategyV1,
self).experimental_distribute_dataset(dataset=dataset, self).experimental_distribute_dataset(dataset=dataset,
options=options) options=options)
@ -140,17 +139,17 @@ class ParameterServerStrategy(distribute_lib.Strategy):
"`experimental_distribute_datasets_from_function` " "`experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy") "of tf.distribute.MirroredStrategy")
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
super(ParameterServerStrategy, self).distribute_datasets_from_function( super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
dataset_fn=dataset_fn, options=options) dataset_fn=dataset_fn, options=options)
def run(self, fn, args=(), kwargs=None, options=None): def run(self, fn, args=(), kwargs=None, options=None):
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
super(ParameterServerStrategy, self).run( super(ParameterServerStrategyV1, self).run(
fn, args=args, kwargs=kwargs, options=options) fn, args=args, kwargs=kwargs, options=options)
def scope(self): def scope(self):
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
return super(ParameterServerStrategy, self).scope() return super(ParameterServerStrategyV1, self).scope()
def _raise_pss_error_if_eager(self): def _raise_pss_error_if_eager(self):
if context.executing_eagerly(): if context.executing_eagerly():
@ -159,22 +158,6 @@ class ParameterServerStrategy(distribute_lib.Strategy):
"currently only works with the tf.Estimator API") "currently only works with the tf.Estimator API")
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
__doc__ = ParameterServerStrategy.__doc__
def __init__(self, cluster_resolver=None):
"""Initializes this strategy."""
super(ParameterServerStrategyV1, self).__init__(
ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver))
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
"ParameterServerStrategy")
__init__.__doc__ = ParameterServerStrategy.__init__.__doc__
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of ParameterServerStrategy and CentralStorageStrategy.""" """Implementation of ParameterServerStrategy and CentralStorageStrategy."""

View File

@ -84,7 +84,7 @@ def create_test_objects(cluster_spec=None,
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_accelerators={'GPU': num_gpus}) num_accelerators={'GPU': num_gpus})
distribution = parameter_server_strategy.ParameterServerStrategy( distribution = parameter_server_strategy.ParameterServerStrategyV1(
cluster_resolver) cluster_resolver)
target = 'grpc://' + cluster_spec[WORKER][task_id] target = 'grpc://' + cluster_spec[WORKER][task_id]
else: else:
@ -748,7 +748,7 @@ class ParameterServerStrategyTest(
task_type='worker', task_type='worker',
task_id=1, task_id=1,
num_accelerators={'GPU': 0}) num_accelerators={'GPU': 0})
strategy = parameter_server_strategy.ParameterServerStrategy( strategy = parameter_server_strategy.ParameterServerStrategyV1(
cluster_resolver) cluster_resolver)
dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.]) dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])

View File

@ -2449,12 +2449,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
task_type='worker', task_type='worker',
task_id=1, task_id=1,
num_accelerators={'GPU': 0}) num_accelerators={'GPU': 0})
distribution = parameter_server_strategy.ParameterServerStrategy( distribution = parameter_server_strategy.ParameterServerStrategyV1(
cluster_resolver) cluster_resolver)
self.assertIsInstance(distribution, self.assertIsInstance(distribution,
(parameter_server_strategy.ParameterServerStrategyV1, parameter_server_strategy.ParameterServerStrategyV1)
parameter_server_strategy.ParameterServerStrategy))
with self.assertRaisesRegex(NotImplementedError, with self.assertRaisesRegex(NotImplementedError,
'ParameterServerStrategy*'): 'ParameterServerStrategy*'):

View File

@ -360,8 +360,7 @@ class Model(training_lib.Model):
distribution_strategy_context.get_strategy()) distribution_strategy_context.get_strategy())
if isinstance(self._distribution_strategy, if isinstance(self._distribution_strategy,
(parameter_server_strategy.ParameterServerStrategyV1, parameter_server_strategy.ParameterServerStrategyV1):
parameter_server_strategy.ParameterServerStrategy)):
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet' raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
'erServerStrategy` currently only works ' 'erServerStrategy` currently only works '
'with the tf.Estimator API') 'with the tf.Estimator API')