Fix order-dependent mixed precision test.

PiperOrigin-RevId: 302992910
Change-Id: Ifc776a95eb6a69ba2c5806dc9afeacdfbff9bcb7
This commit is contained in:
Reed Wanderman-Milne 2020-03-25 16:12:42 -07:00 committed by TensorFlower Gardener
parent 4997cdbbaa
commit ae33348a36
1 changed files with 2 additions and 0 deletions

View File

@ -166,6 +166,8 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
'not passing any loss_scale instead.') 'not passing any loss_scale instead.')
for policy_name in 'float16', 'mixed_float16': for policy_name in 'float16', 'mixed_float16':
# Trigger any other warnings that occur only once
mp_policy.Policy(policy_name, loss_scale=2.)
with test.mock.patch.object(tf_logging, 'warn') as mock_warn: with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name, loss_scale=2.) mp_policy.Policy(policy_name, loss_scale=2.)
mock_warn.assert_not_called() mock_warn.assert_not_called()