Rewrite Adam and LazyAdam optimizer to take global step for computing beta1 and beta2 accumulators, instead of having the optimizer instance to keep its own independent beta1 and beta2 accumulators as non-slot variables.
PiperOrigin-RevId: 224948020
This commit is contained in:
parent
68834966da
commit
c2255b0f32
@ -14,6 +14,7 @@ py_library(
|
||||
name = "opt_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/training/adam_gs_optimizer.py",
|
||||
"python/training/adamax.py",
|
||||
"python/training/addsign.py",
|
||||
"python/training/agn_optimizer.py",
|
||||
@ -22,6 +23,7 @@ py_library(
|
||||
"python/training/external_optimizer.py",
|
||||
"python/training/ggt.py",
|
||||
"python/training/lars_optimizer.py",
|
||||
"python/training/lazy_adam_gs_optimizer.py",
|
||||
"python/training/lazy_adam_optimizer.py",
|
||||
"python/training/matrix_functions.py",
|
||||
"python/training/model_average_optimizer.py",
|
||||
@ -60,6 +62,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "adam_gs_optimizer_test",
|
||||
srcs = ["python/training/adam_gs_optimizer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "adamax_test",
|
||||
srcs = ["python/training/adamax_test.py"],
|
||||
@ -148,6 +165,25 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "lazy_adam_gs_optimizer_test",
|
||||
srcs = ["python/training/lazy_adam_gs_optimizer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "lazy_adam_optimizer_test",
|
||||
srcs = ["python/training/lazy_adam_optimizer_test.py"],
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.opt.python.training.adam_gs_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.adamax import *
|
||||
from tensorflow.contrib.opt.python.training.addsign import *
|
||||
from tensorflow.contrib.opt.python.training.agn_optimizer import *
|
||||
@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.lars_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.ggt import *
|
||||
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
|
||||
@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'AdaMaxOptimizer',
|
||||
'AdamGSOptimizer',
|
||||
'PowerSignOptimizer',
|
||||
'AddSignOptimizer',
|
||||
'DelayCompensatedGradientDescentOptimizer',
|
||||
'DropStaleGradientOptimizer',
|
||||
'ExternalOptimizerInterface',
|
||||
'LARSOptimizer',
|
||||
'LazyAdamGSOptimizer',
|
||||
'LazyAdamOptimizer',
|
||||
'NadamOptimizer',
|
||||
'MovingAverageOptimizer',
|
||||
|
217
tensorflow/contrib/opt/python/training/adam_gs_optimizer.py
Normal file
217
tensorflow/contrib/opt/python/training/adam_gs_optimizer.py
Normal file
@ -0,0 +1,217 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Adam rewrite to use global step for computing beta1 & beta2 accumulation."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("train.AdamOptimizer")
|
||||
class AdamGSOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
|
||||
def __init__(self, global_step=0, learning_rate=0.001,
|
||||
beta1=0.9, beta2=0.999, epsilon=1e-8,
|
||||
use_locking=False, name="Adam"):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
Branched from tf.train.AdamOptimizer. The only difference is to pass
|
||||
global step for computing beta1 and beta2 accumulators, instead of having
|
||||
optimizer keep its own independent beta1 and beta2 accumulators as non-slot
|
||||
variables.
|
||||
|
||||
Initialization:
|
||||
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Args:
|
||||
global_step: tensorflow variable indicating the step.
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(AdamGSOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
self._global_step = global_step
|
||||
self._global_step_on_worker = None
|
||||
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
|
||||
def _get_beta_accumulators(self):
|
||||
return (math_ops.pow(self._beta1_t, self._global_step_on_worker),
|
||||
math_ops.pow(self._beta2_t, self._global_step_on_worker))
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
|
||||
# Performance optimization so that worker creates a copy of the global step
|
||||
# to avoid overloading the parameter server holding the global step.
|
||||
self._global_step_on_worker = math_ops.cast(
|
||||
array_ops.identity(self._global_step) + 1, dtypes.float32)
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.apply_adam(
|
||||
var, m, v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking).op
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle, m.handle, v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking)
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t,
|
||||
use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(var,
|
||||
lr * m_t / (v_sqrt + epsilon_t),
|
||||
use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values, var, grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking))
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(
|
||||
x.handle, i, v)]):
|
||||
return x.value()
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(
|
||||
grad, var, indices, self._resource_scatter_add)
|
382
tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py
Normal file
382
tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py
Normal file
@ -0,0 +1,382 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for AdamGS."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.opt.python.training import adam_gs_optimizer
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def adam_update_numpy(param,
|
||||
g_t,
|
||||
t,
|
||||
m,
|
||||
v,
|
||||
alpha=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8):
|
||||
alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
|
||||
|
||||
m_t = beta1 * m + (1 - beta1) * g_t
|
||||
v_t = beta2 * v + (1 - beta2) * g_t * g_t
|
||||
|
||||
param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
class AdamGSOptimizerTest(test.TestCase):
|
||||
|
||||
def doTestSparse(self, use_resource=False):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
if use_resource:
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
else:
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads0 = ops.IndexedSlices(
|
||||
constant_op.constant(grads0_np),
|
||||
constant_op.constant(grads0_np_indices), constant_op.constant([2]))
|
||||
grads1_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads1 = ops.IndexedSlices(
|
||||
constant_op.constant(grads1_np),
|
||||
constant_op.constant(grads1_np_indices), constant_op.constant([2]))
|
||||
opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(0.999**t,
|
||||
self.evaluate(beta2_power))
|
||||
update.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
def testSparse(self):
|
||||
self.doTestSparse(use_resource=False)
|
||||
|
||||
def testResourceSparse(self):
|
||||
self.doTestSparse(use_resource=True)
|
||||
|
||||
def testSparseDevicePlacement(self):
|
||||
for index_dtype in [dtypes.int32, dtypes.int64]:
|
||||
with self.cached_session(force_gpu=test.is_gpu_available()):
|
||||
# If a GPU is available, tests that all optimizer ops can be placed on
|
||||
# it (i.e. they have GPU kernels).
|
||||
var = variables.Variable([[1.0], [2.0]])
|
||||
indices = constant_op.constant([0, 1], dtype=index_dtype)
|
||||
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
|
||||
optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0)
|
||||
minimize_op = optimizer.minimize(gathered_sum)
|
||||
variables.global_variables_initializer().run()
|
||||
minimize_op.run()
|
||||
|
||||
def testSparseRepeatedIndices(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
repeated_index_global_step = variables.Variable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
aggregated_global_step = variables.Variable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
repeated_index_update_var = variables.Variable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
aggregated_update_var = variables.Variable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
grad_repeated_index = ops.IndexedSlices(
|
||||
constant_op.constant(
|
||||
[0.1, 0.1], shape=[2, 1], dtype=dtype),
|
||||
constant_op.constant([1, 1]),
|
||||
constant_op.constant([2, 1]))
|
||||
grad_aggregated = ops.IndexedSlices(
|
||||
constant_op.constant(
|
||||
[0.2], shape=[1, 1], dtype=dtype),
|
||||
constant_op.constant([1]),
|
||||
constant_op.constant([2, 1]))
|
||||
repeated_update = adam_gs_optimizer.AdamGSOptimizer(
|
||||
global_step=repeated_index_global_step).apply_gradients(
|
||||
[(grad_repeated_index, repeated_index_update_var)],
|
||||
global_step=repeated_index_global_step)
|
||||
aggregated_update = adam_gs_optimizer.AdamGSOptimizer(
|
||||
global_step=aggregated_global_step).apply_gradients(
|
||||
[(grad_aggregated, aggregated_update_var)],
|
||||
global_step=aggregated_global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(aggregated_update_var.eval(),
|
||||
self.evaluate(repeated_index_update_var))
|
||||
for _ in range(3):
|
||||
repeated_update.run()
|
||||
aggregated_update.run()
|
||||
self.assertAllClose(aggregated_update_var.eval(),
|
||||
self.evaluate(repeated_index_update_var))
|
||||
|
||||
def doTestBasic(self, use_resource=False, use_callable_params=False):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.session(graph=ops.Graph()):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
if use_resource:
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64), name="global_step_%d" % i)
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
var0_np, name="var0_%d" % i)
|
||||
var1 = resource_variable_ops.ResourceVariable(
|
||||
var1_np, name="var1_%d" % i)
|
||||
else:
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
learning_rate = lambda: 0.001
|
||||
beta1 = lambda: 0.9
|
||||
beta2 = lambda: 0.999
|
||||
epsilon = lambda: 1e-8
|
||||
if not use_callable_params:
|
||||
learning_rate = learning_rate()
|
||||
beta1 = beta1()
|
||||
beta2 = beta2()
|
||||
epsilon = epsilon()
|
||||
|
||||
opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step,
|
||||
learning_rate=learning_rate)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
opt_variables = opt.variables()
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
self.assertTrue(beta1_power is not None)
|
||||
self.assertTrue(beta2_power is not None)
|
||||
self.assertNotIn(beta1_power, opt_variables)
|
||||
self.assertNotIn(beta2_power, opt_variables)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with ops.Graph().as_default():
|
||||
# Shouldn't return non-slot variables from other graphs.
|
||||
self.assertEqual(0, len(opt.variables()))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.9**(t + 1), self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.999**(t + 1), self.evaluate(beta2_power))
|
||||
else:
|
||||
if t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.9**t, self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.999**t, self.evaluate(beta2_power))
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
if use_resource:
|
||||
self.assertEqual("var0_%d/Adam:0" % (i,),
|
||||
opt.get_slot(var=var0, name="m").name)
|
||||
|
||||
def testBasic(self):
|
||||
with self.cached_session():
|
||||
self.doTestBasic(use_resource=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testResourceBasic(self):
|
||||
self.doTestBasic(use_resource=True)
|
||||
|
||||
def testBasicCallableParams(self):
|
||||
with context.eager_mode():
|
||||
self.doTestBasic(use_resource=True, use_callable_params=True)
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
opt = adam_gs_optimizer.AdamGSOptimizer(
|
||||
global_step=global_step, learning_rate=constant_op.constant(0.001))
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(0.999**t,
|
||||
self.evaluate(beta2_power))
|
||||
update.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
def testSharing(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step)
|
||||
update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 3 steps of intertwined Adam1 and Adam2.
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(0.999**t,
|
||||
self.evaluate(beta2_power))
|
||||
if t % 2 == 0:
|
||||
update1.run()
|
||||
else:
|
||||
update2.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
|
||||
def testTwoSessions(self):
|
||||
optimizer = adam_gs_optimizer.AdamGSOptimizer()
|
||||
|
||||
with context.eager_mode():
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
with session.Session():
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
gg = ops.Graph()
|
||||
with gg.as_default():
|
||||
with session.Session():
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
|
||||
# If the optimizer saves any state not keyed by graph the following line
|
||||
# fails.
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
def testSlotsUniqueEager(self):
|
||||
with context.eager_mode():
|
||||
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||
v2 = resource_variable_ops.ResourceVariable(1.)
|
||||
opt = adam_gs_optimizer.AdamGSOptimizer(1.)
|
||||
opt.minimize(lambda: v1 + v2)
|
||||
# There should be two unique slot variables for v1 and v2 respectively.
|
||||
self.assertEqual(4, len(set(opt.variables())))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
114
tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py
Normal file
114
tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.opt.python.training import adam_gs_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
|
||||
|
||||
class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer):
|
||||
"""Variant of the Adam optimizer that handles sparse updates more efficiently.
|
||||
|
||||
Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to
|
||||
pass global step for computing beta1 and beta2 accumulators, instead of having
|
||||
optimizer keep its own independent beta1 and beta2 accumulators as non-slot
|
||||
variables.
|
||||
|
||||
The original Adam algorithm maintains two moving-average accumulators for
|
||||
each trainable variable; the accumulators are updated at every step.
|
||||
This class provides lazier handling of gradient updates for sparse variables.
|
||||
It only updates moving-average accumulators for sparse variable indices that
|
||||
appear in the current batch, rather than updating the accumulators for all
|
||||
indices. Compared with the original Adam optimizer, it can provide large
|
||||
improvements in model training throughput for some applications. However, it
|
||||
provides slightly different semantics than the original Adam algorithm, and
|
||||
may lead to different empirical results.
|
||||
"""
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
|
||||
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
|
||||
m = self.get_slot(var, "m")
|
||||
m_t = state_ops.scatter_update(m, grad.indices,
|
||||
beta1_t * array_ops.gather(m, grad.indices) +
|
||||
(1 - beta1_t) * grad.values,
|
||||
use_locking=self._use_locking)
|
||||
|
||||
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
|
||||
v = self.get_slot(var, "v")
|
||||
v_t = state_ops.scatter_update(v, grad.indices,
|
||||
beta2_t * array_ops.gather(v, grad.indices) +
|
||||
(1 - beta2_t) * math_ops.square(grad.values),
|
||||
use_locking=self._use_locking)
|
||||
|
||||
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
|
||||
m_t_slice = array_ops.gather(m_t, grad.indices)
|
||||
v_t_slice = array_ops.gather(v_t, grad.indices)
|
||||
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
|
||||
var_update = state_ops.scatter_sub(var, grad.indices,
|
||||
lr * m_t_slice / denominator_slice,
|
||||
use_locking=self._use_locking)
|
||||
return control_flow_ops.group(var_update, m_t, v_t)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
|
||||
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
|
||||
m = self.get_slot(var, "m")
|
||||
m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
|
||||
m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
|
||||
indices,
|
||||
m_t_slice)
|
||||
|
||||
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
|
||||
v = self.get_slot(var, "v")
|
||||
v_t_slice = (beta2_t * array_ops.gather(v, indices) +
|
||||
(1 - beta2_t) * math_ops.square(grad))
|
||||
v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
|
||||
indices,
|
||||
v_t_slice)
|
||||
|
||||
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
|
||||
var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
|
||||
var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
|
||||
indices,
|
||||
var_slice)
|
||||
|
||||
return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
|
@ -0,0 +1,402 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for LazyAdamGSOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def adam_update_numpy(param,
|
||||
g_t,
|
||||
t,
|
||||
m,
|
||||
v,
|
||||
alpha=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8):
|
||||
alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
|
||||
|
||||
m_t = beta1 * m + (1 - beta1) * g_t
|
||||
v_t = beta2 * v + (1 - beta2) * g_t * g_t
|
||||
|
||||
param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def testSparse(self, use_resource):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
if use_resource:
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
else:
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
|
||||
grads0_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads0 = ops.IndexedSlices(
|
||||
constant_op.constant(grads0_np),
|
||||
constant_op.constant(grads0_np_indices), constant_op.constant([2]))
|
||||
grads1_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads1 = ops.IndexedSlices(
|
||||
constant_op.constant(grads1_np),
|
||||
constant_op.constant(grads1_np_indices), constant_op.constant([2]))
|
||||
opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=global_step)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
||||
update.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def testSparseDevicePlacement(self, use_resource):
|
||||
for index_dtype in [dtypes.int32, dtypes.int64]:
|
||||
with self.cached_session(force_gpu=test.is_gpu_available()):
|
||||
# If a GPU is available, tests that all optimizer ops can be placed on
|
||||
# it (i.e. they have GPU kernels).
|
||||
if use_resource:
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
|
||||
else:
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
var = variables.Variable([[1.0], [2.0]])
|
||||
|
||||
indices = constant_op.constant([0, 1], dtype=index_dtype)
|
||||
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
|
||||
optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=global_step, learning_rate=3.0)
|
||||
minimize_op = optimizer.minimize(gathered_sum, global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
minimize_op.run()
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def testSparseRepeatedIndices(self, use_resource):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
if use_resource:
|
||||
repeated_index_global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
aggregated_global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
repeated_index_update_var = resource_variable_ops.ResourceVariable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
aggregated_update_var = resource_variable_ops.ResourceVariable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
else:
|
||||
repeated_index_global_step = variables.Variable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
aggregated_global_step = variables.Variable(
|
||||
array_ops.zeros([], dtypes.int64))
|
||||
repeated_index_update_var = variables.Variable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
aggregated_update_var = variables.Variable(
|
||||
[[1.0], [2.0]], dtype=dtype)
|
||||
|
||||
grad_repeated_index = ops.IndexedSlices(
|
||||
constant_op.constant(
|
||||
[0.1, 0.1], shape=[2, 1], dtype=dtype),
|
||||
constant_op.constant([1, 1]),
|
||||
constant_op.constant([2, 1]))
|
||||
grad_aggregated = ops.IndexedSlices(
|
||||
constant_op.constant(
|
||||
[0.2], shape=[1, 1], dtype=dtype),
|
||||
constant_op.constant([1]),
|
||||
constant_op.constant([2, 1]))
|
||||
repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=repeated_index_global_step)
|
||||
repeated_update = repeated_update_opt.apply_gradients(
|
||||
[(grad_repeated_index, repeated_index_update_var)],
|
||||
global_step=repeated_index_global_step)
|
||||
aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=aggregated_global_step)
|
||||
aggregated_update = aggregated_update_opt.apply_gradients(
|
||||
[(grad_aggregated, aggregated_update_var)],
|
||||
global_step=aggregated_global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(aggregated_update_var.eval(),
|
||||
repeated_index_update_var.eval())
|
||||
for _ in range(3):
|
||||
repeated_update.run()
|
||||
aggregated_update.run()
|
||||
self.assertAllClose(aggregated_update_var.eval(),
|
||||
repeated_index_update_var.eval())
|
||||
|
||||
def doTestBasic(self, use_resource=False, use_callable_params=False):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.session(graph=ops.Graph()):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
if use_resource:
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64), name="global_step_%d" % i)
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
var0_np, name="var0_%d" % i)
|
||||
var1 = resource_variable_ops.ResourceVariable(
|
||||
var1_np, name="var1_%d" % i)
|
||||
else:
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
learning_rate = lambda: 0.001
|
||||
beta1 = lambda: 0.9
|
||||
beta2 = lambda: 0.999
|
||||
epsilon = lambda: 1e-8
|
||||
if not use_callable_params:
|
||||
learning_rate = learning_rate()
|
||||
beta1 = beta1()
|
||||
beta2 = beta2()
|
||||
epsilon = epsilon()
|
||||
|
||||
opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=global_step, learning_rate=learning_rate)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
opt_variables = opt.variables()
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
self.assertIsNotNone(beta1_power)
|
||||
self.assertIsNotNone(beta2_power is not None)
|
||||
self.assertNotIn(beta1_power, opt_variables)
|
||||
self.assertNotIn(beta2_power, opt_variables)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with ops.Graph().as_default():
|
||||
# Shouldn't return non-slot variables from other graphs.
|
||||
self.assertEqual(0, len(opt.variables()))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.9**(t + 1), self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.999**(t + 1), self.evaluate(beta2_power))
|
||||
else:
|
||||
if t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.9**t, self.evaluate(beta1_power))
|
||||
self.assertAllCloseAccordingToType(
|
||||
0.999**t, self.evaluate(beta2_power))
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
if use_resource:
|
||||
self.assertEqual("var0_%d/Adam:0" % (i,),
|
||||
opt.get_slot(var=var0, name="m").name)
|
||||
|
||||
def testBasic(self):
|
||||
with self.cached_session():
|
||||
self.doTestBasic(use_resource=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testResourceBasic(self):
|
||||
self.doTestBasic(use_resource=True)
|
||||
|
||||
def testBasicCallableParams(self):
|
||||
with context.eager_mode():
|
||||
self.doTestBasic(use_resource=True, use_callable_params=True)
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=global_step, learning_rate=constant_op.constant(0.001))
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
||||
update.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testSharing(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
global_step = variables.Variable(array_ops.zeros([], dtypes.int64))
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(
|
||||
global_step=global_step)
|
||||
update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 3 steps of intertwined Adam1 and Adam2.
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
||||
if t % 2 == 0:
|
||||
update1.run()
|
||||
else:
|
||||
update2.run()
|
||||
|
||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testTwoSessions(self):
|
||||
optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer()
|
||||
|
||||
with context.eager_mode():
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
with self.session(graph=g):
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
gg = ops.Graph()
|
||||
with gg.as_default():
|
||||
with self.session(graph=gg):
|
||||
var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
|
||||
grads0 = constant_op.constant(np.array([0.1, 0.1]))
|
||||
|
||||
# If the optimizer saves any state not keyed by graph the following line
|
||||
# fails.
|
||||
optimizer.apply_gradients([(grads0, var0)])
|
||||
|
||||
def testSlotsUniqueEager(self):
|
||||
with context.eager_mode():
|
||||
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||
v2 = resource_variable_ops.ResourceVariable(1.)
|
||||
opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.)
|
||||
opt.minimize(lambda: v1 + v2)
|
||||
# There should be two non-slot variables, and two unique slot variables
|
||||
# for v1 and v2 respectively.
|
||||
self.assertLen(set(opt.variables()), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user