Added beta parameter from FTRL paper to optimizer classes (such as the one in Keras).
PiperOrigin-RevId: 326529913 Change-Id: Ibb57acc7ea33a7c1b893487bfb58ca5befa22a81
This commit is contained in:
parent
f893e8ed94
commit
e9ca78c331
@ -106,7 +106,8 @@
|
||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` parameter to FTRL optimizer to match paper.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
|
@ -54,6 +54,8 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
or equal to zero. This differs from L2 above in that the L2 above is a
|
||||
stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
|
||||
When input is sparse shrinkage will only happen on the active weights.
|
||||
beta: A float value, representing the beta value from the paper
|
||||
(https://research.google.com/pubs/archive/41159.pdf).
|
||||
**kwargs: Keyword arguments. Allowed to be one of
|
||||
`"clipnorm"` or `"clipvalue"`.
|
||||
`"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
|
||||
@ -72,6 +74,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
l2_regularization_strength=0.0,
|
||||
name='Ftrl',
|
||||
l2_shrinkage_regularization_strength=0.0,
|
||||
beta=0.0,
|
||||
**kwargs):
|
||||
super(Ftrl, self).__init__(name, **kwargs)
|
||||
|
||||
@ -100,6 +103,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
self._set_hyper('learning_rate_power', learning_rate_power)
|
||||
self._set_hyper('l1_regularization_strength', l1_regularization_strength)
|
||||
self._set_hyper('l2_regularization_strength', l2_regularization_strength)
|
||||
self._set_hyper('beta', beta)
|
||||
self._initial_accumulator_value = initial_accumulator_value
|
||||
self._l2_shrinkage_regularization_strength = (
|
||||
l2_shrinkage_regularization_strength)
|
||||
@ -115,22 +119,29 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
|
||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||
super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
|
||||
apply_state[(var_device, var_dtype)].update(dict(
|
||||
learning_rate_power=array_ops.identity(
|
||||
self._get_hyper('learning_rate_power', var_dtype)),
|
||||
l1_regularization_strength=array_ops.identity(
|
||||
self._get_hyper('l1_regularization_strength', var_dtype)),
|
||||
l2_regularization_strength=array_ops.identity(
|
||||
self._get_hyper('l2_regularization_strength', var_dtype)),
|
||||
l2_shrinkage_regularization_strength=math_ops.cast(
|
||||
self._l2_shrinkage_regularization_strength, var_dtype)
|
||||
))
|
||||
apply_state[(var_device, var_dtype)].update(
|
||||
dict(
|
||||
learning_rate_power=array_ops.identity(
|
||||
self._get_hyper('learning_rate_power', var_dtype)),
|
||||
l1_regularization_strength=array_ops.identity(
|
||||
self._get_hyper('l1_regularization_strength', var_dtype)),
|
||||
l2_regularization_strength=array_ops.identity(
|
||||
self._get_hyper('l2_regularization_strength', var_dtype)),
|
||||
beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
|
||||
l2_shrinkage_regularization_strength=math_ops.cast(
|
||||
self._l2_shrinkage_regularization_strength, var_dtype)))
|
||||
|
||||
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
# Adjust L2 regularization strength to include beta to avoid the underlying
|
||||
# TensorFlow ops needing to include it.
|
||||
adjusted_l2_regularization_strength = (
|
||||
coefficients['l2_regularization_strength'] + coefficients['beta'] /
|
||||
(2. * coefficients['lr_t']))
|
||||
|
||||
accum = self.get_slot(var, 'accumulator')
|
||||
linear = self.get_slot(var, 'linear')
|
||||
|
||||
@ -142,7 +153,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
grad=grad,
|
||||
lr=coefficients['lr_t'],
|
||||
l1=coefficients['l1_regularization_strength'],
|
||||
l2=coefficients['l2_regularization_strength'],
|
||||
l2=adjusted_l2_regularization_strength,
|
||||
lr_power=coefficients['learning_rate_power'],
|
||||
use_locking=self._use_locking)
|
||||
else:
|
||||
@ -153,7 +164,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
grad=grad,
|
||||
lr=coefficients['lr_t'],
|
||||
l1=coefficients['l1_regularization_strength'],
|
||||
l2=coefficients['l2_regularization_strength'],
|
||||
l2=adjusted_l2_regularization_strength,
|
||||
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
||||
lr_power=coefficients['learning_rate_power'],
|
||||
use_locking=self._use_locking)
|
||||
@ -163,6 +174,12 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
# Adjust L2 regularization strength to include beta to avoid the underlying
|
||||
# TensorFlow ops needing to include it.
|
||||
adjusted_l2_regularization_strength = (
|
||||
coefficients['l2_regularization_strength'] + coefficients['beta'] /
|
||||
(2. * coefficients['lr_t']))
|
||||
|
||||
accum = self.get_slot(var, 'accumulator')
|
||||
linear = self.get_slot(var, 'linear')
|
||||
|
||||
@ -175,7 +192,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
indices=indices,
|
||||
lr=coefficients['lr_t'],
|
||||
l1=coefficients['l1_regularization_strength'],
|
||||
l2=coefficients['l2_regularization_strength'],
|
||||
l2=adjusted_l2_regularization_strength,
|
||||
lr_power=coefficients['learning_rate_power'],
|
||||
use_locking=self._use_locking)
|
||||
else:
|
||||
@ -187,7 +204,7 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
indices=indices,
|
||||
lr=coefficients['lr_t'],
|
||||
l1=coefficients['l1_regularization_strength'],
|
||||
l2=coefficients['l2_regularization_strength'],
|
||||
l2=adjusted_l2_regularization_strength,
|
||||
l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
|
||||
lr_power=coefficients['learning_rate_power'],
|
||||
use_locking=self._use_locking)
|
||||
@ -207,6 +224,8 @@ class Ftrl(optimizer_v2.OptimizerV2):
|
||||
self._serialize_hyperparameter('l1_regularization_strength'),
|
||||
'l2_regularization_strength':
|
||||
self._serialize_hyperparameter('l2_regularization_strength'),
|
||||
'beta':
|
||||
self._serialize_hyperparameter('beta'),
|
||||
'l2_shrinkage_regularization_strength':
|
||||
self._l2_shrinkage_regularization_strength,
|
||||
})
|
||||
|
@ -156,6 +156,63 @@ class FtrlOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.93460727, -1.86147261]), v1_val)
|
||||
|
||||
def testFtrlWithBeta(self):
|
||||
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||
for dtype in [dtypes.half, dtypes.float32]:
|
||||
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
|
||||
opt = ftrl.Ftrl(3.0, initial_accumulator_value=0.1, beta=0.1)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
|
||||
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
|
||||
|
||||
# Run 10 steps FTRL
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-6.096838, -9.162214]), v0_val)
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.717741, -1.425132]), v1_val)
|
||||
|
||||
def testFtrlWithL2_Beta(self):
|
||||
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||
for dtype in [dtypes.half, dtypes.float32]:
|
||||
with ops.Graph().as_default(), self.cached_session(use_gpu=True):
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
|
||||
opt = ftrl.Ftrl(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.1,
|
||||
beta=0.1)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
|
||||
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
|
||||
|
||||
# Run 10 steps FTRL
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-2.735487, -4.704625]), v0_val)
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.294335, -0.586556]), v1_val)
|
||||
|
||||
def testFtrlWithL1_L2(self):
|
||||
# TODO(tanzheny, omalleyt): Fix test in eager mode.
|
||||
for dtype in [dtypes.half, dtypes.float32]:
|
||||
|
@ -22,7 +22,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_slot"
|
||||
|
@ -22,7 +22,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_slot"
|
||||
|
@ -22,7 +22,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'Ftrl\', \'0.0\', \'0.0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_slot"
|
||||
|
Loading…
Reference in New Issue
Block a user