diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer_test.py b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py index ef82696c997..7e4e6983905 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py @@ -25,8 +25,11 @@ from absl.testing import parameterized from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -63,6 +66,50 @@ def get_gradients(opt, loss, params): return grads +def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None): + """Returns a function that asserts it's gradient has a certain value. + + This serves as a hook to assert intermediate gradients have a certain value. + This returns an identity function. The identity's gradient function is also + the identity function, except it asserts that the gradient equals + `expected_gradient` and has dtype `expected_dtype`. + + Args: + expected_gradient: The gradient function asserts that the gradient is this + value. + expected_dtype: The gradient function asserts the gradient has this dtype. + + Returns: + An identity function whose gradient function asserts the gradient has a + certain value. + """ + @custom_gradient.custom_gradient + def _identity_with_grad_check(x): + """Function that asserts it's gradient has a certain value.""" + x = array_ops.identity(x) + def grad(dx): + """Gradient function that asserts the gradient has a certain value.""" + if expected_dtype: + assert dx.dtype == expected_dtype, ( + 'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype)) + expected_tensor = ops.convert_to_tensor_v2( + expected_gradient, dtype=dx.dtype, name='expected_gradient') + # Control dependency is to ensure input is available. It's possible the + # dataset will throw a StopIteration to indicate there is no more data, in + # which case we don't want to run the assertion. + with ops.control_dependencies([x]): + assert_op = check_ops.assert_equal(dx, expected_tensor) + with ops.control_dependencies([assert_op]): + dx = array_ops.identity(dx) + return dx + return x, grad + # Keras sometimes has trouble serializing Lambda layers with a decorated + # function. So we define and return a non-decorated function. + def identity_with_grad_check(x): + return _identity_with_grad_check(x) + return identity_with_grad_check + + class MixedPrecisionLossScaleOptimizerTest(test.TestCase, parameterized.TestCase): @@ -75,7 +122,7 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase, self.evaluate(val) def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad): - grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( + grad_check_fn = create_identity_with_grad_check_fn( expected_grad) loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync return lambda: opt.minimize(loss, var_list=[var]) @@ -108,12 +155,12 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase, opt = gradient_descent.GradientDescentOptimizer(1.0) loss_scale = 10. opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale) - grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale) + grad_check_fn = create_identity_with_grad_check_fn(loss_scale) loss = grad_check_fn(var) run_op = get_gradients(opt, loss, [var]) self.evaluate(variables.global_variables_initializer()) # This will cause an assertion to run, as - # mp_test_util.create_identity_with_grad_check_fn added an assertion op. + # create_identity_with_grad_check_fn added an assertion op. self.evaluate(run_op) @parameterized.named_parameters(*TESTCASES)