Error when experimental_aggregate_gradients=False is used with
CentralStorageStrategy PiperOrigin-RevId: 302804311 Change-Id: Ibb27c529251390f40338cd296537cd98f8940b56
This commit is contained in:
parent
1030e2aa58
commit
c3d655f51b
|
@ -98,6 +98,24 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
|||
|
||||
self.assertAllClose(optimize(), [[-0.1, -0.1]])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=[
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||
]))
|
||||
def test_custom_aggregation_central_storage(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variables.Variable([0., 0.])
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1)
|
||||
|
||||
grads = ops.convert_to_tensor([1., 1.])
|
||||
|
||||
def step_fn(grads):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
optimizer.apply_gradients([(grads, v)],
|
||||
experimental_aggregate_gradients=False)
|
||||
|
||||
return distribution.run(step_fn, args=(grads,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -91,6 +91,7 @@ py_library(
|
|||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:tpu_strategy",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -122,7 +123,7 @@ distribute_py_test(
|
|||
srcs = ["distribute_strategy_test.py"],
|
||||
full_precision = True,
|
||||
main = "distribute_strategy_test.py",
|
||||
shard_count = 8,
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm", # times out on ROCm
|
||||
|
@ -283,6 +284,7 @@ distribute_py_test(
|
|||
],
|
||||
deps = [
|
||||
":keras_test_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
|
@ -216,7 +217,8 @@ strategies_minus_default_minus_tpu = [
|
|||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||
]
|
||||
|
||||
strategies_minus_tpu = [
|
||||
|
@ -224,7 +226,8 @@ strategies_minus_tpu = [
|
|||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||
]
|
||||
|
||||
tpu_strategies = [
|
||||
|
@ -482,6 +485,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||
|
||||
@combinations.generate(all_strategy_combinations_plus_run_distributed())
|
||||
def test_calling_model_with_mixed_precision(self, distribution):
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest('b/152097775')
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
policy_name = 'mixed_bfloat16'
|
||||
|
@ -529,6 +535,10 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||
# AutoCastVariable to a tensor on a TPU, where the variable was the LHS of
|
||||
# the '+' operator, used to cause the gradient w.r.t. the variable to be
|
||||
# None.
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest('b/152097775')
|
||||
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
policy_name = 'mixed_bfloat16'
|
||||
|
|
|
@ -26,6 +26,7 @@ import numpy as np
|
|||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values
|
||||
|
@ -397,6 +398,9 @@ class TestDistributionStrategyWithNormalizationLayer(test.TestCase,
|
|||
optimizer=strategy_combinations
|
||||
.gradient_descent_optimizer_keras_v2_fn)))
|
||||
def test_batchnorm_correctness(self, distribution, fused, optimizer):
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest('b/152353796')
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
model = keras.models.Sequential()
|
||||
|
|
|
@ -47,6 +47,7 @@ py_library(
|
|||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/keras:activations",
|
||||
|
|
|
@ -24,6 +24,7 @@ import itertools
|
|||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
|
@ -508,7 +509,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
|||
# self.optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||||
# The _minimize call does a few extra steps unnecessary in most cases,
|
||||
# such as loss scaling and gradient clipping.
|
||||
_minimize(tape, self.optimizer, loss, self.trainable_variables)
|
||||
_minimize(self.distribute_strategy, tape, self.optimizer, loss,
|
||||
self.trainable_variables)
|
||||
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
|
@ -1787,7 +1789,7 @@ def _tpu_multi_host_concat(v, strategy):
|
|||
return concat(ordered_replicas)
|
||||
|
||||
|
||||
def _minimize(tape, optimizer, loss, trainable_variables):
|
||||
def _minimize(strategy, tape, optimizer, loss, trainable_variables):
|
||||
"""Minimizes loss for one step by updating `trainable_variables`.
|
||||
|
||||
This is roughly equivalent to
|
||||
|
@ -1801,6 +1803,7 @@ def _minimize(tape, optimizer, loss, trainable_variables):
|
|||
optimizer is a LossScaleOptimizer.
|
||||
|
||||
Args:
|
||||
strategy: `tf.distribute.Strategy`.
|
||||
tape: A gradient tape. The loss must have been computed under this tape.
|
||||
optimizer: The optimizer used to minimize the loss.
|
||||
loss: The loss tensor.
|
||||
|
@ -1814,7 +1817,15 @@ def _minimize(tape, optimizer, loss, trainable_variables):
|
|||
|
||||
gradients = tape.gradient(loss, trainable_variables)
|
||||
|
||||
if optimizer._HAS_AGGREGATE_GRAD: # pylint: disable=protected-access
|
||||
# Whether to aggregate gradients outside of optimizer. This requires support
|
||||
# of the optimizer and doesn't work with ParameterServerStrategy and
|
||||
# CentralStroageStrategy.
|
||||
aggregate_grads_outside_optimizer = (
|
||||
optimizer._HAS_AGGREGATE_GRAD and # pylint: disable=protected-access
|
||||
not isinstance(strategy.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended))
|
||||
|
||||
if aggregate_grads_outside_optimizer:
|
||||
# We aggregate gradients before unscaling them, in case a subclass of
|
||||
# LossScaleOptimizer all-reduces in fp16. All-reducing in fp16 can only be
|
||||
# done on scaled gradients, not unscaled gradients, for numeric stability.
|
||||
|
@ -1824,7 +1835,7 @@ def _minimize(tape, optimizer, loss, trainable_variables):
|
|||
gradients = optimizer.get_unscaled_gradients(gradients)
|
||||
gradients = optimizer._clip_gradients(gradients) # pylint: disable=protected-access
|
||||
if trainable_variables:
|
||||
if optimizer._HAS_AGGREGATE_GRAD: # pylint: disable=protected-access
|
||||
if aggregate_grads_outside_optimizer:
|
||||
optimizer.apply_gradients(
|
||||
zip(gradients, trainable_variables),
|
||||
experimental_aggregate_gradients=False)
|
||||
|
|
|
@ -40,6 +40,7 @@ py_library(
|
|||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/keras:backend",
|
||||
|
|
|
@ -26,6 +26,7 @@ import functools
|
|||
import six
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
|
@ -491,6 +492,14 @@ class OptimizerV2(trackable.Trackable):
|
|||
"Use `tf.distribute.Strategy.run` to enter replica "
|
||||
"context.")
|
||||
|
||||
strategy = distribute_ctx.get_strategy()
|
||||
if (not experimental_aggregate_gradients and strategy and isinstance(
|
||||
strategy.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended)):
|
||||
raise NotImplementedError(
|
||||
"`experimental_aggregate_gradients=False is not supported for "
|
||||
"ParameterServerStrategy and CentralStorageStrategy")
|
||||
|
||||
apply_state = self._prepare(var_list)
|
||||
if experimental_aggregate_gradients:
|
||||
reduced_grads = self._aggregate_gradients(grads_and_vars)
|
||||
|
|
Loading…
Reference in New Issue