Fork the mp test util function to break the dependency between Keras and TF.

PiperOrigin-RevId: 315955944
Change-Id: I40dae53dd634ad5bdd236bbe0efaa9a260807c54
This commit is contained in:
Scott Zhu 2020-06-11 12:39:11 -07:00 committed by TensorFlower Gardener
parent 04f033b847
commit 221d47c3b9

View File

@ -25,8 +25,11 @@ from absl.testing import parameterized
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -63,6 +66,50 @@ def get_gradients(opt, loss, params):
return grads
def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
"""Returns a function that asserts it's gradient has a certain value.
This serves as a hook to assert intermediate gradients have a certain value.
This returns an identity function. The identity's gradient function is also
the identity function, except it asserts that the gradient equals
`expected_gradient` and has dtype `expected_dtype`.
Args:
expected_gradient: The gradient function asserts that the gradient is this
value.
expected_dtype: The gradient function asserts the gradient has this dtype.
Returns:
An identity function whose gradient function asserts the gradient has a
certain value.
"""
@custom_gradient.custom_gradient
def _identity_with_grad_check(x):
"""Function that asserts it's gradient has a certain value."""
x = array_ops.identity(x)
def grad(dx):
"""Gradient function that asserts the gradient has a certain value."""
if expected_dtype:
assert dx.dtype == expected_dtype, (
'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
expected_tensor = ops.convert_to_tensor_v2(
expected_gradient, dtype=dx.dtype, name='expected_gradient')
# Control dependency is to ensure input is available. It's possible the
# dataset will throw a StopIteration to indicate there is no more data, in
# which case we don't want to run the assertion.
with ops.control_dependencies([x]):
assert_op = check_ops.assert_equal(dx, expected_tensor)
with ops.control_dependencies([assert_op]):
dx = array_ops.identity(dx)
return dx
return x, grad
# Keras sometimes has trouble serializing Lambda layers with a decorated
# function. So we define and return a non-decorated function.
def identity_with_grad_check(x):
return _identity_with_grad_check(x)
return identity_with_grad_check
class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
parameterized.TestCase):
@ -75,7 +122,7 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
self.evaluate(val)
def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
grad_check_fn = create_identity_with_grad_check_fn(
expected_grad)
loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
return lambda: opt.minimize(loss, var_list=[var])
@ -108,12 +155,12 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
opt = gradient_descent.GradientDescentOptimizer(1.0)
loss_scale = 10.
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale)
grad_check_fn = create_identity_with_grad_check_fn(loss_scale)
loss = grad_check_fn(var)
run_op = get_gradients(opt, loss, [var])
self.evaluate(variables.global_variables_initializer())
# This will cause an assertion to run, as
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
# create_identity_with_grad_check_fn added an assertion op.
self.evaluate(run_op)
@parameterized.named_parameters(*TESTCASES)