STT-tensorflow/tensorflow/python/training/momentum_test.py
Cesar Crusius 2819e2a272 Remove reset_test argument from test_util.run_in_graph_and_eager_modes
That argument was only used with its default `True` value. Removing it
makes reasoning about the code, and improving its logic, easier. In
particular, it will make it easier to enforce the resetting of the
eager context in between test calls.

PiperOrigin-RevId: 309306414
Change-Id: Ie98b6586f51cb01a5cdf6e76f76194bd77220d85
2020-04-30 15:09:55 -07:00

606 lines
27 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Momentum."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
class MomentumOptimizerTest(test.TestCase):
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
var = var + accum * lr * momentum
accum = accum * momentum + g
var = var - lr * accum
var = var - accum * lr * momentum
return var, accum
def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
if use_resource:
var0 = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype, name="var0_%d" % i)
var1 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtype, name="var1_%d" % i)
else:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
learning_rate = lambda: 2.0
momentum = lambda: 0.9
if not use_callable_params:
learning_rate = learning_rate()
momentum = momentum()
mom_opt = momentum_lib.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Check we have slots
self.assertEqual(["momentum"], mom_opt.get_slot_names())
slot0 = mom_opt.get_slot(var0, "momentum")
self.assertEquals(slot0.get_shape(), var0.get_shape())
slot1 = mom_opt.get_slot(var1, "momentum")
self.assertEquals(slot1.get_shape(), var1.get_shape())
if not context.executing_eagerly():
self.assertFalse(slot0 in variables.trainable_variables())
self.assertFalse(slot1 in variables.trainable_variables())
# Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate
if not context.executing_eagerly():
self.evaluate(mom_update)
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update.
if context.executing_eagerly():
mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
else:
self.evaluate(mom_update)
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([
2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
(0.9 * 0.01 + 0.01) * 2.0)
]), self.evaluate(var1))
def testBasic(self):
with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
def testBasicCallableParams(self):
with context.eager_mode():
self.doTestBasic(use_resource=True, use_callable_params=True)
def testVariablesAcrossGraphs(self):
optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5)
with ops.Graph().as_default():
var0 = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtypes.float32, name="var0")
var1 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var1")
loss = math_ops.reduce_sum(var0 + var1)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var0")
self.assertStartsWith(optimizer_variables[1].name, "var1")
self.assertEquals(2, len(optimizer_variables))
with ops.Graph().as_default():
var2 = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtypes.float32, name="var2")
var3 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var3")
loss = math_ops.reduce_sum(var2 + var3)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var2")
self.assertStartsWith(optimizer_variables[1].name, "var3")
self.assertEquals(2, len(optimizer_variables))
@test_util.run_deprecated_v1
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
cost = 5 * var0 * var0 + 3 * var1
global_step = variables.Variable(
array_ops.zeros([], dtypes.int64), name="global_step")
mom_op = momentum_lib.MomentumOptimizer(
learning_rate=2.0, momentum=0.9, use_nesterov=True)
opt_op = mom_op.minimize(cost, global_step, [var0, var1])
variables.global_variables_initializer().run()
for t in range(1, 5):
opt_op.run()
var0_np, accum0_np = self._update_nesterov_momentum_numpy(
var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
accum1_np,
3, 2.0, 0.9)
self.assertAllClose(var0_np, self.evaluate(var0))
self.assertAllClose(var1_np, self.evaluate(var1))
@test_util.run_deprecated_v1
def testSparseNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
grads = []
for t in range(1, 5):
grads.append(var0_np * 10)
var0_np, accum0_np = self._update_nesterov_momentum_numpy(
var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
accum1_np,
3, 2.0, 0.9)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
loss = 5 * var0 * var0 + 3 * var1
mom_op = momentum_lib.MomentumOptimizer(
learning_rate=2.0, momentum=0.9, use_nesterov=True)
x_feed = array_ops.placeholder(dtype)
y_feed = ops.IndexedSlices(
x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
grads_and_vars = [(y_feed, var0), (constant_op.constant(
[3.0, 3.0], dtype=dtype), var1)]
opt_update = mom_op.apply_gradients(grads_and_vars)
variables.global_variables_initializer().run()
for t in range(1, 5):
opt_update.run(feed_dict={x_feed: grads[t - 1]})
var0_np, accum0_np = self._update_nesterov_momentum_numpy(
var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
accum1_np,
3, 2.0, 0.9)
self.assertAllClose(var0_np, self.evaluate(var0))
self.assertAllClose(var1_np, self.evaluate(var1))
@test_util.run_in_graph_and_eager_modes
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
# This test invokes the ResourceSparseApplyMomentum operation, which
# did not have a registered GPU kernel as of April 2018. With graph
# execution, the placement algorithm notices this and automatically
# places the variable in CPU (host) memory. With eager execution,
# the variable would be placed in GPU memory if available, which
# would then conflict with the future invocation of the
# ResourceSparseApplyMomentum operation.
# To work around this discrepancy, for now we force the variable
# to be placed on CPU.
with ops.device("/cpu:0"):
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
# pylint: disable=cell-var-from-loop
def loss():
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
return pred * pred
# pylint: enable=cell-var-from-loop
opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
sgd_op = opt.minimize(loss)
self.evaluate(variables.global_variables_initializer())
# Run 1 step of sgd
self.evaluate(sgd_op)
# Validate updated params
self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
@test_util.run_in_graph_and_eager_modes
def testMinimizeWith2DIndicesForEmbeddingLookup(self):
# This test invokes the ResourceSparseApplyMomentum operation, which
# did not have a registered GPU kernel as of April 2018. With graph
# execution, the placement algorithm notices this and automatically
# places the variable in CPU (host) memory. With eager execution,
# the variable would be placed in GPU memory if available, which
# would then conflict with the future invocation of the
# ResourceSparseApplyMomentum operation.
# To work around this discrepancy, for now we force the variable
# to be placed on CPU.
with ops.device("/cpu:0"):
var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
def loss():
return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
sgd_op = opt.minimize(loss)
self.evaluate(variables.global_variables_initializer())
self.evaluate(sgd_op)
self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
@test_util.run_deprecated_v1
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
mom_opt = momentum_lib.MomentumOptimizer(
learning_rate=constant_op.constant(2.0),
momentum=constant_op.constant(0.9))
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Check we have slots
self.assertEqual(["momentum"], mom_opt.get_slot_names())
slot0 = mom_opt.get_slot(var0, "momentum")
self.assertEquals(slot0.get_shape(), var0.get_shape())
self.assertFalse(slot0 in variables.trainable_variables())
slot1 = mom_opt.get_slot(var1, "momentum")
self.assertEquals(slot1.get_shape(), var1.get_shape())
self.assertFalse(slot1 in variables.trainable_variables())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update.
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([
2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
]), self.evaluate(var1))
def _dbParamsMom01(self):
"""Return dist-belief momentum values.
Return values been generated from the dist-belief momentum unittest,
running with a learning rate of 0.1 and a momentum of 0.1.
These values record how a parameter vector of size 10, initialized with 0.0,
gets updated with 10 consecutive momentum steps. It uses random gradients.
Returns:
db_grad: The gradients to apply
db_out: The parameters after the momentum update.
"""
db_grad = [[]] * 10
db_out = [[]] * 10
# pylint: disable=line-too-long
db_grad[0] = [
0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018,
0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615
]
db_out[0] = [
-9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018,
-0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618
]
db_grad[1] = [
0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378,
0.5513742, 0.94687688, 0.16012503, 0.22159521
]
db_out[1] = [
-0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884,
-0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544
]
db_grad[2] = [
0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965,
0.31168157, 0.43203235, 0.16792089, 0.24644311
]
db_out[2] = [
-0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978,
-0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189
]
db_grad[3] = [
0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098,
0.81454384, 0.03848977, 0.89759839, 0.93665648
]
db_out[3] = [
-0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105,
-0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303
]
db_grad[4] = [
0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359,
0.69107032, 0.81897682, 0.5433259, 0.67860287
]
db_out[4] = [
-0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165,
-0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544
]
db_grad[5] = [
0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563,
0.84163809, 0.41172323, 0.83259648, 0.44941229
]
db_out[5] = [
-0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094,
-0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717
]
db_grad[6] = [
0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221,
0.73577434, 0.16014607, 0.57500273, 0.071136251
]
db_out[6] = [
-0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685,
-0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997
]
db_grad[7] = [
0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646,
0.74053431, 0.16033, 0.66625422, 0.73515922
]
db_out[7] = [
-0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838,
-0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418
]
db_grad[8] = [
0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039,
0.55561525, 0.22567581, 0.93331909, 0.29438227
]
db_out[8] = [
-0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527,
-0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781
]
db_grad[9] = [
0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893,
0.68593478, 0.50580865, 0.12602448, 0.093537711
]
db_out[9] = [
-0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302,
-0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295
]
# pylint: enable=line-too-long
return db_grad, db_out
@test_util.run_deprecated_v1
def testLikeDistBeliefMom01(self):
with self.cached_session():
db_grad, db_out = self._dbParamsMom01()
num_samples = len(db_grad)
var0 = variables.Variable([0.0] * num_samples)
grads0 = constant_op.constant([0.0] * num_samples)
mom_opt = momentum_lib.MomentumOptimizer(learning_rate=0.1, momentum=0.1)
mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
variables.global_variables_initializer().run()
for i in xrange(num_samples):
mom_update.run(feed_dict={grads0: db_grad[i]})
self.assertAllClose(np.array(db_out[i]), self.evaluate(var0))
@test_util.run_deprecated_v1
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
grads0 = ops.IndexedSlices(
constant_op.constant(
[[.1, .1]], dtype=dtype),
constant_op.constant([1]),
constant_op.constant([4, 2]))
grads1 = ops.IndexedSlices(
constant_op.constant(
[[.01, .01], [.01, .01]], dtype=dtype),
constant_op.constant([2, 3]),
constant_op.constant([4, 2]))
mom_opt = momentum_lib.MomentumOptimizer(
learning_rate=2.0, momentum=0.9)
mom_update = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Check we have slots
self.assertEqual(["momentum"], mom_opt.get_slot_names())
slot0 = mom_opt.get_slot(var0, "momentum")
self.assertEquals(slot0.get_shape(), var0.get_shape())
slot1 = mom_opt.get_slot(var1, "momentum")
self.assertEquals(slot1.get_shape(), var1.get_shape())
# Fetch params to validate initial values
self.assertAllClose([0, 0], self.evaluate(var0)[0])
self.assertAllClose([0, 0], self.evaluate(var0)[1])
self.assertAllClose([1, 1], self.evaluate(var1)[2])
# Step 1: the momentum accumulators are 0. So we should see a normal
# update: v -= grad * learning_rate
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([0, 0]),
self.evaluate(slot0)[0])
self.assertAllCloseAccordingToType(
np.array([.1, .1]),
self.evaluate(slot0)[1])
self.assertAllCloseAccordingToType(
np.array([.01, .01]),
self.evaluate(slot1)[2])
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([0, 0]),
self.evaluate(var0)[0])
self.assertAllCloseAccordingToType(
np.array([-(0.1 * 2.0), -(0.1 * 2.0)]),
self.evaluate(var0)[1])
self.assertAllCloseAccordingToType(
np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]),
self.evaluate(var1)[2])
# Step 2: the momentum accumulators contain the previous update.
mom_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllClose(np.array([0, 0]), self.evaluate(slot0)[0])
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0)[1])
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1)[2])
# Check that the parameters have been updated.
self.assertAllClose(np.array([0, 0]), self.evaluate(var0)[0])
self.assertAllCloseAccordingToType(
np.array([
-(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
-(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]),
self.evaluate(var0)[1])
self.assertAllCloseAccordingToType(
np.array([
0.98 - ((0.9 * 0.01 + 0.01) * 2.0),
0.98 - ((0.9 * 0.01 + 0.01) * 2.0)
]),
self.evaluate(var1)[2])
@test_util.run_deprecated_v1
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
mom_opt = momentum_lib.MomentumOptimizer(
learning_rate=2.0, momentum=0.9)
mom_update1 = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
mom_update2 = mom_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
self.assertEqual(["momentum"], mom_opt.get_slot_names())
slot0 = mom_opt.get_slot(var0, "momentum")
self.assertEquals(slot0.get_shape(), var0.get_shape())
slot1 = mom_opt.get_slot(var1, "momentum")
self.assertEquals(slot1.get_shape(), var1.get_shape())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate
mom_update1.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
# Step 2: the second momentum accumulators contain the previous update.
mom_update2.run()
# Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
# Check that the parameters have been updated.
self.assertAllCloseAccordingToType(
np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]), self.evaluate(var0))
self.assertAllCloseAccordingToType(
np.array([
2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
]), self.evaluate(var1))
if __name__ == "__main__":
test.main()