Expose Weight Decay to Python TPU Embedding API

Makes it possible to use the same optimizer as BERT for embeddings. Also, see
Decoupled Weight Decay Regularization (https://arxiv.org/abs/1711.05101).

PiperOrigin-RevId: 293474777
Change-Id: I6ca37d5699ed39e5983f82ce32cde910e0ada164
This commit is contained in:
Philip Pham 2020-02-05 16:07:37 -08:00 committed by TensorFlower Gardener
parent b7907e9465
commit e436a14249
5 changed files with 90 additions and 15 deletions

View File

@ -257,12 +257,16 @@ VariablesAndOps = collections.namedtuple(
class _OptimizationParameters(object):
"""Parameters common to all optimizations."""
def __init__(self, learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max):
def __init__(self, learning_rate, use_gradient_accumulation, clip_weight_min,
clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate):
self.learning_rate = learning_rate
self.use_gradient_accumulation = use_gradient_accumulation
self.clip_weight_min = clip_weight_min
self.clip_weight_max = clip_weight_max
self.weight_decay_factor = weight_decay_factor
self.multiply_weight_decay_factor_by_learning_rate = (
multiply_weight_decay_factor_by_learning_rate)
@tf_export(v1=['tpu.experimental.AdagradParameters'])
@ -290,7 +294,9 @@ class AdagradParameters(_OptimizationParameters):
initial_accumulator=0.1,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None):
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None):
"""Optimization parameters for Adagrad.
Args:
@ -302,10 +308,15 @@ class AdagradParameters(_OptimizationParameters):
for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
"""
super(AdagradParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max)
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
if initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
self.initial_accumulator = initial_accumulator
@ -340,7 +351,9 @@ class AdamParameters(_OptimizationParameters):
sum_inside_sqrt=True,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None):
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None):
"""Optimization parameters for Adam.
Args:
@ -360,10 +373,15 @@ class AdamParameters(_OptimizationParameters):
for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
"""
super(AdamParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max)
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
if beta1 < 0. or beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
if beta2 < 0. or beta2 >= 1.:
@ -409,7 +427,9 @@ class FtrlParameters(_OptimizationParameters):
l2_regularization_strength=0.0,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None):
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None):
"""Optimization parameters for Ftrl.
Args:
@ -430,10 +450,15 @@ class FtrlParameters(_OptimizationParameters):
for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
"""
super(FtrlParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max)
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
if learning_rate_power > 0.:
raise ValueError('learning_rate_power must be less than or equal to 0. '
'got {}.'.format(learning_rate_power))
@ -477,17 +502,27 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
"""
def __init__(self, learning_rate, clip_weight_min=None,
clip_weight_max=None):
def __init__(self,
learning_rate,
clip_weight_min=None,
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None):
"""Optimization parameters for stochastic gradient descent.
Args:
learning_rate: a floating point value. The learning rate.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
"""
super(StochasticGradientDescentParameters,
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max)
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
DeviceConfig = collections.namedtuple('DeviceConfig',
@ -557,6 +592,40 @@ class TPUEmbedding(object):
sess.run(enqueue_ops)
loss_val = sess.run(loss)
```
Example with weight decay:
>>> def learning_rate_fn(global_step):
... return tf.compat.v1.train.polynomial_decay(
... learning_rate=5e-5,
... global_step=global_step,
... decay_steps=100000,
... end_learning_rate=0.0)
>>> wordpiece_table_config = TableConfig(
... vocabulary_size=119547,
... dimension=768,
... learning_rate_fn=learning_rate_fn)
>>> wordpiece_feature_config = FeatureConfig(
... table_id='bert/embeddings/word_embeddings',
... max_sequence_length=512)
>>> optimization_parameters = AdamParameters(
... learning_rate=5e-5,
... epsilon=1e-6,
... weight_decay_factor=0.01,
... multiply_weight_decay_factor_by_learning_rate=True)
>>> tpu_embedding = TPUEmbedding(
... table_to_config_dict={
... 'bert/embeddings/word_embeddings': wordpiece_table_config,
... },
... feature_to_config_dict={'input_ids': wordpiece_feature_config},
... batch_size=128,
... mode=TRAINING,
... optimization_parameters=optimization_parameters,
... device_config=DeviceConfig(
... num_cores=64, num_hosts=4, job_name='tpu_worker'))
>>> with tf.Graph().as_default():
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
... embedding_config=tpu_embedding.config_proto, job='tpu_worker')
"""
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
@ -814,6 +883,12 @@ class TPUEmbedding(object):
if self._optimization_parameters.clip_weight_max is not None:
parameters.clipping_limits.upper.value = (
self._optimization_parameters.clip_weight_max)
if self._optimization_parameters.weight_decay_factor:
parameters.weight_decay_factor = (
self._optimization_parameters.weight_decay_factor)
if (self._optimization_parameters
.multiply_weight_decay_factor_by_learning_rate):
parameters.multiply_weight_decay_factor_by_learning_rate = True
if table_config.hot_id_replication:
parameters.hot_id_replication_configuration.status = (
optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
}