Fork the mp test util function to break the dependency between Keras and TF.
PiperOrigin-RevId: 315955944 Change-Id: I40dae53dd634ad5bdd236bbe0efaa9a260807c54
This commit is contained in:
parent
04f033b847
commit
221d47c3b9
@ -25,8 +25,11 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -63,6 +66,50 @@ def get_gradients(opt, loss, params):
|
|||||||
return grads
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
||||||
|
"""Returns a function that asserts it's gradient has a certain value.
|
||||||
|
|
||||||
|
This serves as a hook to assert intermediate gradients have a certain value.
|
||||||
|
This returns an identity function. The identity's gradient function is also
|
||||||
|
the identity function, except it asserts that the gradient equals
|
||||||
|
`expected_gradient` and has dtype `expected_dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_gradient: The gradient function asserts that the gradient is this
|
||||||
|
value.
|
||||||
|
expected_dtype: The gradient function asserts the gradient has this dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An identity function whose gradient function asserts the gradient has a
|
||||||
|
certain value.
|
||||||
|
"""
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def _identity_with_grad_check(x):
|
||||||
|
"""Function that asserts it's gradient has a certain value."""
|
||||||
|
x = array_ops.identity(x)
|
||||||
|
def grad(dx):
|
||||||
|
"""Gradient function that asserts the gradient has a certain value."""
|
||||||
|
if expected_dtype:
|
||||||
|
assert dx.dtype == expected_dtype, (
|
||||||
|
'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
|
||||||
|
expected_tensor = ops.convert_to_tensor_v2(
|
||||||
|
expected_gradient, dtype=dx.dtype, name='expected_gradient')
|
||||||
|
# Control dependency is to ensure input is available. It's possible the
|
||||||
|
# dataset will throw a StopIteration to indicate there is no more data, in
|
||||||
|
# which case we don't want to run the assertion.
|
||||||
|
with ops.control_dependencies([x]):
|
||||||
|
assert_op = check_ops.assert_equal(dx, expected_tensor)
|
||||||
|
with ops.control_dependencies([assert_op]):
|
||||||
|
dx = array_ops.identity(dx)
|
||||||
|
return dx
|
||||||
|
return x, grad
|
||||||
|
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||||
|
# function. So we define and return a non-decorated function.
|
||||||
|
def identity_with_grad_check(x):
|
||||||
|
return _identity_with_grad_check(x)
|
||||||
|
return identity_with_grad_check
|
||||||
|
|
||||||
|
|
||||||
class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
@ -75,7 +122,7 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
|||||||
self.evaluate(val)
|
self.evaluate(val)
|
||||||
|
|
||||||
def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
|
def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
|
||||||
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
|
grad_check_fn = create_identity_with_grad_check_fn(
|
||||||
expected_grad)
|
expected_grad)
|
||||||
loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
|
loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
|
||||||
return lambda: opt.minimize(loss, var_list=[var])
|
return lambda: opt.minimize(loss, var_list=[var])
|
||||||
@ -108,12 +155,12 @@ class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
|
|||||||
opt = gradient_descent.GradientDescentOptimizer(1.0)
|
opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||||
loss_scale = 10.
|
loss_scale = 10.
|
||||||
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
|
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
|
||||||
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale)
|
grad_check_fn = create_identity_with_grad_check_fn(loss_scale)
|
||||||
loss = grad_check_fn(var)
|
loss = grad_check_fn(var)
|
||||||
run_op = get_gradients(opt, loss, [var])
|
run_op = get_gradients(opt, loss, [var])
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# This will cause an assertion to run, as
|
# This will cause an assertion to run, as
|
||||||
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
|
# create_identity_with_grad_check_fn added an assertion op.
|
||||||
self.evaluate(run_op)
|
self.evaluate(run_op)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user