In MovingAverageOptimizer, delegate compute_gradients() to the wrapped optimizer,

which is a bug fix in case the wrapper optimizer (or any other optimizer in the stack)
does something non-standard in its compute_gradients() method.

PiperOrigin-RevId: 178447959
This commit is contained in:
A. Unique TensorFlower 2017-12-08 17:07:39 -08:00 committed by TensorFlower Gardener
parent b1c64c61ad
commit 74780531e9
2 changed files with 34 additions and 0 deletions

View File

@ -86,6 +86,9 @@ class MovingAverageOptimizer(optimizer.Optimizer):
self._variable_map = None
self._sequential_update = sequential_update
def compute_gradients(self, *args, **kwargs):
return self._optimizer.compute_gradients(*args, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
train_op = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)

View File

@ -116,6 +116,37 @@ class MovingAverageOptimizerTest(test.TestCase):
with self.assertRaises(RuntimeError):
_ = opt.swapping_saver([var])
def testCorrectOverride(self):
class WrapperOptimizer(gradient_descent.GradientDescentOptimizer):
def compute_gradients(self, *args, **kwargs):
self.compute_gradients_called = True
return super(WrapperOptimizer, self).compute_gradients(
*args, **kwargs)
def apply_gradients(self, *args, **kwargs):
self.apply_gradients_called = True
return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs)
with self.test_session() as sess:
var = variables.Variable([1.2], name='var', dtype=dtypes.float32)
loss = var ** 2
wrapper_opt = WrapperOptimizer(learning_rate=2.0)
opt = moving_average_optimizer.MovingAverageOptimizer(wrapper_opt)
train_op = opt.minimize(loss)
# Check that both methods are called on the underlying optimizer.
self.assertTrue(wrapper_opt.compute_gradients_called)
self.assertTrue(wrapper_opt.apply_gradients_called)
# Run train_op once, and verify that we've updated the variable.
variables.global_variables_initializer().run()
sess.run(train_op)
var_value = sess.run(var)
# Started at 1.2, gradient is 2*1.2=2.4, lr=2, so should now be -3.6.
self.assertNear(-3.6, var_value, 1e-6)
if __name__ == '__main__':
test.main()