Private implementation of Adam using XLA for fusion.
Because we can rely on tf.function I also refactored and cleaned up the updates a bit; removing stray control dependencies and using methods to update variables and avoiding chaining assignment operations. PiperOrigin-RevId: 301822384 Change-Id: If4cb54e3d7b27c916912d39e5a01c1ff7905b4ba
This commit is contained in:
parent
1dffd2d117
commit
0cfab2b1fa
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend_config
|
from tensorflow.python.keras import backend_config
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
@ -278,3 +279,226 @@ class Adam(optimizer_v2.OptimizerV2):
|
|||||||
'amsgrad': self.amsgrad,
|
'amsgrad': self.amsgrad,
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class NonFusedAdam(optimizer_v2.OptimizerV2):
|
||||||
|
r"""Optimizer that implements the Adam algorithm without fused kernels.
|
||||||
|
|
||||||
|
Adam optimization is a stochastic gradient descent method that is based on
|
||||||
|
adaptive estimation of first-order and second-order moments.
|
||||||
|
According to the paper
|
||||||
|
[Adam: A Method for Stochastic Optimization. Kingma et al.,
|
||||||
|
2014](http://arxiv.org/abs/1412.6980), the method is "*computationally
|
||||||
|
efficient, has little memory requirement, invariant to diagonal rescaling of
|
||||||
|
gradients, and is well suited for problems that are large in terms of
|
||||||
|
data/parameters*".
|
||||||
|
|
||||||
|
For AMSGrad see [On The Convergence Of Adam And Beyond.
|
||||||
|
Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ).
|
||||||
|
|
||||||
|
**If amsgrad = False**:
|
||||||
|
|
||||||
|
initialize $m_0$ as 1st moment vector
|
||||||
|
initialize $v_0$ as 2nd moment vector
|
||||||
|
|
||||||
|
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||||
|
described at the end of section 2 of the paper:
|
||||||
|
|
||||||
|
$$lr_t = \mathrm{learning\_rate} *
|
||||||
|
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||||
|
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||||
|
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||||
|
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||||
|
|
||||||
|
**If amsgrad = True**:
|
||||||
|
|
||||||
|
initialize $m_0$ as 1st moment vector
|
||||||
|
initialize $v_0$ as 2nd moment vector
|
||||||
|
initialize $\hat{v}_0$ as 2nd moment vector
|
||||||
|
|
||||||
|
The update rule for $\theta$ with gradient $g$ uses an optimization
|
||||||
|
described at the end of section 2 of the paper:
|
||||||
|
|
||||||
|
$$lr_t = \mathrm{learning\_rate} *
|
||||||
|
\sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||||
|
|
||||||
|
$$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||||
|
$$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
|
||||||
|
$$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$
|
||||||
|
$$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$
|
||||||
|
|
||||||
|
The default value of 1e-7 for epsilon might not be a good default in
|
||||||
|
general. For example, when training an Inception network on ImageNet a
|
||||||
|
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||||
|
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||||
|
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||||
|
hat" in the paper.
|
||||||
|
|
||||||
|
The sparse implementation of this algorithm (used when the gradient is an
|
||||||
|
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||||
|
lookup in the forward pass) does apply momentum to variable slices even if
|
||||||
|
they were not used in the forward pass (meaning they have a gradient equal
|
||||||
|
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||||
|
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||||
|
behavior (in contrast to some momentum implementations which ignore momentum
|
||||||
|
unless a variable slice was actually used).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
|
||||||
|
>>> var1 = tf.Variable(10.0)
|
||||||
|
>>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1
|
||||||
|
>>> step_count = opt.minimize(loss, [var1]).numpy()
|
||||||
|
>>> # The first step is `-learning_rate*sign(grad)`
|
||||||
|
>>> var1.numpy()
|
||||||
|
9.9
|
||||||
|
"""
|
||||||
|
|
||||||
|
_HAS_ALL_REDUCE_SUM_GRAD = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
learning_rate=0.001,
|
||||||
|
beta_1=0.9,
|
||||||
|
beta_2=0.999,
|
||||||
|
epsilon=1e-7,
|
||||||
|
amsgrad=False,
|
||||||
|
name='Adam',
|
||||||
|
**kwargs):
|
||||||
|
"""Construct a new Adam optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: A `Tensor`, floating point value, or a schedule that is a
|
||||||
|
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that
|
||||||
|
takes no arguments and returns the actual value to use, The learning
|
||||||
|
rate. Defaults to 0.001.
|
||||||
|
beta_1: A float value or a constant float tensor, or a callable that takes
|
||||||
|
no arguments and returns the actual value to use. The exponential decay
|
||||||
|
rate for the 1st moment estimates. Defaults to 0.9.
|
||||||
|
beta_2: A float value or a constant float tensor, or a callable that takes
|
||||||
|
no arguments and returns the actual value to use, The exponential decay
|
||||||
|
rate for the 2nd moment estimates. Defaults to 0.999.
|
||||||
|
epsilon: A small constant for numerical stability. This epsilon is
|
||||||
|
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||||
|
Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
|
||||||
|
1e-7.
|
||||||
|
amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from
|
||||||
|
the paper "On the Convergence of Adam and beyond". Defaults to `False`.
|
||||||
|
name: Optional name for the operations created when applying gradients.
|
||||||
|
Defaults to "Adam".
|
||||||
|
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
|
||||||
|
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
|
||||||
|
gradients by value, `decay` is included for backward compatibility to
|
||||||
|
allow time inverse decay of learning rate. `lr` is included for backward
|
||||||
|
compatibility, recommended to use `learning_rate` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(NonFusedAdam, self).__init__(name, **kwargs)
|
||||||
|
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||||
|
self._set_hyper('decay', self._initial_decay)
|
||||||
|
self._set_hyper('beta_1', beta_1)
|
||||||
|
self._set_hyper('beta_2', beta_2)
|
||||||
|
self.epsilon = epsilon or backend_config.epsilon()
|
||||||
|
self.amsgrad = amsgrad
|
||||||
|
|
||||||
|
def _create_slots(self, var_list):
|
||||||
|
# Create slots for the first and second moments.
|
||||||
|
# Separate for-loops to respect the ordering of slot variables from v1.
|
||||||
|
for var in var_list:
|
||||||
|
self.add_slot(var, 'm')
|
||||||
|
for var in var_list:
|
||||||
|
self.add_slot(var, 'v')
|
||||||
|
if self.amsgrad:
|
||||||
|
for var in var_list:
|
||||||
|
self.add_slot(var, 'vhat')
|
||||||
|
|
||||||
|
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||||
|
super(NonFusedAdam, self)._prepare_local(var_device, var_dtype, apply_state)
|
||||||
|
|
||||||
|
local_step = math_ops.cast(self.iterations + 1, var_dtype)
|
||||||
|
beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
|
||||||
|
beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
|
||||||
|
beta_1_power = math_ops.pow(beta_1_t, local_step)
|
||||||
|
beta_2_power = math_ops.pow(beta_2_t, local_step)
|
||||||
|
lr = (
|
||||||
|
apply_state[(var_device, var_dtype)]['lr_t'] *
|
||||||
|
(math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
|
||||||
|
apply_state[(var_device, var_dtype)].update(
|
||||||
|
dict(
|
||||||
|
lr=lr,
|
||||||
|
epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
|
||||||
|
beta_1_t=beta_1_t,
|
||||||
|
beta_1_power=beta_1_power,
|
||||||
|
one_minus_beta_1_t=1 - beta_1_t,
|
||||||
|
beta_2_t=beta_2_t,
|
||||||
|
beta_2_power=beta_2_power,
|
||||||
|
one_minus_beta_2_t=1 - beta_2_t))
|
||||||
|
|
||||||
|
def set_weights(self, weights):
|
||||||
|
params = self.weights
|
||||||
|
# If the weights are generated by Keras V1 optimizer, it includes vhats
|
||||||
|
# even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2
|
||||||
|
# optimizer has 2x + 1 variables. Filter vhats out for compatibility.
|
||||||
|
num_vars = int((len(params) - 1) / 2)
|
||||||
|
if len(weights) == 3 * num_vars + 1:
|
||||||
|
weights = weights[:len(params)]
|
||||||
|
super(NonFusedAdam, self).set_weights(weights)
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||||
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||||
|
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
|
||||||
|
self._fallback_apply_state(var_device, var_dtype))
|
||||||
|
|
||||||
|
m = self.get_slot(var, 'm')
|
||||||
|
v = self.get_slot(var, 'v')
|
||||||
|
|
||||||
|
alpha = (
|
||||||
|
coefficients['lr_t'] * math_ops.sqrt(1 - coefficients['beta_2_power']) /
|
||||||
|
(1 - coefficients['beta_1_power']))
|
||||||
|
m.assign_add((grad - m) * (1 - coefficients['beta_1_t']))
|
||||||
|
v.assign_add((math_ops.square(grad) - v) * (1 - coefficients['beta_2_t']))
|
||||||
|
if self.amsgrad:
|
||||||
|
vhat = self.get_slot(var, 'vhat')
|
||||||
|
vhat.assign(math_ops.maximum(vhat, v))
|
||||||
|
v = vhat
|
||||||
|
var.assign_sub(
|
||||||
|
(m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon']))
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
||||||
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||||
|
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
|
||||||
|
self._fallback_apply_state(var_device, var_dtype))
|
||||||
|
|
||||||
|
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||||
|
m = self.get_slot(var, 'm')
|
||||||
|
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
|
||||||
|
m.assign(m * coefficients['beta_1_t'])
|
||||||
|
m.scatter_add(ops.IndexedSlices(m_scaled_g_values, indices))
|
||||||
|
|
||||||
|
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||||
|
v = self.get_slot(var, 'v')
|
||||||
|
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
|
||||||
|
v.assign(v * coefficients['beta_2_t'])
|
||||||
|
v.scatter_add(ops.IndexedSlices(v_scaled_g_values, indices))
|
||||||
|
|
||||||
|
if not self.amsgrad:
|
||||||
|
var.assign_sub(coefficients['lr'] * m /
|
||||||
|
(math_ops.sqrt(v) + coefficients['epsilon']))
|
||||||
|
else:
|
||||||
|
v_hat = self.get_slot(var, 'vhat')
|
||||||
|
v_hat.assign(math_ops.maximum(v_hat, v))
|
||||||
|
var.assign_sub(coefficients['lr'] * m /
|
||||||
|
(math_ops.sqrt(v_hat) + coefficients['epsilon']))
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super(NonFusedAdam, self).get_config()
|
||||||
|
config.update({
|
||||||
|
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||||
|
'decay': self._serialize_hyperparameter('decay'),
|
||||||
|
'beta_1': self._serialize_hyperparameter('beta_1'),
|
||||||
|
'beta_2': self._serialize_hyperparameter('beta_2'),
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'amsgrad': self.amsgrad,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
@ -569,5 +569,435 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
||||||
|
|
||||||
|
|
||||||
|
class NonFusedAdamOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testSparse(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||||
|
grads0_np_indices = np.array([0, 2], dtype=np.int32)
|
||||||
|
grads0 = ops.IndexedSlices(
|
||||||
|
constant_op.constant(grads0_np[grads0_np_indices]),
|
||||||
|
constant_op.constant(grads0_np_indices), constant_op.constant([3]))
|
||||||
|
grads1_np_indices = np.array([0, 2], dtype=np.int32)
|
||||||
|
grads1 = ops.IndexedSlices(
|
||||||
|
constant_op.constant(grads1_np[grads1_np_indices]),
|
||||||
|
constant_op.constant(grads1_np_indices), constant_op.constant([3]))
|
||||||
|
opt = adam.NonFusedAdam()
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0))
|
||||||
|
self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
|
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta_1_power))
|
||||||
|
self.assertAllCloseAccordingToType(0.999**(t + 1),
|
||||||
|
self.evaluate(beta_2_power))
|
||||||
|
update.run()
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
def testSparseDevicePlacement(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for index_dtype in [dtypes.int32, dtypes.int64]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(
|
||||||
|
force_gpu=test.is_gpu_available()):
|
||||||
|
# If a GPU is available, tests that all optimizer ops can be placed on
|
||||||
|
# it (i.e. they have GPU kernels).
|
||||||
|
var = variables.Variable([[1.0], [2.0]])
|
||||||
|
indices = constant_op.constant([0, 1], dtype=index_dtype)
|
||||||
|
g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices)) # pylint: disable=cell-var-from-loop
|
||||||
|
optimizer = adam.NonFusedAdam(3.0)
|
||||||
|
minimize_op = optimizer.minimize(g_sum, var_list=[var])
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
minimize_op.run()
|
||||||
|
|
||||||
|
def testSparseRepeatedIndices(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
|
repeated_index_update_var = variables.Variable(
|
||||||
|
[[1.0], [2.0]], dtype=dtype)
|
||||||
|
aggregated_update_var = variables.Variable(
|
||||||
|
[[1.0], [2.0]], dtype=dtype)
|
||||||
|
grad_repeated_index = ops.IndexedSlices(
|
||||||
|
constant_op.constant(
|
||||||
|
[0.1, 0.1], shape=[2, 1], dtype=dtype),
|
||||||
|
constant_op.constant([1, 1]),
|
||||||
|
constant_op.constant([2, 1]))
|
||||||
|
grad_aggregated = ops.IndexedSlices(
|
||||||
|
constant_op.constant(
|
||||||
|
[0.2], shape=[1, 1], dtype=dtype),
|
||||||
|
constant_op.constant([1]),
|
||||||
|
constant_op.constant([2, 1]))
|
||||||
|
repeated_update = adam.NonFusedAdam().apply_gradients(
|
||||||
|
[(grad_repeated_index, repeated_index_update_var)])
|
||||||
|
aggregated_update = adam.NonFusedAdam().apply_gradients(
|
||||||
|
[(grad_aggregated, aggregated_update_var)])
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertAllClose(aggregated_update_var.eval(),
|
||||||
|
self.evaluate(repeated_index_update_var))
|
||||||
|
for _ in range(3):
|
||||||
|
repeated_update.run()
|
||||||
|
aggregated_update.run()
|
||||||
|
self.assertAllClose(aggregated_update_var.eval(),
|
||||||
|
self.evaluate(repeated_index_update_var))
|
||||||
|
|
||||||
|
def doTestBasic(self, use_callable_params=False):
|
||||||
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
|
var0_np, name="var0_%d" % i)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(
|
||||||
|
var1_np, name="var1_%d" % i)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
learning_rate = lambda: 0.001
|
||||||
|
beta1 = lambda: 0.9
|
||||||
|
beta2 = lambda: 0.999
|
||||||
|
epsilon = lambda: 1e-8
|
||||||
|
if not use_callable_params:
|
||||||
|
learning_rate = learning_rate()
|
||||||
|
beta1 = beta1()
|
||||||
|
beta2 = beta2()
|
||||||
|
epsilon = epsilon()
|
||||||
|
|
||||||
|
opt = adam.NonFusedAdam(learning_rate=learning_rate)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
|
||||||
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta_1_power))
|
||||||
|
self.assertAllCloseAccordingToType(0.999**(t + 1),
|
||||||
|
self.evaluate(beta_2_power))
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(update)
|
||||||
|
else:
|
||||||
|
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var0_np, self.evaluate(var0), rtol=1e-4, atol=1e-4)
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var1_np, self.evaluate(var1), rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||||
|
def testResourceBasic(self):
|
||||||
|
self.doTestBasic()
|
||||||
|
|
||||||
|
def testBasicCallableParams(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
self.doTestBasic(use_callable_params=True)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||||
|
def testBasicWithAmsgrad(self):
|
||||||
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, v0hat, m1, v1, v1hat = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
|
var0_np, name="var0_%d" % i)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(
|
||||||
|
var1_np, name="var1_%d" % i)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
opt = adam.NonFusedAdam(amsgrad=True)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
|
||||||
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta_1_power))
|
||||||
|
self.assertAllCloseAccordingToType(0.999**(t + 1),
|
||||||
|
self.evaluate(beta_2_power))
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(update)
|
||||||
|
else:
|
||||||
|
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
var0_np, m0, v0, v0hat = adam_update_numpy_amsgrad(
|
||||||
|
var0_np, grads0_np, t, m0, v0, v0hat)
|
||||||
|
var1_np, m1, v1, v1hat = adam_update_numpy_amsgrad(
|
||||||
|
var1_np, grads1_np, t, m1, v1, v1hat)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var0_np, self.evaluate(var0), rtol=1e-4, atol=1e-4)
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var1_np, self.evaluate(var1), rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||||
|
def testSparseWithAmsgrad(self):
|
||||||
|
# dtypes.half does not work on gpu + eager.
|
||||||
|
for dtype in [dtypes.float32, dtypes.float64]:
|
||||||
|
with self.cached_session():
|
||||||
|
m0 = np.array([[0.0], [0.0]])
|
||||||
|
v0 = np.array([[0.0], [0.0]])
|
||||||
|
v0hat = np.array([[0.0], [0.0]])
|
||||||
|
indices_np = np.array([1])
|
||||||
|
indices = constant_op.constant(indices_np, dtype=dtypes.int32)
|
||||||
|
var0_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype)
|
||||||
|
repeated_index_update_var = variables.Variable(var0_np, dtype=dtype)
|
||||||
|
aggregated_update_var = variables.Variable(var0_np, dtype=dtype)
|
||||||
|
grads0_np = np.array([[0.2]], dtype=dtype.as_numpy_dtype)
|
||||||
|
grad_repeated_index = ops.IndexedSlices(
|
||||||
|
constant_op.constant([0.1, 0.1], shape=[2, 1], dtype=dtype),
|
||||||
|
constant_op.constant([1, 1]), constant_op.constant([2, 1]))
|
||||||
|
grad_aggregated = ops.IndexedSlices(grads0_np, indices,
|
||||||
|
constant_op.constant([2, 1]))
|
||||||
|
opt_repeated = adam.NonFusedAdam(amsgrad=True)
|
||||||
|
opt_aggregated = adam.NonFusedAdam(amsgrad=True)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
repeated_update = opt_repeated.apply_gradients(
|
||||||
|
[(grad_repeated_index, repeated_index_update_var)])
|
||||||
|
aggregated_update = opt_aggregated.apply_gradients(
|
||||||
|
[(grad_aggregated, aggregated_update_var)])
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(aggregated_update_var),
|
||||||
|
self.evaluate(repeated_index_update_var))
|
||||||
|
for t in range(3):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(repeated_update)
|
||||||
|
self.evaluate(aggregated_update)
|
||||||
|
else:
|
||||||
|
opt_repeated.apply_gradients(
|
||||||
|
[(grad_repeated_index, repeated_index_update_var)])
|
||||||
|
opt_aggregated.apply_gradients(
|
||||||
|
[(grad_aggregated, aggregated_update_var)])
|
||||||
|
|
||||||
|
var0_np, m0, v0, v0hat = adam_sparse_update_numpy_amsgrad(
|
||||||
|
var0_np, indices_np, grads0_np, t, m0, v0, v0hat)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var0_np, self.evaluate(aggregated_update_var))
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
self.evaluate(aggregated_update_var),
|
||||||
|
self.evaluate(repeated_index_update_var))
|
||||||
|
|
||||||
|
def testBasicWithLearningRateDecay(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
|
var0_np, name="var0_%d" % i)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(
|
||||||
|
var1_np, name="var1_%d" % i)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
|
beta_1 = 0.9
|
||||||
|
beta_2 = 0.999
|
||||||
|
epsilon = 1e-7
|
||||||
|
decay = 0.5
|
||||||
|
|
||||||
|
opt = adam.NonFusedAdam(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
beta_1=beta_1,
|
||||||
|
beta_2=beta_2,
|
||||||
|
epsilon=epsilon,
|
||||||
|
decay=decay)
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
self.evaluate(update)
|
||||||
|
lr_np = learning_rate / (1 + decay * t)
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(
|
||||||
|
var0_np, grads0_np, t, m0, v0, lr=lr_np)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(
|
||||||
|
var1_np, grads1_np, t, m1, v1, lr=lr_np)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
def testBasicWithLearningRateInverseTimeDecay(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
|
var0_np, name="var0_%d" % i)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(
|
||||||
|
var1_np, name="var1_%d" % i)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
|
decay = 0.5
|
||||||
|
lr_schedule = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
learning_rate, decay_steps=1.0, decay_rate=decay)
|
||||||
|
beta_1 = 0.9
|
||||||
|
beta_2 = 0.999
|
||||||
|
epsilon = 1e-7
|
||||||
|
|
||||||
|
opt = adam.NonFusedAdam(
|
||||||
|
learning_rate=lr_schedule,
|
||||||
|
beta_1=beta_1,
|
||||||
|
beta_2=beta_2,
|
||||||
|
epsilon=epsilon)
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
self.evaluate(update)
|
||||||
|
|
||||||
|
lr_np = learning_rate / (1 + decay * t)
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(
|
||||||
|
var0_np, grads0_np, t, m0, v0, lr=lr_np)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(
|
||||||
|
var1_np, grads1_np, t, m1, v1, lr=lr_np)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
def testTensorLearningRate(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = variables.Variable(var0_np)
|
||||||
|
var1 = variables.Variable(var1_np)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
opt = adam.NonFusedAdam(constant_op.constant(0.001))
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
|
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
|
||||||
|
# Run 3 steps of NonFusedAdam
|
||||||
|
for t in range(3):
|
||||||
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta_1_power))
|
||||||
|
self.assertAllCloseAccordingToType(0.999**(t + 1),
|
||||||
|
self.evaluate(beta_2_power))
|
||||||
|
update.run()
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
def testSharing(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = variables.Variable(var0_np)
|
||||||
|
var1 = variables.Variable(var1_np)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
opt = adam.NonFusedAdam()
|
||||||
|
update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
|
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
|
||||||
|
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
|
# Run 3 steps of intertwined NonFusedAdam1 and NonFusedAdam2.
|
||||||
|
for t in range(3):
|
||||||
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta_1_power))
|
||||||
|
self.assertAllCloseAccordingToType(0.999**(t + 1),
|
||||||
|
self.evaluate(beta_2_power))
|
||||||
|
if t % 2 == 0:
|
||||||
|
update1.run()
|
||||||
|
else:
|
||||||
|
update2.run()
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user