PSv2: Dedup the legacy ParameterServerStrategy class (as the estimator usage of it uses ParameterServerStrategyV1).
PiperOrigin-RevId: 338310081 Change-Id: Icff445e322b22ee4ac7f3e69327c7969444eeb93
This commit is contained in:
parent
43c9b64f53
commit
dbf191bb17
@ -47,9 +47,8 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
_LOCAL_CPU = "/device:CPU:0"
|
_LOCAL_CPU = "/device:CPU:0"
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): maybe cache variables on local CPU.
|
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
|
||||||
# TODO(b/171250971): Remove this and change all symbol usage of this to V1.
|
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
|
||||||
class ParameterServerStrategy(distribute_lib.Strategy):
|
|
||||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||||
|
|
||||||
This strategy requires two roles: workers and parameter servers. Variables and
|
This strategy requires two roles: workers and parameter servers. Variables and
|
||||||
@ -112,11 +111,11 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
|||||||
"""
|
"""
|
||||||
if cluster_resolver is None:
|
if cluster_resolver is None:
|
||||||
cluster_resolver = TFConfigClusterResolver()
|
cluster_resolver = TFConfigClusterResolver()
|
||||||
if not cluster_resolver.cluster_spec():
|
super(ParameterServerStrategyV1, self).__init__(
|
||||||
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
|
ParameterServerStrategyExtended(
|
||||||
extended = ParameterServerStrategyExtended(
|
self, cluster_resolver=cluster_resolver))
|
||||||
self, cluster_resolver=cluster_resolver)
|
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
|
||||||
super(ParameterServerStrategy, self).__init__(extended)
|
"ParameterServerStrategy")
|
||||||
|
|
||||||
def experimental_distribute_dataset(self, dataset, options=None):
|
def experimental_distribute_dataset(self, dataset, options=None):
|
||||||
if (options and options.experimental_replication_mode ==
|
if (options and options.experimental_replication_mode ==
|
||||||
@ -127,7 +126,7 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
|||||||
"`experimental_distribute_datasets_from_function`."
|
"`experimental_distribute_datasets_from_function`."
|
||||||
)
|
)
|
||||||
self._raise_pss_error_if_eager()
|
self._raise_pss_error_if_eager()
|
||||||
super(ParameterServerStrategy,
|
super(ParameterServerStrategyV1,
|
||||||
self).experimental_distribute_dataset(dataset=dataset,
|
self).experimental_distribute_dataset(dataset=dataset,
|
||||||
options=options)
|
options=options)
|
||||||
|
|
||||||
@ -140,17 +139,17 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
|||||||
"`experimental_distribute_datasets_from_function` "
|
"`experimental_distribute_datasets_from_function` "
|
||||||
"of tf.distribute.MirroredStrategy")
|
"of tf.distribute.MirroredStrategy")
|
||||||
self._raise_pss_error_if_eager()
|
self._raise_pss_error_if_eager()
|
||||||
super(ParameterServerStrategy, self).distribute_datasets_from_function(
|
super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
|
||||||
dataset_fn=dataset_fn, options=options)
|
dataset_fn=dataset_fn, options=options)
|
||||||
|
|
||||||
def run(self, fn, args=(), kwargs=None, options=None):
|
def run(self, fn, args=(), kwargs=None, options=None):
|
||||||
self._raise_pss_error_if_eager()
|
self._raise_pss_error_if_eager()
|
||||||
super(ParameterServerStrategy, self).run(
|
super(ParameterServerStrategyV1, self).run(
|
||||||
fn, args=args, kwargs=kwargs, options=options)
|
fn, args=args, kwargs=kwargs, options=options)
|
||||||
|
|
||||||
def scope(self):
|
def scope(self):
|
||||||
self._raise_pss_error_if_eager()
|
self._raise_pss_error_if_eager()
|
||||||
return super(ParameterServerStrategy, self).scope()
|
return super(ParameterServerStrategyV1, self).scope()
|
||||||
|
|
||||||
def _raise_pss_error_if_eager(self):
|
def _raise_pss_error_if_eager(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -159,22 +158,6 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
|||||||
"currently only works with the tf.Estimator API")
|
"currently only works with the tf.Estimator API")
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
|
|
||||||
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
|
|
||||||
|
|
||||||
__doc__ = ParameterServerStrategy.__doc__
|
|
||||||
|
|
||||||
def __init__(self, cluster_resolver=None):
|
|
||||||
"""Initializes this strategy."""
|
|
||||||
super(ParameterServerStrategyV1, self).__init__(
|
|
||||||
ParameterServerStrategyExtended(
|
|
||||||
self, cluster_resolver=cluster_resolver))
|
|
||||||
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
|
|
||||||
"ParameterServerStrategy")
|
|
||||||
|
|
||||||
__init__.__doc__ = ParameterServerStrategy.__init__.__doc__
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
|
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
|
||||||
class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||||
"""Implementation of ParameterServerStrategy and CentralStorageStrategy."""
|
"""Implementation of ParameterServerStrategy and CentralStorageStrategy."""
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def create_test_objects(cluster_spec=None,
|
|||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
num_accelerators={'GPU': num_gpus})
|
num_accelerators={'GPU': num_gpus})
|
||||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
distribution = parameter_server_strategy.ParameterServerStrategyV1(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
target = 'grpc://' + cluster_spec[WORKER][task_id]
|
target = 'grpc://' + cluster_spec[WORKER][task_id]
|
||||||
else:
|
else:
|
||||||
@ -748,7 +748,7 @@ class ParameterServerStrategyTest(
|
|||||||
task_type='worker',
|
task_type='worker',
|
||||||
task_id=1,
|
task_id=1,
|
||||||
num_accelerators={'GPU': 0})
|
num_accelerators={'GPU': 0})
|
||||||
strategy = parameter_server_strategy.ParameterServerStrategy(
|
strategy = parameter_server_strategy.ParameterServerStrategyV1(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])
|
dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])
|
||||||
|
|
||||||
|
|||||||
@ -2449,12 +2449,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
|||||||
task_type='worker',
|
task_type='worker',
|
||||||
task_id=1,
|
task_id=1,
|
||||||
num_accelerators={'GPU': 0})
|
num_accelerators={'GPU': 0})
|
||||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
distribution = parameter_server_strategy.ParameterServerStrategyV1(
|
||||||
cluster_resolver)
|
cluster_resolver)
|
||||||
|
|
||||||
self.assertIsInstance(distribution,
|
self.assertIsInstance(distribution,
|
||||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
parameter_server_strategy.ParameterServerStrategyV1)
|
||||||
parameter_server_strategy.ParameterServerStrategy))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(NotImplementedError,
|
with self.assertRaisesRegex(NotImplementedError,
|
||||||
'ParameterServerStrategy*'):
|
'ParameterServerStrategy*'):
|
||||||
|
|||||||
@ -360,8 +360,7 @@ class Model(training_lib.Model):
|
|||||||
distribution_strategy_context.get_strategy())
|
distribution_strategy_context.get_strategy())
|
||||||
|
|
||||||
if isinstance(self._distribution_strategy,
|
if isinstance(self._distribution_strategy,
|
||||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
parameter_server_strategy.ParameterServerStrategyV1):
|
||||||
parameter_server_strategy.ParameterServerStrategy)):
|
|
||||||
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
|
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
|
||||||
'erServerStrategy` currently only works '
|
'erServerStrategy` currently only works '
|
||||||
'with the tf.Estimator API')
|
'with the tf.Estimator API')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user