diff --git a/RELEASE.md b/RELEASE.md index 62bdc11aa68..0eb673b0db7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -95,6 +95,7 @@ * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` as an alternative to accepting a `callable` loss. + * Added `beta` parameter to FTRL optimizer to match paper. * `tf.function` / AutoGraph: * Added `experimental_follow_type_hints` argument for `tf.function`. When True, the function may use type annotations to optimize the tracing diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index c7b3867631d..6c8a6ceadc5 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -49,7 +49,8 @@ class FtrlOptimizer(optimizer.Optimizer): name="Ftrl", accum_name=None, linear_name=None, - l2_shrinkage_regularization_strength=0.0): + l2_shrinkage_regularization_strength=0.0, + beta=None): r"""Construct a new FTRL optimizer. Args: @@ -79,10 +80,11 @@ class FtrlOptimizer(optimizer.Optimizer): function w.r.t. the weights w. Specifically, in the absence of L1 regularization, it is equivalent to the following update rule: - w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - - 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t + w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t - + 2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t where lr_t is the learning rate at t. When input is sparse shrinkage will only happen on the active weights. + beta: A float value; corresponds to the beta parameter in the paper. Raises: ValueError: If one of the arguments is invalid. @@ -119,12 +121,13 @@ class FtrlOptimizer(optimizer.Optimizer): self._initial_accumulator_value = initial_accumulator_value self._l1_regularization_strength = l1_regularization_strength self._l2_regularization_strength = l2_regularization_strength + self._beta = (0.0 if beta is None else beta) self._l2_shrinkage_regularization_strength = ( l2_shrinkage_regularization_strength) self._learning_rate_tensor = None self._learning_rate_power_tensor = None self._l1_regularization_strength_tensor = None - self._l2_regularization_strength_tensor = None + self._adjusted_l2_regularization_strength_tensor = None self._l2_shrinkage_regularization_strength_tensor = None self._accum_name = accum_name self._linear_name = linear_name @@ -142,8 +145,14 @@ class FtrlOptimizer(optimizer.Optimizer): self._learning_rate, name="learning_rate") self._l1_regularization_strength_tensor = ops.convert_to_tensor( self._l1_regularization_strength, name="l1_regularization_strength") - self._l2_regularization_strength_tensor = ops.convert_to_tensor( - self._l2_regularization_strength, name="l2_regularization_strength") + # L2 regularization strength with beta added in so that the underlying + # TensorFlow ops do not need to include that parameter. + self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor( + self._l2_regularization_strength + self._beta / + (2. * self._learning_rate), + name="adjusted_l2_regularization_strength") + assert self._adjusted_l2_regularization_strength_tensor is not None + self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor( self._l2_shrinkage_regularization_strength, name="l2_shrinkage_regularization_strength") @@ -162,7 +171,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), use_locking=self._use_locking) @@ -175,7 +184,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, var.dtype.base_dtype), @@ -194,7 +203,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), use_locking=self._use_locking) @@ -207,7 +216,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, var.dtype.base_dtype), @@ -227,7 +236,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), use_locking=self._use_locking) @@ -241,7 +250,7 @@ class FtrlOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype), - math_ops.cast(self._l2_regularization_strength_tensor, + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype), math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, grad.dtype.base_dtype), @@ -260,7 +269,8 @@ class FtrlOptimizer(optimizer.Optimizer): indices, math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), - math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, + grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype), use_locking=self._use_locking) else: @@ -272,7 +282,8 @@ class FtrlOptimizer(optimizer.Optimizer): indices, math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), - math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), + math_ops.cast(self._adjusted_l2_regularization_strength_tensor, + grad.dtype), math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype), diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py index f0cbe13e037..ff1bf177a72 100644 --- a/tensorflow/python/training/ftrl_test.py +++ b/tensorflow/python/training/ftrl_test.py @@ -161,6 +161,65 @@ class FtrlOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType( np.array([-0.93460727, -1.86147261]), v1_val) + def testFtrlWithBeta(self): + # The v1 optimizers do not support eager execution + with ops.Graph().as_default(): + for dtype in [dtypes.half, dtypes.float32]: + with self.cached_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = ftrl.FtrlOptimizer(3.0, initial_accumulator_value=0.1, beta=0.1) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllCloseAccordingToType([1.0, 2.0], v0_val) + self.assertAllCloseAccordingToType([4.0, 3.0], v1_val) + + # Run 10 steps FTRL + for _ in range(10): + update.run() + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllCloseAccordingToType( + np.array([-6.096838, -9.162214]), v0_val) + self.assertAllCloseAccordingToType( + np.array([-0.717741, -1.425132]), v1_val) + + def testFtrlWithL2_Beta(self): + # The v1 optimizers do not support eager execution + with ops.Graph().as_default(): + for dtype in [dtypes.half, dtypes.float32]: + with self.cached_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.1, + beta=0.1) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllCloseAccordingToType([1.0, 2.0], v0_val) + self.assertAllCloseAccordingToType([4.0, 3.0], v1_val) + + # Run 10 steps FTRL + for _ in range(10): + update.run() + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllCloseAccordingToType( + np.array([-2.735487, -4.704625]), v0_val) + self.assertAllCloseAccordingToType( + np.array([-0.294335, -0.586556]), v1_val) + def testFtrlWithL1_L2(self): # The v1 optimizers do not support eager execution with ops.Graph().as_default(): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt index 1d1aceb0138..9e12ae9b71f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt @@ -18,7 +18,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], " + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\', \'beta\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\', \'None\'], " } member_method { name: "apply_gradients"