diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer_test.py b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py
index ef82696c997..7e4e6983905 100644
--- a/tensorflow/python/training/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py
@@ -25,8 +25,11 @@ from absl.testing import parameterized
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
-from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -63,6 +66,50 @@ def get_gradients(opt, loss, params):
   return grads
 
 
+def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
+  """Returns a function that asserts it's gradient has a certain value.
+
+  This serves as a hook to assert intermediate gradients have a certain value.
+  This returns an identity function. The identity's gradient function is also
+  the identity function, except it asserts that the gradient equals
+  `expected_gradient` and has dtype `expected_dtype`.
+
+  Args:
+    expected_gradient: The gradient function asserts that the gradient is this
+      value.
+    expected_dtype: The gradient function asserts the gradient has this dtype.
+
+  Returns:
+    An identity function whose gradient function asserts the gradient has a
+    certain value.
+  """
+  @custom_gradient.custom_gradient
+  def _identity_with_grad_check(x):
+    """Function that asserts it's gradient has a certain value."""
+    x = array_ops.identity(x)
+    def grad(dx):
+      """Gradient function that asserts the gradient has a certain value."""
+      if expected_dtype:
+        assert dx.dtype == expected_dtype, (
+            'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
+      expected_tensor = ops.convert_to_tensor_v2(
+          expected_gradient, dtype=dx.dtype, name='expected_gradient')
+      # Control dependency is to ensure input is available. It's possible the
+      # dataset will throw a StopIteration to indicate there is no more data, in
+      # which case we don't want to run the assertion.
+      with ops.control_dependencies([x]):
+        assert_op = check_ops.assert_equal(dx, expected_tensor)
+      with ops.control_dependencies([assert_op]):
+        dx = array_ops.identity(dx)
+      return dx
+    return x, grad
+  # Keras sometimes has trouble serializing Lambda layers with a decorated
+  # function. So we define and return a non-decorated function.
+  def identity_with_grad_check(x):
+    return _identity_with_grad_check(x)
+  return identity_with_grad_check
+
+
 class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
                                            parameterized.TestCase):
 
@@ -75,7 +122,7 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
       self.evaluate(val)
 
   def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
-    grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
+    grad_check_fn = create_identity_with_grad_check_fn(
         expected_grad)
     loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
     return lambda: opt.minimize(loss, var_list=[var])
@@ -108,12 +155,12 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
     opt = gradient_descent.GradientDescentOptimizer(1.0)
     loss_scale = 10.
     opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
-    grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale)
+    grad_check_fn = create_identity_with_grad_check_fn(loss_scale)
     loss = grad_check_fn(var)
     run_op = get_gradients(opt, loss, [var])
     self.evaluate(variables.global_variables_initializer())
     # This will cause an assertion to run, as
-    # mp_test_util.create_identity_with_grad_check_fn added an assertion op.
+    # create_identity_with_grad_check_fn added an assertion op.
     self.evaluate(run_op)
 
   @parameterized.named_parameters(*TESTCASES)