Keras Fixit: Remove usage of parameter_server_strategy.ParameterServerStrategyExtended in Keras.
PiperOrigin-RevId: 339379608 Change-Id: I6c56ca9236ed2c04c3d60d82f07e27b7f9048499
This commit is contained in:
parent
9af8f9866d
commit
f34ada9573
@ -318,6 +318,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:tpu_strategy",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
@ -527,8 +528,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(all_strategy_combinations())
|
||||
def test_calling_model_with_mixed_precision(self, distribution):
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
if isinstance(distribution,
|
||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2,
|
||||
central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
self.skipTest('b/152097775')
|
||||
if _is_tpu_strategy(distribution):
|
||||
policy_name = 'mixed_bfloat16'
|
||||
@ -576,8 +580,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
# AutoCastVariable to a tensor on a TPU, where the variable was the LHS of
|
||||
# the '+' operator, used to cause the gradient w.r.t. the variable to be
|
||||
# None.
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
if isinstance(distribution,
|
||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2,
|
||||
central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
self.skipTest('b/152097775')
|
||||
|
||||
if _is_tpu_strategy(distribution):
|
||||
|
@ -33,10 +33,7 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import callbacks
|
||||
@ -48,23 +45,6 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# TODO(b/130375202): remove this class which is a temporary solution before we
|
||||
# get rid of configure method.
|
||||
class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
"""Temporarily mock the original strategy to bypass cluster_spec check."""
|
||||
|
||||
def __init__(self, cluster_resolver=None):
|
||||
"""Initializes this strategy."""
|
||||
# The `cluster_resolver` must be set so that
|
||||
# `ParameterServerStrategyExtended` will keep num_gpus for `configure`
|
||||
# method.
|
||||
if cluster_resolver is None:
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
extended = parameter_server_strategy.ParameterServerStrategyExtended(
|
||||
self, cluster_resolver=cluster_resolver)
|
||||
super(ParameterServerStrategy, self).__init__(extended)
|
||||
|
||||
|
||||
def _clone_and_build_model(model, strategy):
|
||||
# The new "original" model in worker 0.
|
||||
with strategy.scope():
|
||||
@ -262,69 +242,6 @@ class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
|
||||
self.join_independent_workers(threads_to_join)
|
||||
verification_callback.verify(self)
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
mode=['graph'],
|
||||
strategy_cls=[ParameterServerStrategy],
|
||||
required_gpus=[0, 1]))
|
||||
def testSimpleModelIndependentWorkerAsync(self, strategy_cls):
|
||||
num_workers = 2
|
||||
num_epoch = 2
|
||||
cluster_spec = test_base.create_cluster_spec(
|
||||
num_workers=num_workers, num_ps=2)
|
||||
self._barrier = dc._Barrier(4)
|
||||
|
||||
# The verification callback will be shared by multiple threads.
|
||||
verification_callback = MultiWorkerVerificationCallback(
|
||||
num_epoch=num_epoch, num_worker=num_workers)
|
||||
|
||||
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
"""Simulates an Independent Worker inside of a thread."""
|
||||
# TODO(rchao/yuefengz): The following is run by both worker and ps
|
||||
# threads. The distribute coordinator should run std server immediately
|
||||
# without configuring the session (or building the graph) on PS.
|
||||
with test.mock.patch.object(dc, '_run_std_server',
|
||||
self._make_mock_run_std_server()):
|
||||
batch_size = 64
|
||||
steps = 2
|
||||
strategy = strategy_cls()
|
||||
verification_callback.is_between_graph = \
|
||||
strategy.extended.experimental_between_graph
|
||||
|
||||
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
val_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
with strategy.scope():
|
||||
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
|
||||
|
||||
# TODO(b/123868066): Verify callback for model.evaluate().
|
||||
callbacks_for_fit = nest.flatten(
|
||||
kwargs.get('verification_callback', []))
|
||||
history = model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
validation_data=val_ds,
|
||||
validation_steps=steps,
|
||||
callbacks=callbacks_for_fit)
|
||||
self.assertIsInstance(history, keras.callbacks.History)
|
||||
|
||||
threads = self.run_multiple_tasks_in_threads(
|
||||
_independent_worker_fn,
|
||||
cluster_spec,
|
||||
verification_callback=verification_callback)
|
||||
|
||||
threads_to_join = []
|
||||
for task_type, ts in threads.items():
|
||||
# This test can finish once the worker threads complete, and thus
|
||||
# the ps threads don't need to be joined.
|
||||
if task_type == 'ps':
|
||||
continue
|
||||
threads_to_join.extend(ts)
|
||||
self.join_independent_workers(threads_to_join)
|
||||
verification_callback.verify(self)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Enable manual variable initialization to make sure variables are initialized
|
||||
|
@ -47,6 +47,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:central_storage_strategy",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/keras:backend",
|
||||
|
@ -25,8 +25,10 @@ import functools
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -615,9 +617,12 @@ class OptimizerV2(trackable.Trackable):
|
||||
"context.")
|
||||
|
||||
strategy = distribute_ctx.get_strategy()
|
||||
if (not experimental_aggregate_gradients and strategy and isinstance(
|
||||
strategy.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended)):
|
||||
if (not experimental_aggregate_gradients and strategy and
|
||||
isinstance(strategy,
|
||||
(parameter_server_strategy.ParameterServerStrategyV1,
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2,
|
||||
central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1))):
|
||||
raise NotImplementedError(
|
||||
"`experimental_aggregate_gradients=False is not supported for "
|
||||
"ParameterServerStrategy and CentralStorageStrategy")
|
||||
|
Loading…
Reference in New Issue
Block a user