Fork the mp test util function to break the dependency between Keras and TF.
PiperOrigin-RevId: 315955944 Change-Id: I40dae53dd634ad5bdd236bbe0efaa9a260807c54
This commit is contained in:
parent
04f033b847
commit
221d47c3b9
@ -25,8 +25,11 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -63,6 +66,50 @@ def get_gradients(opt, loss, params):
|
||||
return grads
|
||||
|
||||
|
||||
def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
||||
"""Returns a function that asserts it's gradient has a certain value.
|
||||
|
||||
This serves as a hook to assert intermediate gradients have a certain value.
|
||||
This returns an identity function. The identity's gradient function is also
|
||||
the identity function, except it asserts that the gradient equals
|
||||
`expected_gradient` and has dtype `expected_dtype`.
|
||||
|
||||
Args:
|
||||
expected_gradient: The gradient function asserts that the gradient is this
|
||||
value.
|
||||
expected_dtype: The gradient function asserts the gradient has this dtype.
|
||||
|
||||
Returns:
|
||||
An identity function whose gradient function asserts the gradient has a
|
||||
certain value.
|
||||
"""
|
||||
@custom_gradient.custom_gradient
|
||||
def _identity_with_grad_check(x):
|
||||
"""Function that asserts it's gradient has a certain value."""
|
||||
x = array_ops.identity(x)
|
||||
def grad(dx):
|
||||
"""Gradient function that asserts the gradient has a certain value."""
|
||||
if expected_dtype:
|
||||
assert dx.dtype == expected_dtype, (
|
||||
'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
|
||||
expected_tensor = ops.convert_to_tensor_v2(
|
||||
expected_gradient, dtype=dx.dtype, name='expected_gradient')
|
||||
# Control dependency is to ensure input is available. It's possible the
|
||||
# dataset will throw a StopIteration to indicate there is no more data, in
|
||||
# which case we don't want to run the assertion.
|
||||
with ops.control_dependencies([x]):
|
||||
assert_op = check_ops.assert_equal(dx, expected_tensor)
|
||||
with ops.control_dependencies([assert_op]):
|
||||
dx = array_ops.identity(dx)
|
||||
return dx
|
||||
return x, grad
|
||||
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||
# function. So we define and return a non-decorated function.
|
||||
def identity_with_grad_check(x):
|
||||
return _identity_with_grad_check(x)
|
||||
return identity_with_grad_check
|
||||
|
||||
|
||||
class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@ -75,7 +122,7 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
||||
self.evaluate(val)
|
||||
|
||||
def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
|
||||
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
|
||||
grad_check_fn = create_identity_with_grad_check_fn(
|
||||
expected_grad)
|
||||
loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
|
||||
return lambda: opt.minimize(loss, var_list=[var])
|
||||
@ -108,12 +155,12 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
||||
opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
loss_scale = 10.
|
||||
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
|
||||
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale)
|
||||
grad_check_fn = create_identity_with_grad_check_fn(loss_scale)
|
||||
loss = grad_check_fn(var)
|
||||
run_op = get_gradients(opt, loss, [var])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# This will cause an assertion to run, as
|
||||
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
|
||||
# create_identity_with_grad_check_fn added an assertion op.
|
||||
self.evaluate(run_op)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
|
Loading…
Reference in New Issue
Block a user