Added beta parameter from FTRL paper to optimizer classes (such as the one in Keras).
PiperOrigin-RevId: 326529913 Change-Id: Ibb57acc7ea33a7c1b893487bfb58ca5befa22a81
This commit is contained in:
parent
f893e8ed94
commit
e9ca78c331
@ -106,7 +106,8 @@
|
|||||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||||
as an alternative to accepting a `callable` loss.
|
as an alternative to accepting a `callable` loss.
|
||||||
* Added `beta` parameter to FTRL optimizer to match paper.
|
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||||
|
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||||
* Added `mobilenet_v3` to keras application model.
|
* Added `mobilenet_v3` to keras application model.
|
||||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||||
customization of how gradients are aggregated across devices, as well as
|
customization of how gradients are aggregated across devices, as well as
|
||||||
|
|||||||
@ -54,6 +54,8 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
or equal to zero. This differs from L2 above in that the L2 above is a
|
or equal to zero. This differs from L2 above in that the L2 above is a
|
||||||
stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
|
stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
|
||||||
When input is sparse shrinkage will only happen on the active weights.
|
When input is sparse shrinkage will only happen on the active weights.
|
||||||
|
beta: A float value, representing the beta value from the paper
|
||||||
|
(https://research.google.com/pubs/archive/41159.pdf).
|
||||||
**kwargs: Keyword arguments. Allowed to be one of
|
**kwargs: Keyword arguments. Allowed to be one of
|
||||||
`"clipnorm"` or `"clipvalue"`.
|
`"clipnorm"` or `"clipvalue"`.
|
||||||
`"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
|
`"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
|
||||||
@ -72,6 +74,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
l2_regularization_strength=0.0,
|
l2_regularization_strength=0.0,
|
||||||
name='Ftrl',
|
name='Ftrl',
|
||||||
l2_shrinkage_regularization_strength=0.0,
|
l2_shrinkage_regularization_strength=0.0,
|
||||||
|
beta=0.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(Ftrl, self).__init__(name, **kwargs)
|
super(Ftrl, self).__init__(name, **kwargs)
|
||||||
|
|
||||||
@ -100,6 +103,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
self._set_hyper('learning_rate_power', learning_rate_power)
|
self._set_hyper('learning_rate_power', learning_rate_power)
|
||||||
self._set_hyper('l1_regularization_strength', l1_regularization_strength)
|
self._set_hyper('l1_regularization_strength', l1_regularization_strength)
|
||||||
self._set_hyper('l2_regularization_strength', l2_regularization_strength)
|
self._set_hyper('l2_regularization_strength', l2_regularization_strength)
|
||||||
|
self._set_hyper('beta', beta)
|
||||||
self._initial_accumulator_value = initial_accumulator_value
|
self._initial_accumulator_value = initial_accumulator_value
|
||||||
self._l2_shrinkage_regularization_strength = (
|
self._l2_shrinkage_regularization_strength = (
|
||||||
l2_shrinkage_regularization_strength)
|
l2_shrinkage_regularization_strength)
|
||||||
@ -115,22 +119,29 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
|
|
||||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||||
super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
|
super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
|
||||||
apply_state[(var_device, var_dtype)].update(dict(
|
apply_state[(var_device, var_dtype)].update(
|
||||||
learning_rate_power=array_ops.identity(
|
dict(
|
||||||
self._get_hyper('learning_rate_power', var_dtype)),
|
learning_rate_power=array_ops.identity(
|
||||||
l1_regularization_strength=array_ops.identity(
|
self._get_hyper('learning_rate_power', var_dtype)),
|
||||||
self._get_hyper('l1_regularization_strength', var_dtype)),
|
l1_regularization_strength=array_ops.identity(
|
||||||
l2_regularization_strength=array_ops.identity(
|
self._get_hyper('l1_regularization_strength', var_dtype)),
|
||||||
self._get_hyper('l2_regularization_strength', var_dtype)),
|
l2_regularization_strength=array_ops.identity(
|
||||||
l2_shrinkage_regularization_strength=math_ops.cast(
|
self._get_hyper('l2_regularization_strength', var_dtype)),
|
||||||
self._l2_shrinkage_regularization_strength, var_dtype)
|
beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
|
||||||
))
|
l2_shrinkage_regularization_strength=math_ops.cast(
|
||||||
|
self._l2_shrinkage_regularization_strength, var_dtype)))
|
||||||
|
|
||||||
def _resource_apply_dense(self, grad, var, apply_state=None):
|
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||||
var_device, var_dtype = var.device, var.dtype.base_dtype
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||||
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
||||||
or self._fallback_apply_state(var_device, var_dtype))
|
or self._fallback_apply_state(var_device, var_dtype))
|
||||||
|
|
||||||
|
# Adjust L2 regularization strength to include beta to avoid the underlying
|
||||||
|
# TensorFlow ops needing to include it.
|
||||||
|
adjusted_l2_regularization_strength = (
|
||||||
|
coefficients['l2_regularization_strength'] + coefficients['beta'] /
|
||||||
|
(2. * coefficients['lr_t']))
|
||||||
|
|
||||||
accum = self.get_slot(var, 'accumulator')
|
accum = self.get_slot(var, 'accumulator')
|
||||||
linear = self.get_slot(var, 'linear')
|
linear = self.get_slot(var, 'linear')
|
||||||
|
|
||||||
@ -142,7 +153,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
grad=grad,
|
grad=grad,
|
||||||
lr=coefficients['lr_t'],
|
lr=coefficients['lr_t'],
|
||||||
l1=coefficients['l1_regularization_strength'],
|
l1=coefficients['l1_regularization_strength'],
|
||||||
l2=coefficients['l2_regularization_strength'],
|
l2=adjusted_l2_regularization_strength,
|
||||||
lr_power=coefficients['learning_rate_power'],
|
lr_power=coefficients['learning_rate_power'],
|
||||||
use_locking=self._use_locking)
|
use_locking=self._use_locking)
|
||||||
else:
|
else:
|
||||||
@ -153,7 +164,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
grad=grad,
|
grad=grad,
|
||||||
lr=coefficients['lr_t'],
|
lr=coefficients['lr_t'],
|
||||||
l1=coefficients['l1_regularization_strength'],
|
l1=coefficients['l1_regularization_strength'],
|
||||||
l2=coefficients['l2_regularization_strength'],
|
l2=adjusted_l2_regularization_strength,
|
||||||
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
||||||
lr_power=coefficients['learning_rate_power'],
|
lr_power=coefficients['learning_rate_power'],
|
||||||
use_locking=self._use_locking)
|
use_locking=self._use_locking)
|
||||||
@ -163,6 +174,12 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
||||||
or self._fallback_apply_state(var_device, var_dtype))
|
or self._fallback_apply_state(var_device, var_dtype))
|
||||||
|
|
||||||
|
# Adjust L2 regularization strength to include beta to avoid the underlying
|
||||||
|
# TensorFlow ops needing to include it.
|
||||||
|
adjusted_l2_regularization_strength = (
|
||||||
|
coefficients['l2_regularization_strength'] + coefficients['beta'] /
|
||||||
|
(2. * coefficients['lr_t']))
|
||||||
|
|
||||||
accum = self.get_slot(var, 'accumulator')
|
accum = self.get_slot(var, 'accumulator')
|
||||||
linear = self.get_slot(var, 'linear')
|
linear = self.get_slot(var, 'linear')
|
||||||
|
|
||||||
@ -175,7 +192,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
indices=indices,
|
indices=indices,
|
||||||
lr=coefficients['lr_t'],
|
lr=coefficients['lr_t'],
|
||||||
l1=coefficients['l1_regularization_strength'],
|
l1=coefficients['l1_regularization_strength'],
|
||||||
l2=coefficients['l2_regularization_strength'],
|
l2=adjusted_l2_regularization_strength,
|
||||||
lr_power=coefficients['learning_rate_power'],
|
lr_power=coefficients['learning_rate_power'],
|
||||||
use_locking=self._use_locking)
|
use_locking=self._use_locking)
|
||||||
else:
|
else:
|
||||||
@ -187,7 +204,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
indices=indices,
|
indices=indices,
|
||||||
lr=coefficients['lr_t'],
|
lr=coefficients['lr_t'],
|
||||||
l1=coefficients['l1_regularization_strength'],
|
l1=coefficients['l1_regularization_strength'],
|
||||||
l2=coefficients['l2_regularization_strength'],
|
l2=adjusted_l2_regularization_strength,
|
||||||
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
||||||
lr_power=coefficients['learning_rate_power'],
|
lr_power=coefficients['learning_rate_power'],
|
||||||
use_locking=self._use_locking)
|
use_locking=self._use_locking)
|
||||||
@ -207,6 +224,8 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
|||||||
self._serialize_hyperparameter('l1_regularization_strength'),
|
self._serialize_hyperparameter('l1_regularization_strength'),
|
||||||
'l2_regularization_strength':
|
'l2_regularization_strength':
|
||||||
self._serialize_hyperparameter('l2_regularization_strength'),
|
self._serialize_hyperparameter('l2_regularization_strength'),
|
||||||
|
'beta':
|
||||||
|
self._serialize_hyperparameter('beta'),
|
||||||
'l2_shrinkage_regularization_strength':
|
'l2_shrinkage_regularization_strength':
|
||||||
self._l2_shrinkage_regularization_strength,
|
self._l2_shrinkage_regularization_strength,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -156,6 +156,63 @@ class FtrlOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.93460727, -1.86147261]), v1_val)
|
np.array([-0.93460727, -1.86147261]), v1_val)
|
||||||
|
|
||||||
|
def testFtrlWithBeta(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
|
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||||
|
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||||
|
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||||
|
|
||||||
|
opt = ftrl.Ftrl(3.0, initial_accumulator_value=0.1, beta=0.1)
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
v0_val, v1_val = self.evaluate([var0, var1])
|
||||||
|
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
|
||||||
|
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
|
||||||
|
|
||||||
|
# Run 10 steps FTRL
|
||||||
|
for _ in range(10):
|
||||||
|
update.run()
|
||||||
|
v0_val, v1_val = self.evaluate([var0, var1])
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([-6.096838, -9.162214]), v0_val)
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([-0.717741, -1.425132]), v1_val)
|
||||||
|
|
||||||
|
def testFtrlWithL2_Beta(self):
|
||||||
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
|
for dtype in [dtypes.half, dtypes.float32]:
|
||||||
|
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||||
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
|
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||||
|
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||||
|
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||||
|
|
||||||
|
opt = ftrl.Ftrl(
|
||||||
|
3.0,
|
||||||
|
initial_accumulator_value=0.1,
|
||||||
|
l1_regularization_strength=0.0,
|
||||||
|
l2_regularization_strength=0.1,
|
||||||
|
beta=0.1)
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
v0_val, v1_val = self.evaluate([var0, var1])
|
||||||
|
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
|
||||||
|
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
|
||||||
|
|
||||||
|
# Run 10 steps FTRL
|
||||||
|
for _ in range(10):
|
||||||
|
update.run()
|
||||||
|
v0_val, v1_val = self.evaluate([var0, var1])
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([-2.735487, -4.704625]), v0_val)
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([-0.294335, -0.586556]), v1_val)
|
||||||
|
|
||||||
def testFtrlWithL1_L2(self):
|
def testFtrlWithL1_L2(self):
|
||||||
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||||
for dtype in [dtypes.half, dtypes.float32]:
|
for dtype in [dtypes.half, dtypes.float32]:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_slot"
|
name: "add_slot"
|
||||||
|
|||||||
@ -22,7 +22,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_slot"
|
name: "add_slot"
|
||||||
|
|||||||
@ -22,7 +22,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_slot"
|
name: "add_slot"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user