Remove decay variable from optimizers.
PiperOrigin-RevId: 346172198 Change-Id: I2f98d52d47ffe81ee3baa5891e187c2c84f145cf
This commit is contained in:
parent
eaa851346c
commit
c50af433c1
tensorflow/python/keras/optimizer_v2
@ -153,7 +153,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
|
||||
config = super(Adadelta, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'rho': self._serialize_hyperparameter('rho'),
|
||||
'epsilon': self.epsilon,
|
||||
})
|
||||
|
@ -157,7 +157,7 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
config = super(Adagrad, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'initial_accumulator_value': self._initial_accumulator_value,
|
||||
'epsilon': self.epsilon,
|
||||
})
|
||||
|
@ -244,7 +244,7 @@ class Adam(optimizer_v2.OptimizerV2):
|
||||
config = super(Adam, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'beta_1': self._serialize_hyperparameter('beta_1'),
|
||||
'beta_2': self._serialize_hyperparameter('beta_2'),
|
||||
'epsilon': self.epsilon,
|
||||
@ -468,7 +468,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
|
||||
config = super(NonFusedAdam, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'beta_1': self._serialize_hyperparameter('beta_1'),
|
||||
'beta_2': self._serialize_hyperparameter('beta_2'),
|
||||
'epsilon': self.epsilon,
|
||||
|
@ -180,7 +180,7 @@ class Adamax(optimizer_v2.OptimizerV2):
|
||||
config = super(Adamax, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'beta_1': self._serialize_hyperparameter('beta_1'),
|
||||
'beta_2': self._serialize_hyperparameter('beta_2'),
|
||||
'epsilon': self.epsilon,
|
||||
|
@ -234,7 +234,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
'learning_rate':
|
||||
self._serialize_hyperparameter('learning_rate'),
|
||||
'decay':
|
||||
self._serialize_hyperparameter('decay'),
|
||||
self._initial_decay,
|
||||
'initial_accumulator_value':
|
||||
self._initial_accumulator_value,
|
||||
'learning_rate_power':
|
||||
|
@ -187,7 +187,7 @@ class SGD(optimizer_v2.OptimizerV2):
|
||||
config = super(SGD, self).get_config()
|
||||
config.update({
|
||||
"learning_rate": self._serialize_hyperparameter("learning_rate"),
|
||||
"decay": self._serialize_hyperparameter("decay"),
|
||||
"decay": self._initial_decay,
|
||||
"momentum": self._serialize_hyperparameter("momentum"),
|
||||
"nesterov": self.nesterov,
|
||||
})
|
||||
|
@ -214,7 +214,7 @@ class Nadam(optimizer_v2.OptimizerV2):
|
||||
config = super(Nadam, self).get_config()
|
||||
config.update({
|
||||
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
||||
'decay': self._serialize_hyperparameter('decay'),
|
||||
'decay': self._initial_decay,
|
||||
'beta_1': self._serialize_hyperparameter('beta_1'),
|
||||
'beta_2': self._serialize_hyperparameter('beta_2'),
|
||||
'epsilon': self.epsilon,
|
||||
|
@ -1004,7 +1004,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
lr_t = math_ops.cast(lr_t(local_step), var_dtype)
|
||||
if self._initial_decay > 0.:
|
||||
local_step = math_ops.cast(self.iterations, var_dtype)
|
||||
decay_t = self._get_hyper("decay", var_dtype)
|
||||
decay_t = math_ops.cast(self._initial_decay, var_dtype)
|
||||
lr_t = lr_t / (1. + decay_t * local_step)
|
||||
return lr_t
|
||||
|
||||
|
@ -908,6 +908,19 @@ class OptimizerWithFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose([0., 1.], fn(), atol=1e-4)
|
||||
self.assertAllClose([-1, 0.], fn(), atol=1e-4)
|
||||
|
||||
def testBasicWithConstantDecay(self):
|
||||
var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
|
||||
loss = lambda: 3 * var
|
||||
opt = adam.Adam(learning_rate=1.0)
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
opt.minimize(loss, [var])
|
||||
return var
|
||||
|
||||
self.assertAllClose([0., 1.], fn(), atol=1e-4)
|
||||
self.assertAllClose([-1, 0.], fn(), atol=1e-4)
|
||||
|
||||
def testVarKeyWithVarCreatedInEager(self):
|
||||
a = variables.Variable([1., 2.], name='var')
|
||||
b = variables.Variable([1.], name='var')
|
||||
|
@ -290,7 +290,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
|
||||
config = super(RMSprop, self).get_config()
|
||||
config.update({
|
||||
"learning_rate": self._serialize_hyperparameter("learning_rate"),
|
||||
"decay": self._serialize_hyperparameter("decay"),
|
||||
"decay": self._initial_decay,
|
||||
"rho": self._serialize_hyperparameter("rho"),
|
||||
"momentum": self._serialize_hyperparameter("momentum"),
|
||||
"epsilon": self.epsilon,
|
||||
|
Loading…
Reference in New Issue
Block a user