Minor fix to allow iterations variable to update in eager mode

PiperOrigin-RevId: 209644988
This commit is contained in:
Katherine Wu 2018-08-21 12:59:45 -07:00 committed by TensorFlower Gardener
parent d648d7e6e1
commit fce0a4eaab
2 changed files with 13 additions and 6 deletions

View File

@ -699,7 +699,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
self.iterations = K.variable(0, dtype='int64', name='iterations')
def apply_gradients(self, grads):
self.optimizer.apply_gradients(grads)
self.optimizer.apply_gradients(grads, global_step=self.iterations)
def get_grads(self, loss, params):
return self.optimizer.compute_gradients(loss, params)

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
from tensorflow.python.training.adam import AdamOptimizer
@ -153,6 +155,7 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
with self.test_session():
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
@ -169,11 +172,15 @@ class KerasOptimizersTest(test.TestCase):
verbose=0)
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
model.fit(np.random.random((20, 3)),
np.random.random((20, 2)),
steps_per_epoch=8,
verbose=0)
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19)
if not context.executing_eagerly():
# TODO(kathywu): investigate why training with an array input and
# setting the argument steps_per_epoch does not work in eager mode.
model.fit(np.random.random((20, 3)),
np.random.random((20, 2)),
steps_per_epoch=8,
verbose=0)
self.assertEqual(
keras.backend.get_value(model.optimizer.iterations), 19)
def test_negative_clipvalue_or_clipnorm(self):
with self.assertRaises(ValueError):