Keras Fixit: Remove usage of parameter_server_strategy.ParameterServerStrategyExtended in Keras.

PiperOrigin-RevId: 339379608
Change-Id: I6c56ca9236ed2c04c3d60d82f07e27b7f9048499
This commit is contained in:
Rick Chao 2020-10-27 19:36:08 -07:00 committed by TensorFlower Gardener
parent 9af8f9866d
commit f34ada9573
5 changed files with 21 additions and 90 deletions

View File

@ -318,6 +318,7 @@ py_library(
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/eager:context",

View File

@ -35,6 +35,7 @@ from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
@ -527,8 +528,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
@ds_combinations.generate(all_strategy_combinations())
def test_calling_model_with_mixed_precision(self, distribution):
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
if isinstance(distribution,
(parameter_server_strategy.ParameterServerStrategyV1,
parameter_server_strategy_v2.ParameterServerStrategyV2,
central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
self.skipTest('b/152097775')
if _is_tpu_strategy(distribution):
policy_name = 'mixed_bfloat16'
@ -576,8 +580,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
# AutoCastVariable to a tensor on a TPU, where the variable was the LHS of
# the '+' operator, used to cause the gradient w.r.t. the variable to be
# None.
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
if isinstance(distribution,
(parameter_server_strategy.ParameterServerStrategyV1,
parameter_server_strategy_v2.ParameterServerStrategyV2,
central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
self.skipTest('b/152097775')
if _is_tpu_strategy(distribution):

View File

@ -33,10 +33,7 @@ from tensorflow.python import keras
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks
@ -48,23 +45,6 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest
# TODO(b/130375202): remove this class which is a temporary solution before we
# get rid of configure method.
class ParameterServerStrategy(distribute_lib.Strategy):
"""Temporarily mock the original strategy to bypass cluster_spec check."""
def __init__(self, cluster_resolver=None):
"""Initializes this strategy."""
# The `cluster_resolver` must be set so that
# `ParameterServerStrategyExtended` will keep num_gpus for `configure`
# method.
if cluster_resolver is None:
cluster_resolver = TFConfigClusterResolver()
extended = parameter_server_strategy.ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver)
super(ParameterServerStrategy, self).__init__(extended)
def _clone_and_build_model(model, strategy):
# The new "original" model in worker 0.
with strategy.scope():
@ -262,69 +242,6 @@ class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
self.join_independent_workers(threads_to_join)
verification_callback.verify(self)
@ds_combinations.generate(
combinations.combine(
mode=['graph'],
strategy_cls=[ParameterServerStrategy],
required_gpus=[0, 1]))
def testSimpleModelIndependentWorkerAsync(self, strategy_cls):
num_workers = 2
num_epoch = 2
cluster_spec = test_base.create_cluster_spec(
num_workers=num_workers, num_ps=2)
self._barrier = dc._Barrier(4)
# The verification callback will be shared by multiple threads.
verification_callback = MultiWorkerVerificationCallback(
num_epoch=num_epoch, num_worker=num_workers)
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
"""Simulates an Independent Worker inside of a thread."""
# TODO(rchao/yuefengz): The following is run by both worker and ps
# threads. The distribute coordinator should run std server immediately
# without configuring the session (or building the graph) on PS.
with test.mock.patch.object(dc, '_run_std_server',
self._make_mock_run_std_server()):
batch_size = 64
steps = 2
strategy = strategy_cls()
verification_callback.is_between_graph = \
strategy.extended.experimental_between_graph
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
val_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
with strategy.scope():
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
# TODO(b/123868066): Verify callback for model.evaluate().
callbacks_for_fit = nest.flatten(
kwargs.get('verification_callback', []))
history = model.fit(
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
validation_data=val_ds,
validation_steps=steps,
callbacks=callbacks_for_fit)
self.assertIsInstance(history, keras.callbacks.History)
threads = self.run_multiple_tasks_in_threads(
_independent_worker_fn,
cluster_spec,
verification_callback=verification_callback)
threads_to_join = []
for task_type, ts in threads.items():
# This test can finish once the worker threads complete, and thus
# the ps threads don't need to be joined.
if task_type == 'ps':
continue
threads_to_join.extend(ts)
self.join_independent_workers(threads_to_join)
verification_callback.verify(self)
if __name__ == '__main__':
# Enable manual variable initialization to make sure variables are initialized

View File

@ -47,6 +47,7 @@ py_library(
"//tensorflow/python/distribute:central_storage_strategy",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/keras:backend",

View File

@ -25,8 +25,10 @@ import functools
import six
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -615,9 +617,12 @@ class OptimizerV2(trackable.Trackable):
"context.")
strategy = distribute_ctx.get_strategy()
if (not experimental_aggregate_gradients and strategy and isinstance(
strategy.extended,
parameter_server_strategy.ParameterServerStrategyExtended)):
if (not experimental_aggregate_gradients and strategy and
isinstance(strategy,
(parameter_server_strategy.ParameterServerStrategyV1,
parameter_server_strategy_v2.ParameterServerStrategyV2,
central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1))):
raise NotImplementedError(
"`experimental_aggregate_gradients=False is not supported for "
"ParameterServerStrategy and CentralStorageStrategy")