Fix None gradient test

This commit is contained in:
Matt Conley 2019-04-16 11:47:25 -07:00
parent e73a77d19d
commit 4896cc21ab

View File

@ -253,10 +253,9 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale) self._test_helper(inputs, expected_outputs, init_loss_scale)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self, strategy_fn): def test_update_with_none_gradients(self):
with strategy_fn().scope(): loss_scaler = loss_scale_module.DynamicLossScale()
loss_scaler = loss_scale_module.DynamicLossScale() loss_scaler.update([None])
loss_scaler.update([None])
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def test_get(self): def test_get(self):