In MovingAverageOptimizer, delegate compute_gradients() to the wrapped optimizer,
which is a bug fix in case the wrapper optimizer (or any other optimizer in the stack) does something non-standard in its compute_gradients() method. PiperOrigin-RevId: 178447959
This commit is contained in:
parent
b1c64c61ad
commit
74780531e9
@ -86,6 +86,9 @@ class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
self._variable_map = None
|
||||
self._sequential_update = sequential_update
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
return self._optimizer.compute_gradients(*args, **kwargs)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
train_op = self._optimizer.apply_gradients(
|
||||
grads_and_vars, global_step=global_step, name=name)
|
||||
|
@ -116,6 +116,37 @@ class MovingAverageOptimizerTest(test.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = opt.swapping_saver([var])
|
||||
|
||||
def testCorrectOverride(self):
|
||||
|
||||
class WrapperOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
self.compute_gradients_called = True
|
||||
return super(WrapperOptimizer, self).compute_gradients(
|
||||
*args, **kwargs)
|
||||
|
||||
def apply_gradients(self, *args, **kwargs):
|
||||
self.apply_gradients_called = True
|
||||
return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs)
|
||||
|
||||
with self.test_session() as sess:
|
||||
var = variables.Variable([1.2], name='var', dtype=dtypes.float32)
|
||||
loss = var ** 2
|
||||
wrapper_opt = WrapperOptimizer(learning_rate=2.0)
|
||||
opt = moving_average_optimizer.MovingAverageOptimizer(wrapper_opt)
|
||||
train_op = opt.minimize(loss)
|
||||
|
||||
# Check that both methods are called on the underlying optimizer.
|
||||
self.assertTrue(wrapper_opt.compute_gradients_called)
|
||||
self.assertTrue(wrapper_opt.apply_gradients_called)
|
||||
|
||||
# Run train_op once, and verify that we've updated the variable.
|
||||
variables.global_variables_initializer().run()
|
||||
sess.run(train_op)
|
||||
var_value = sess.run(var)
|
||||
# Started at 1.2, gradient is 2*1.2=2.4, lr=2, so should now be -3.6.
|
||||
self.assertNear(-3.6, var_value, 1e-6)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user