diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 5dffc4c9f7d..4e635478aa7 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -197,6 +197,7 @@ py_library( "//tensorflow/python/eager:monitoring", "//tensorflow/python/keras/distribute", "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", + "//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer", "//tensorflow/python/keras/mixed_precision/experimental:policy", "//tensorflow/python/module", "//tensorflow/python/training/tracking:data_structures", diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index c21758d7d0d..b2fbdffc029 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.utils import losses_utils from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging @@ -234,13 +235,24 @@ def _process_single_batch(model, if total_loss is None: raise ValueError('The model cannot be run ' 'because it has no loss to optimize.') + if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer): + # TODO(reedwm): Make loss_scale public instead of accessing private + # _loss_scale attribute. + loss_scale = model.optimizer._loss_scale() + scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss, + loss_scale) + else: + loss_scale = None + scaled_total_loss = total_loss if training: if not model.trainable_weights: logging.warning('The list of trainable weights is empty. Make sure that' ' you are not setting model.trainable to False before ' 'compiling the model.') else: - grads = tape.gradient(total_loss, model.trainable_weights) + grads = tape.gradient(scaled_total_loss, model.trainable_weights) + if loss_scale is not None: + grads = loss_scale_optimizer.unscale_grads(grads, loss_scale) model.optimizer.apply_gradients(zip(grads, model.trainable_weights)) return outs, total_loss, output_losses, masks diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index d4be5b79157..9c2d76e8efa 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -293,9 +293,14 @@ class KerasModelTest(test.TestCase, parameterized.TestCase): 'testcase_name': 'regularizer', 'strategy_fn': create_mirrored_strategy, 'use_regularizer': True + }, { + 'testcase_name': 'nocloning', + 'strategy_fn': create_mirrored_strategy, + 'cloning': False }) @test_util.run_in_graph_and_eager_modes - def test_model(self, strategy_fn, use_operator=False, use_regularizer=False): + def test_model(self, strategy_fn, use_operator=False, use_regularizer=False, + cloning=True): regularizer = IdentityRegularizer() if use_regularizer else None with strategy_fn().scope(): with policy.policy_scope('infer_float32_vars'): @@ -314,7 +319,7 @@ class KerasModelTest(test.TestCase, parameterized.TestCase): # the variable will not change. So this tests the learning rate not # applied to a float16 value, but instead the float32 variable. opt = gradient_descent.SGD(2 ** -14) - model.compile(opt, loss=loss_fn) + model.compile(opt, loss=loss_fn, cloning=cloning) self.assertEqual(backend.eval(layer.v), 1) x = np.ones((2, 1)) @@ -329,6 +334,53 @@ class KerasModelTest(test.TestCase, parameterized.TestCase): expected -= 2 ** -14 self.assertEqual(backend.eval(layer.v), expected) + @parameterized.named_parameters({ + 'testcase_name': 'base', + 'strategy_fn': default_strategy_fn + }, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'nocloning', + 'strategy_fn': create_mirrored_strategy, + 'cloning': False, + }) + @test_util.run_in_graph_and_eager_modes + def test_fixed_loss_scaling(self, strategy_fn, cloning=True): + # Note: We do not test mixed precision in this method, only loss scaling. + loss_scale = 8. + batch_size = 4 + with strategy_fn().scope(): + x = layers.Input(shape=(1,), batch_size=batch_size) + layer = AddLayer() + y = layer(x) + + # The gradient of 'y' at this point is 1. With loss scaling, the gradient + # is 'loss_scale'. We divide by the batch size since the loss is averaged + # across batch elements. + expected_gradient = loss_scale / batch_size + identity_with_grad_check_fn = ( + mp_test_util.create_identity_with_grad_check_fn([expected_gradient])) + y = core.Lambda(identity_with_grad_check_fn)(y) + model = models.Model(inputs=x, outputs=y) + + def loss_fn(y_true, y_pred): + del y_true + return math_ops.reduce_mean(y_pred) + + opt = gradient_descent.SGD(1.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + model.compile(opt, loss=loss_fn, cloning=cloning) + + self.assertEqual(backend.eval(layer.v), 1) + x = np.ones((batch_size, 1)) + y = np.ones((batch_size, 1)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size) + model.fit(dataset) + # Variable starts at 1, and should have gradient of 1 subtracted from it. + expected = 0 + self.assertEqual(backend.eval(layer.v), expected) + @parameterized.named_parameters({ 'testcase_name': 'base', 'strategy_fn': default_strategy_fn @@ -413,9 +465,13 @@ class KerasModelTest(test.TestCase, parameterized.TestCase): }, { 'testcase_name': 'distribute', 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'nocloning', + 'strategy_fn': create_mirrored_strategy, + 'cloning': False, }) @test_util.run_in_graph_and_eager_modes - def test_dynamic_loss_scaling(self, strategy_fn): + def test_dynamic_loss_scaling(self, strategy_fn, cloning=True): strategy = strategy_fn() initial_loss_scale = 2. batch_size = 4 @@ -449,12 +505,12 @@ class KerasModelTest(test.TestCase, parameterized.TestCase): loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=initial_loss_scale, increment_period=2) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - model.compile(opt, loss=loss_fn) + model.compile(opt, loss=loss_fn, cloning=cloning) self.assertEqual(backend.eval(layer.v), 1) - x = np.ones((2, 1)) - y = np.ones((2, 1)) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2) + x = np.ones((batch_size, 1)) + y = np.ones((batch_size, 1)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size) model.fit(dataset) # The variables starts with 1 and has a gradient of 1, so will go down by 1 # each step. diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index 2963c4294b8..1f201221a3c 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -41,6 +41,20 @@ class _UnwrapPreventer(object): self.value = value +def scale_loss(loss, loss_scale): + """Scales the loss by the loss scale.""" + if callable(loss): + return lambda: loss() * loss_scale + else: + return loss * loss_scale + + +def unscale_grads(grads, loss_scale): + """Unscales the gradients by the loss scale.""" + loss_scale_reciprocal = 1. / loss_scale + return [g * loss_scale_reciprocal if g is not None else None for g in grads] + + @keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') class LossScaleOptimizer(optimizer_v2.OptimizerV2): """An optimizer that applies loss scaling. @@ -101,31 +115,18 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): self._track_trackable(self._loss_scale, 'loss_scale') def _compute_gradients(self, loss, var_list, grad_loss=None): - loss = self._scale_loss(loss) + loss = scale_loss(loss, self._loss_scale()) grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access grad_loss) grads = [g for g, _ in grads_and_vars] variables = [v for _, v in grads_and_vars] - scaled_grads = self._scale_grads(grads) - return list(zip(scaled_grads, variables)) + unscaled_grads = unscale_grads(grads, self._loss_scale()) + return list(zip(unscaled_grads, variables)) def get_gradients(self, loss, params): - loss = self._scale_loss(loss) + loss = scale_loss(loss, self._loss_scale()) grads = self._optimizer.get_gradients(loss, params) - return self._scale_grads(grads) - - def _scale_loss(self, loss): - # The loss is callable for `_compute_gradients`, but not `get_gradients`. - loss_scale = self._loss_scale() - if callable(loss): - return lambda: loss() * loss_scale - else: - return loss * loss_scale - - def _scale_grads(self, grads): - loss_scale = self._loss_scale() - loss_scale_reciprocal = 1 / loss_scale - return [None if g is None else g * loss_scale_reciprocal for g in grads] + return unscale_grads(grads, self._loss_scale()) def apply_gradients(self, grads_and_vars, name=None): if distribution_strategy_context.in_cross_replica_context(): diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py index dd8f27f0c5f..eaa3f022bee 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -119,8 +119,8 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): grads = [g for g, _ in grads_and_vars] variables = [v for _, v in grads_and_vars] - scaled_grads = self._scale_grads(grads) - return list(zip(scaled_grads, variables)) + unscaled_grads = self._unscale_grads(grads) + return list(zip(unscaled_grads, variables)) def _scale_loss(self, loss): loss_scale = self._loss_scale() @@ -128,7 +128,7 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): return lambda: loss() * loss_scale return loss * loss_scale - def _scale_grads(self, grads): + def _unscale_grads(self, grads): loss_scale = self._loss_scale() loss_scale_reciprical = 1 / loss_scale return [