Remove decay variable from optimizers.

PiperOrigin-RevId: 346172198
Change-Id: I2f98d52d47ffe81ee3baa5891e187c2c84f145cf
This commit is contained in:
Zhenyu Tan 2020-12-07 13:50:38 -08:00 committed by TensorFlower Gardener
parent eaa851346c
commit c50af433c1
10 changed files with 23 additions and 10 deletions

View File

@ -153,7 +153,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
config = super(Adadelta, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'rho': self._serialize_hyperparameter('rho'),
'epsilon': self.epsilon,
})

View File

@ -157,7 +157,7 @@ class Adagrad(optimizer_v2.OptimizerV2):
config = super(Adagrad, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'initial_accumulator_value': self._initial_accumulator_value,
'epsilon': self.epsilon,
})

View File

@ -244,7 +244,7 @@ class Adam(optimizer_v2.OptimizerV2):
config = super(Adam, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,
@ -468,7 +468,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
config = super(NonFusedAdam, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,

View File

@ -180,7 +180,7 @@ class Adamax(optimizer_v2.OptimizerV2):
config = super(Adamax, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,

View File

@ -234,7 +234,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
'learning_rate':
self._serialize_hyperparameter('learning_rate'),
'decay':
self._serialize_hyperparameter('decay'),
self._initial_decay,
'initial_accumulator_value':
self._initial_accumulator_value,
'learning_rate_power':

View File

@ -187,7 +187,7 @@ class SGD(optimizer_v2.OptimizerV2):
config = super(SGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"decay": self._initial_decay,
"momentum": self._serialize_hyperparameter("momentum"),
"nesterov": self.nesterov,
})

View File

@ -214,7 +214,7 @@ class Nadam(optimizer_v2.OptimizerV2):
config = super(Nadam, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'decay': self._initial_decay,
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,

View File

@ -1004,7 +1004,7 @@ class OptimizerV2(trackable.Trackable):
lr_t = math_ops.cast(lr_t(local_step), var_dtype)
if self._initial_decay > 0.:
local_step = math_ops.cast(self.iterations, var_dtype)
decay_t = self._get_hyper("decay", var_dtype)
decay_t = math_ops.cast(self._initial_decay, var_dtype)
lr_t = lr_t / (1. + decay_t * local_step)
return lr_t

View File

@ -908,6 +908,19 @@ class OptimizerWithFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllClose([0., 1.], fn(), atol=1e-4)
self.assertAllClose([-1, 0.], fn(), atol=1e-4)
def testBasicWithConstantDecay(self):
var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
loss = lambda: 3 * var
opt = adam.Adam(learning_rate=1.0)
@def_function.function
def fn():
opt.minimize(loss, [var])
return var
self.assertAllClose([0., 1.], fn(), atol=1e-4)
self.assertAllClose([-1, 0.], fn(), atol=1e-4)
def testVarKeyWithVarCreatedInEager(self):
a = variables.Variable([1., 2.], name='var')
b = variables.Variable([1.], name='var')

View File

@ -290,7 +290,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
config = super(RMSprop, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"decay": self._initial_decay,
"rho": self._serialize_hyperparameter("rho"),
"momentum": self._serialize_hyperparameter("momentum"),
"epsilon": self.epsilon,