diff --git a/tensorflow/python/distribute/custom_training_loop_optimizer_test.py b/tensorflow/python/distribute/custom_training_loop_optimizer_test.py index 942f83ed01d..a0c6cc1d01d 100644 --- a/tensorflow/python/distribute/custom_training_loop_optimizer_test.py +++ b/tensorflow/python/distribute/custom_training_loop_optimizer_test.py @@ -98,6 +98,24 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): self.assertAllClose(optimize(), [[-0.1, -0.1]]) + @combinations.generate( + combinations.combine(distribution=[ + strategy_combinations.central_storage_strategy_with_gpu_and_cpu + ])) + def test_custom_aggregation_central_storage(self, distribution): + with distribution.scope(): + v = variables.Variable([0., 0.]) + optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1) + + grads = ops.convert_to_tensor([1., 1.]) + + def step_fn(grads): + with self.assertRaises(NotImplementedError): + optimizer.apply_gradients([(grads, v)], + experimental_aggregate_gradients=False) + + return distribution.run(step_fn, args=(grads,)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 7f2339f2ff9..6a70c7c5c5c 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -91,6 +91,7 @@ py_library( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/eager:context", @@ -122,7 +123,7 @@ distribute_py_test( srcs = ["distribute_strategy_test.py"], full_precision = True, main = "distribute_strategy_test.py", - shard_count = 8, + shard_count = 10, tags = [ "multi_and_single_gpu", "no_rocm", # times out on ROCm @@ -283,6 +284,7 @@ distribute_py_test( ], deps = [ ":keras_test_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/keras/distribute:distribute_strategy_test_lib", ], ) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index fa653ce9f54..a7e25b77627 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy @@ -216,7 +217,8 @@ strategies_minus_default_minus_tpu = [ strategy_combinations.one_device_strategy, strategy_combinations.one_device_strategy_gpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.mirrored_strategy_with_two_gpus + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu ] strategies_minus_tpu = [ @@ -224,7 +226,8 @@ strategies_minus_tpu = [ strategy_combinations.one_device_strategy, strategy_combinations.one_device_strategy_gpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.mirrored_strategy_with_two_gpus + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu ] tpu_strategies = [ @@ -482,6 +485,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate(all_strategy_combinations_plus_run_distributed()) def test_calling_model_with_mixed_precision(self, distribution): + if isinstance(distribution.extended, + parameter_server_strategy.ParameterServerStrategyExtended): + self.skipTest('b/152097775') if isinstance(distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): policy_name = 'mixed_bfloat16' @@ -529,6 +535,10 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # AutoCastVariable to a tensor on a TPU, where the variable was the LHS of # the '+' operator, used to cause the gradient w.r.t. the variable to be # None. + if isinstance(distribution.extended, + parameter_server_strategy.ParameterServerStrategyExtended): + self.skipTest('b/152097775') + if isinstance(distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): policy_name = 'mixed_bfloat16' diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py index 0f65bbbf917..702d89d95f8 100644 --- a/tensorflow/python/keras/distribute/keras_utils_test.py +++ b/tensorflow/python/keras/distribute/keras_utils_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values @@ -397,6 +398,9 @@ class TestDistributionStrategyWithNormalizationLayer(test.TestCase, optimizer=strategy_combinations .gradient_descent_optimizer_keras_v2_fn))) def test_batchnorm_correctness(self, distribution, fused, optimizer): + if isinstance(distribution.extended, + parameter_server_strategy.ParameterServerStrategyExtended): + self.skipTest('b/152353796') with self.cached_session(): with distribution.scope(): model = keras.models.Sequential() diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 7555228c20f..4aaf0061e43 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -47,6 +47,7 @@ py_library( "//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:monitoring", "//tensorflow/python/keras:activations", diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d337af77919..94570d96208 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -24,6 +24,7 @@ import itertools from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator_context as dc_context from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -508,7 +509,8 @@ class Model(network.Network, version_utils.ModelVersionSelector): # self.optimizer.apply_gradients(zip(gradients, trainable_variables)) # The _minimize call does a few extra steps unnecessary in most cases, # such as loss scaling and gradient clipping. - _minimize(tape, self.optimizer, loss, self.trainable_variables) + _minimize(self.distribute_strategy, tape, self.optimizer, loss, + self.trainable_variables) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics} @@ -1787,7 +1789,7 @@ def _tpu_multi_host_concat(v, strategy): return concat(ordered_replicas) -def _minimize(tape, optimizer, loss, trainable_variables): +def _minimize(strategy, tape, optimizer, loss, trainable_variables): """Minimizes loss for one step by updating `trainable_variables`. This is roughly equivalent to @@ -1801,6 +1803,7 @@ def _minimize(tape, optimizer, loss, trainable_variables): optimizer is a LossScaleOptimizer. Args: + strategy: `tf.distribute.Strategy`. tape: A gradient tape. The loss must have been computed under this tape. optimizer: The optimizer used to minimize the loss. loss: The loss tensor. @@ -1814,7 +1817,15 @@ def _minimize(tape, optimizer, loss, trainable_variables): gradients = tape.gradient(loss, trainable_variables) - if optimizer._HAS_AGGREGATE_GRAD: # pylint: disable=protected-access + # Whether to aggregate gradients outside of optimizer. This requires support + # of the optimizer and doesn't work with ParameterServerStrategy and + # CentralStroageStrategy. + aggregate_grads_outside_optimizer = ( + optimizer._HAS_AGGREGATE_GRAD and # pylint: disable=protected-access + not isinstance(strategy.extended, + parameter_server_strategy.ParameterServerStrategyExtended)) + + if aggregate_grads_outside_optimizer: # We aggregate gradients before unscaling them, in case a subclass of # LossScaleOptimizer all-reduces in fp16. All-reducing in fp16 can only be # done on scaled gradients, not unscaled gradients, for numeric stability. @@ -1824,7 +1835,7 @@ def _minimize(tape, optimizer, loss, trainable_variables): gradients = optimizer.get_unscaled_gradients(gradients) gradients = optimizer._clip_gradients(gradients) # pylint: disable=protected-access if trainable_variables: - if optimizer._HAS_AGGREGATE_GRAD: # pylint: disable=protected-access + if aggregate_grads_outside_optimizer: optimizer.apply_gradients( zip(gradients, trainable_variables), experimental_aggregate_gradients=False) diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index afdb8bc04b3..c5eab79f6c2 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -40,6 +40,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", "//tensorflow/python/keras:backend", diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 98b42cbad97..20515beb0eb 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -26,6 +26,7 @@ import functools import six from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx +from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop @@ -491,6 +492,14 @@ class OptimizerV2(trackable.Trackable): "Use `tf.distribute.Strategy.run` to enter replica " "context.") + strategy = distribute_ctx.get_strategy() + if (not experimental_aggregate_gradients and strategy and isinstance( + strategy.extended, + parameter_server_strategy.ParameterServerStrategyExtended)): + raise NotImplementedError( + "`experimental_aggregate_gradients=False is not supported for " + "ParameterServerStrategy and CentralStorageStrategy") + apply_state = self._prepare(var_list) if experimental_aggregate_gradients: reduced_grads = self._aggregate_gradients(grads_and_vars)