Fix bug where optimizer clipvalue & clipnorm are totally ignored by the Keras training loop. Also raise an error when trying to set clipvalue & clipnorm with a distribution strategy, because global gradient clipping w/ a distribution strategy is not supported just yet.
We are working on reworking optimizers to make this possible in a different way. PiperOrigin-RevId: 294575398 Change-Id: I3d1bb69857d4ced857928e7dc83729c315ed00f6
This commit is contained in:
parent
01cd184620
commit
69da929ad4
@ -274,6 +274,7 @@ def _process_single_batch(model,
|
||||
if isinstance(model.optimizer,
|
||||
loss_scale_optimizer.LossScaleOptimizer):
|
||||
grads = model.optimizer.get_unscaled_gradients(grads)
|
||||
grads = model.optimizer._clip_gradients(grads)
|
||||
model.optimizer.apply_gradients(zip(grads, trainable_weights))
|
||||
else:
|
||||
logging.warning('The list of trainable weights is empty. Make sure that'
|
||||
|
@ -237,7 +237,12 @@ class CorrectnessTest(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_loss_correctness(self):
|
||||
@parameterized.named_parameters([
|
||||
('', dict()),
|
||||
('_clipvalue_inf', {'clipvalue': 999999}),
|
||||
('_clipnorm_inf', {'clipnorm': 999999}),
|
||||
])
|
||||
def test_loss_correctness(self, optimizer_kwargs):
|
||||
# Test that training loss is the same in eager and graph
|
||||
# (by comparing it to a reference value in a deterministic case)
|
||||
layers = [
|
||||
@ -247,7 +252,7 @@ class CorrectnessTest(keras_parameterized.TestCase):
|
||||
model = testing_utils.get_model_from_layers(layers, input_shape=(4,))
|
||||
model.compile(
|
||||
loss='sparse_categorical_crossentropy',
|
||||
optimizer=rmsprop.RMSprop(learning_rate=0.001),
|
||||
optimizer=rmsprop.RMSprop(learning_rate=0.001, **optimizer_kwargs),
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
x = np.ones((100, 4))
|
||||
@ -256,6 +261,30 @@ class CorrectnessTest(keras_parameterized.TestCase):
|
||||
history = model.fit(x, y, epochs=1, batch_size=10)
|
||||
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_loss_correctness_clipvalue_zero(self):
|
||||
# Test that training loss is the same in eager and graph
|
||||
# (by comparing it to a reference value in a deterministic case)
|
||||
# And confirm that setting clipvalue to zero stops all training
|
||||
layers = [
|
||||
keras.layers.Dense(3, activation='relu',
|
||||
kernel_initializer='ones'),
|
||||
keras.layers.Dense(2, activation='softmax', kernel_initializer='ones')]
|
||||
model = testing_utils.get_model_from_layers(layers, input_shape=(4,))
|
||||
model.compile(
|
||||
loss='sparse_categorical_crossentropy',
|
||||
optimizer=rmsprop.RMSprop(learning_rate=0.001, clipvalue=0.0),
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
x = np.ones((100, 4))
|
||||
np.random.seed(123)
|
||||
y = np.random.randint(0, 1, size=(100, 1))
|
||||
history = model.fit(x, y, epochs=3, batch_size=10)
|
||||
self.assertAlmostEqual(history.history['loss'][-3], 0.6931, 4)
|
||||
self.assertAlmostEqual(history.history['loss'][-2], 0.6931, 4)
|
||||
self.assertAlmostEqual(history.history['loss'][-1], 0.6931, 4)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_loss_correctness_with_iterator(self):
|
||||
|
@ -117,16 +117,19 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
|
||||
'got: %s' % optimizer)
|
||||
if hasattr(optimizer, 'clipnorm'):
|
||||
if optimizer.clipnorm is not None:
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipnorm. Optimizer %s has clipnorm '
|
||||
'%s' % (optimizer, optimizer.clipnorm))
|
||||
|
||||
if hasattr(optimizer, 'clipvalue'):
|
||||
if optimizer.clipvalue is not None:
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipvalue. Optimizer %s has '
|
||||
'clipvalue %s' % (optimizer, optimizer.clipvalue))
|
||||
|
||||
self.clipnorm = None
|
||||
self.clipvalue = None
|
||||
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
if self._loss_scale is None:
|
||||
|
@ -279,10 +279,15 @@ class OptimizerV2(trackable.Trackable):
|
||||
if decay < 0.:
|
||||
raise ValueError("decay cannot be less than 0: {}".format(decay))
|
||||
self._initial_decay = decay
|
||||
if "clipnorm" in kwargs:
|
||||
self.clipnorm = kwargs.pop("clipnorm")
|
||||
if "clipvalue" in kwargs:
|
||||
self.clipvalue = kwargs.pop("clipvalue")
|
||||
|
||||
# Set the gradient clipping properties
|
||||
self.clipnorm = kwargs.pop("clipnorm", None)
|
||||
self.clipvalue = kwargs.pop("clipvalue", None)
|
||||
if ((self.clipnorm is not None or self.clipvalue is not None)
|
||||
and distribute_ctx.has_strategy()):
|
||||
raise ValueError("Gradient clipping in the optimizer "
|
||||
"(by setting clipnorm or clipvalue) is currently "
|
||||
"unsupported when using a distribution strategy.")
|
||||
|
||||
self._hypers_created = False
|
||||
|
||||
@ -317,6 +322,25 @@ class OptimizerV2(trackable.Trackable):
|
||||
|
||||
return self.apply_gradients(grads_and_vars, name=name)
|
||||
|
||||
def _clip_gradients(self, grads):
|
||||
"""Clip gradients according to the clipnorm and clipvalue attributes."""
|
||||
if self.clipnorm is not None:
|
||||
if distribute_ctx.has_strategy():
|
||||
raise ValueError("Gradient clipping in the optimizer "
|
||||
"(by setting clipnorm or clipvalue) is currently "
|
||||
"unsupported when using a distribution strategy.")
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if self.clipvalue is not None:
|
||||
if distribute_ctx.has_strategy():
|
||||
raise ValueError("Gradient clipping in the optimizer "
|
||||
"(by setting clipnorm or clipvalue) is currently "
|
||||
"unsupported when using a distribution strategy.")
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
return grads
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||
"""Compute gradients of `loss` for the variables in `var_list`.
|
||||
|
||||
@ -353,14 +377,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
var_list = nest.flatten(var_list)
|
||||
with backend.name_scope(self._name + "/gradients"):
|
||||
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||||
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
grads = self._clip_gradients(grads)
|
||||
|
||||
grads_and_vars = list(zip(grads, var_list))
|
||||
self._assert_valid_dtypes([
|
||||
@ -395,13 +412,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
"gradient defined (i.e. are differentiable). "
|
||||
"Common ops without gradient: "
|
||||
"K.argmax, K.round, K.eval.".format(param))
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
grads = self._clip_gradients(grads)
|
||||
return grads
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
@ -704,9 +715,9 @@ class OptimizerV2(trackable.Trackable):
|
||||
Python dictionary.
|
||||
"""
|
||||
config = {"name": self._name}
|
||||
if hasattr(self, "clipnorm"):
|
||||
if self.clipnorm is not None:
|
||||
config["clipnorm"] = self.clipnorm
|
||||
if hasattr(self, "clipvalue"):
|
||||
if self.clipvalue is not None:
|
||||
config["clipvalue"] = self.clipvalue
|
||||
return config
|
||||
|
||||
|
@ -721,6 +721,11 @@ class TFOptimizer(Optimizer, trackable.Trackable):
|
||||
self.iterations = iterations
|
||||
self._track_trackable(self.iterations, name='global_step')
|
||||
|
||||
def _clip_gradients(self, grads):
|
||||
"""Clip gradients according to the clipnorm and clipvalue attributes."""
|
||||
# TFOptimizer wrapper has no gradient clipping options.
|
||||
return grads
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.optimizer.apply_gradients(grads, global_step=self.iterations)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user