Minor fix to allow iterations variable to update in eager mode
PiperOrigin-RevId: 209644988
This commit is contained in:
parent
d648d7e6e1
commit
fce0a4eaab
@ -699,7 +699,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
|
|||||||
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
||||||
|
|
||||||
def apply_gradients(self, grads):
|
def apply_gradients(self, grads):
|
||||||
self.optimizer.apply_gradients(grads)
|
self.optimizer.apply_gradients(grads, global_step=self.iterations)
|
||||||
|
|
||||||
def get_grads(self, loss, params):
|
def get_grads(self, loss, params):
|
||||||
return self.optimizer.compute_gradients(loss, params)
|
return self.optimizer.compute_gradients(loss, params)
|
||||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training.adam import AdamOptimizer
|
from tensorflow.python.training.adam import AdamOptimizer
|
||||||
@ -153,6 +155,7 @@ class KerasOptimizersTest(test.TestCase):
|
|||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
optimizer.from_config(None)
|
optimizer.from_config(None)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_tfoptimizer_iterations(self):
|
def test_tfoptimizer_iterations(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
|
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
|
||||||
@ -169,11 +172,15 @@ class KerasOptimizersTest(test.TestCase):
|
|||||||
verbose=0)
|
verbose=0)
|
||||||
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
|
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
|
||||||
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
# TODO(kathywu): investigate why training with an array input and
|
||||||
|
# setting the argument steps_per_epoch does not work in eager mode.
|
||||||
model.fit(np.random.random((20, 3)),
|
model.fit(np.random.random((20, 3)),
|
||||||
np.random.random((20, 2)),
|
np.random.random((20, 2)),
|
||||||
steps_per_epoch=8,
|
steps_per_epoch=8,
|
||||||
verbose=0)
|
verbose=0)
|
||||||
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19)
|
self.assertEqual(
|
||||||
|
keras.backend.get_value(model.optimizer.iterations), 19)
|
||||||
|
|
||||||
def test_negative_clipvalue_or_clipnorm(self):
|
def test_negative_clipvalue_or_clipnorm(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user