Minor fix to allow iterations variable to update in eager mode
PiperOrigin-RevId: 209644988
This commit is contained in:
parent
d648d7e6e1
commit
fce0a4eaab
@ -699,7 +699,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
|
||||
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.optimizer.apply_gradients(grads)
|
||||
self.optimizer.apply_gradients(grads, global_step=self.iterations)
|
||||
|
||||
def get_grads(self, loss, params):
|
||||
return self.optimizer.compute_gradients(loss, params)
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.adam import AdamOptimizer
|
||||
@ -153,6 +155,7 @@ class KerasOptimizersTest(test.TestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
optimizer.from_config(None)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_tfoptimizer_iterations(self):
|
||||
with self.test_session():
|
||||
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
|
||||
@ -169,11 +172,15 @@ class KerasOptimizersTest(test.TestCase):
|
||||
verbose=0)
|
||||
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
|
||||
|
||||
model.fit(np.random.random((20, 3)),
|
||||
np.random.random((20, 2)),
|
||||
steps_per_epoch=8,
|
||||
verbose=0)
|
||||
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19)
|
||||
if not context.executing_eagerly():
|
||||
# TODO(kathywu): investigate why training with an array input and
|
||||
# setting the argument steps_per_epoch does not work in eager mode.
|
||||
model.fit(np.random.random((20, 3)),
|
||||
np.random.random((20, 2)),
|
||||
steps_per_epoch=8,
|
||||
verbose=0)
|
||||
self.assertEqual(
|
||||
keras.backend.get_value(model.optimizer.iterations), 19)
|
||||
|
||||
def test_negative_clipvalue_or_clipnorm(self):
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user