From c2926ab744228f2c35a74b562ed633f40bbe27a2 Mon Sep 17 00:00:00 2001
From: Reed Wanderman-Milne <reedwm@google.com>
Date: Wed, 23 Sep 2020 11:35:46 -0700
Subject: [PATCH] Support clipping in LossScaleOptimizer.

This has been made possible by previous refactors to the base Optimizer.

PiperOrigin-RevId: 333336359
Change-Id: Ie31d96ff9b27e3961940c9a221486963be02ced7
---
 .../experimental/loss_scale_optimizer.py      | 36 ++++++++++-----
 .../experimental/loss_scale_optimizer_test.py | 46 +++++++++++++++++++
 2 files changed, 70 insertions(+), 12 deletions(-)

diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
index c309fafa4a4..eb31c647ca3 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
@@ -248,20 +248,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
     if not isinstance(optimizer, optimizer_v2.OptimizerV2):
       raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
                        'got: %s' % optimizer)
-    if optimizer.clipnorm is not None:
-      raise ValueError('LossScaleOptimizer does not support wrapping '
-                       'optimizers with a clipnorm. Optimizer %s has clipnorm '
-                       '%s' % (optimizer, optimizer.clipnorm))
-
-    if optimizer.clipvalue is not None:
-      raise ValueError('LossScaleOptimizer does not support wrapping '
-                       'optimizers with a clipvalue. Optimizer %s has '
-                       'clipvalue %s' % (optimizer, optimizer.clipvalue))
     self._raise_if_strategy_unsupported()
 
-    self.clipnorm = None
-    self.clipvalue = None
-
     self._optimizer = optimizer
     self._loss_scale = keras_loss_scale_module.get(loss_scale)
     if self._loss_scale is None:
@@ -486,6 +474,30 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
   def set_weights(self, weights):
     return self._optimizer.set_weights(weights)
 
+  @property
+  def clipnorm(self):
+    return self._optimizer.clipnorm
+
+  @clipnorm.setter
+  def clipnorm(self, val):
+    self._optimizer.clipnorm = val
+
+  @property
+  def global_clipnorm(self):
+    return self._optimizer.global_clipnorm
+
+  @global_clipnorm.setter
+  def global_clipnorm(self, val):
+    self._optimizer.global_clipnorm = val
+
+  @property
+  def clipvalue(self):
+    return self._optimizer.clipvalue
+
+  @clipvalue.setter
+  def clipvalue(self, val):
+    self._optimizer.clipvalue = val
+
   def _aggregate_gradients(self, grads_and_vars):
     return self._optimizer._aggregate_gradients(grads_and_vars)  # pylint: disable=protected-access
 
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index 35c25c1da8c..e375fcf557b 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -189,6 +189,52 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
       # 1.
       self.assertAllClose([1.], self.evaluate(var))
 
+  # pylint: disable=cell-var-from-loop
+  @parameterized.named_parameters(*TESTCASES)
+  def testClipping(self, strategy_fn):
+    strategy = strategy_fn()
+    learning_rate = 2.
+    for clip_type in ('clipnorm', 'global_clipnorm', 'clipvalue'):
+      with strategy.scope(), self.subTest(clip_type=clip_type):
+        var = variables.Variable([5.0])
+        opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0})
+        loss_scale = loss_scale_module.DynamicLossScale(
+            initial_loss_scale=2, increment_period=1, multiplier=2)
+        opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
+        self.assertEqual(getattr(opt, clip_type), 2.0)
+        self.assertEqual(
+            loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0)
+
+        loss = lambda: var * 4 / strategy.num_replicas_in_sync
+        run_fn = lambda: opt.minimize(loss, var_list=[var])
+
+        # Test running with clipped gradients
+        run_op = strategy.experimental_run(run_fn)
+        self.evaluate(variables.global_variables_initializer())
+        self._run_if_in_graph_mode(run_op)
+        # The gradient is 4 but is clipped to 2, so the variable will be
+        # init_val - clipped_grad * lr == 5 - 2 * 2 == 1
+        self.assertAllClose([1.], self.evaluate(var))
+        self.assertEqual(self.evaluate(opt.loss_scale()), 4)
+
+        # Test changing the clip amount and running again
+        setattr(opt, clip_type, 3.0)
+        run_op = strategy.experimental_run(run_fn)
+        self._run_if_in_graph_mode(run_op)
+        # The gradient is 4 but is clipped to 3, so the variable will be
+        # prev_var - clipped_grad * lr == 1 - 3 * 2 == -5
+        self.assertAllClose([-5.], self.evaluate(var))
+        self.assertEqual(self.evaluate(opt.loss_scale()), 8)
+
+        # Test Inf gradients are still skipped instead of being clipped
+        loss = lambda: var * float('Inf')
+        run_fn = lambda: opt.minimize(loss, var_list=[var])
+        run_op = strategy.experimental_run(run_fn)
+        self._run_if_in_graph_mode(run_op)
+        self.assertAllClose([-5.], self.evaluate(var))  # Var does not change
+        self.assertEqual(self.evaluate(opt.loss_scale()), 4)
+  # pylint: enable=cell-var-from-loop
+
   @parameterized.named_parameters(*TESTCASES)
   def testDynamicUpdate(self, strategy_fn):
     with strategy_fn().scope() as strategy: