Added beta parameter from FTRL paper to main optimizer class.

PiperOrigin-RevId: 325290409
Change-Id: I0aa85a26b188b9ab3e1faa7462dca4c5d81f8712
This commit is contained in:
A. Unique TensorFlower 2020-08-06 12:53:49 -07:00 committed by TensorFlower Gardener
parent 3cb0418477
commit 7df6aa0253
4 changed files with 86 additions and 15 deletions

View File

@ -95,6 +95,7 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss. as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* `tf.function` / AutoGraph: * `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When * Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing True, the function may use type annotations to optimize the tracing

View File

@ -49,7 +49,8 @@ class FtrlOptimizer(optimizer.Optimizer):
name="Ftrl", name="Ftrl",
accum_name=None, accum_name=None,
linear_name=None, linear_name=None,
l2_shrinkage_regularization_strength=0.0): l2_shrinkage_regularization_strength=0.0,
beta=None):
r"""Construct a new FTRL optimizer. r"""Construct a new FTRL optimizer.
Args: Args:
@ -79,10 +80,11 @@ class FtrlOptimizer(optimizer.Optimizer):
function w.r.t. the weights w. function w.r.t. the weights w.
Specifically, in the absence of L1 regularization, it is equivalent to Specifically, in the absence of L1 regularization, it is equivalent to
the following update rule: the following update rule:
w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t -
2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t 2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t
where lr_t is the learning rate at t. where lr_t is the learning rate at t.
When input is sparse shrinkage will only happen on the active weights. When input is sparse shrinkage will only happen on the active weights.
beta: A float value; corresponds to the beta parameter in the paper.
Raises: Raises:
ValueError: If one of the arguments is invalid. ValueError: If one of the arguments is invalid.
@ -119,12 +121,13 @@ class FtrlOptimizer(optimizer.Optimizer):
self._initial_accumulator_value = initial_accumulator_value self._initial_accumulator_value = initial_accumulator_value
self._l1_regularization_strength = l1_regularization_strength self._l1_regularization_strength = l1_regularization_strength
self._l2_regularization_strength = l2_regularization_strength self._l2_regularization_strength = l2_regularization_strength
self._beta = (0.0 if beta is None else beta)
self._l2_shrinkage_regularization_strength = ( self._l2_shrinkage_regularization_strength = (
l2_shrinkage_regularization_strength) l2_shrinkage_regularization_strength)
self._learning_rate_tensor = None self._learning_rate_tensor = None
self._learning_rate_power_tensor = None self._learning_rate_power_tensor = None
self._l1_regularization_strength_tensor = None self._l1_regularization_strength_tensor = None
self._l2_regularization_strength_tensor = None self._adjusted_l2_regularization_strength_tensor = None
self._l2_shrinkage_regularization_strength_tensor = None self._l2_shrinkage_regularization_strength_tensor = None
self._accum_name = accum_name self._accum_name = accum_name
self._linear_name = linear_name self._linear_name = linear_name
@ -142,8 +145,14 @@ class FtrlOptimizer(optimizer.Optimizer):
self._learning_rate, name="learning_rate") self._learning_rate, name="learning_rate")
self._l1_regularization_strength_tensor = ops.convert_to_tensor( self._l1_regularization_strength_tensor = ops.convert_to_tensor(
self._l1_regularization_strength, name="l1_regularization_strength") self._l1_regularization_strength, name="l1_regularization_strength")
self._l2_regularization_strength_tensor = ops.convert_to_tensor( # L2 regularization strength with beta added in so that the underlying
self._l2_regularization_strength, name="l2_regularization_strength") # TensorFlow ops do not need to include that parameter.
self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor(
self._l2_regularization_strength + self._beta /
(2. * self._learning_rate),
name="adjusted_l2_regularization_strength")
assert self._adjusted_l2_regularization_strength_tensor is not None
self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta")
self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor( self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
self._l2_shrinkage_regularization_strength, self._l2_shrinkage_regularization_strength,
name="l2_shrinkage_regularization_strength") name="l2_shrinkage_regularization_strength")
@ -162,7 +171,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking) use_locking=self._use_locking)
@ -175,7 +184,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
@ -194,7 +203,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking) use_locking=self._use_locking)
@ -207,7 +216,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
@ -227,7 +236,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
use_locking=self._use_locking) use_locking=self._use_locking)
@ -241,7 +250,7 @@ class FtrlOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._l1_regularization_strength_tensor, math_ops.cast(self._l1_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_regularization_strength_tensor, math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
var.dtype.base_dtype), var.dtype.base_dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
grad.dtype.base_dtype), grad.dtype.base_dtype),
@ -260,7 +269,8 @@ class FtrlOptimizer(optimizer.Optimizer):
indices, indices,
math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._learning_rate_tensor, grad.dtype),
math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
grad.dtype),
math_ops.cast(self._learning_rate_power_tensor, grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
use_locking=self._use_locking) use_locking=self._use_locking)
else: else:
@ -272,7 +282,8 @@ class FtrlOptimizer(optimizer.Optimizer):
indices, indices,
math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._learning_rate_tensor, grad.dtype),
math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
grad.dtype),
math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
grad.dtype), grad.dtype),
math_ops.cast(self._learning_rate_power_tensor, grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype),

View File

@ -161,6 +161,65 @@ class FtrlOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.93460727, -1.86147261]), v1_val) np.array([-0.93460727, -1.86147261]), v1_val)
def testFtrlWithBeta(self):
# The v1 optimizers do not support eager execution
with ops.Graph().as_default():
for dtype in [dtypes.half, dtypes.float32]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
opt = ftrl.FtrlOptimizer(3.0, initial_accumulator_value=0.1, beta=0.1)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run 10 steps FTRL
for _ in range(10):
update.run()
v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-6.096838, -9.162214]), v0_val)
self.assertAllCloseAccordingToType(
np.array([-0.717741, -1.425132]), v1_val)
def testFtrlWithL2_Beta(self):
# The v1 optimizers do not support eager execution
with ops.Graph().as_default():
for dtype in [dtypes.half, dtypes.float32]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
opt = ftrl.FtrlOptimizer(
3.0,
initial_accumulator_value=0.1,
l1_regularization_strength=0.0,
l2_regularization_strength=0.1,
beta=0.1)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run 10 steps FTRL
for _ in range(10):
update.run()
v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-2.735487, -4.704625]), v0_val)
self.assertAllCloseAccordingToType(
np.array([-0.294335, -0.586556]), v1_val)
def testFtrlWithL1_L2(self): def testFtrlWithL1_L2(self):
# The v1 optimizers do not support eager execution # The v1 optimizers do not support eager execution
with ops.Graph().as_default(): with ops.Graph().as_default():

View File

@ -18,7 +18,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], " argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\', \'None\'], "
} }
member_method { member_method {
name: "apply_gradients" name: "apply_gradients"