Expose Weight Decay to Python TPU Embedding API
Makes it possible to use the same optimizer as BERT for embeddings. Also, see Decoupled Weight Decay Regularization (https://arxiv.org/abs/1711.05101). PiperOrigin-RevId: 293474777 Change-Id: I6ca37d5699ed39e5983f82ce32cde910e0ada164
This commit is contained in:
parent
b7907e9465
commit
e436a14249
@ -257,12 +257,16 @@ VariablesAndOps = collections.namedtuple(
|
||||
class _OptimizationParameters(object):
|
||||
"""Parameters common to all optimizations."""
|
||||
|
||||
def __init__(self, learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max):
|
||||
def __init__(self, learning_rate, use_gradient_accumulation, clip_weight_min,
|
||||
clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate):
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_accumulation = use_gradient_accumulation
|
||||
self.clip_weight_min = clip_weight_min
|
||||
self.clip_weight_max = clip_weight_max
|
||||
self.weight_decay_factor = weight_decay_factor
|
||||
self.multiply_weight_decay_factor_by_learning_rate = (
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
|
||||
|
||||
@tf_export(v1=['tpu.experimental.AdagradParameters'])
|
||||
@ -290,7 +294,9 @@ class AdagradParameters(_OptimizationParameters):
|
||||
initial_accumulator=0.1,
|
||||
use_gradient_accumulation=True,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None):
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None):
|
||||
"""Optimization parameters for Adagrad.
|
||||
|
||||
Args:
|
||||
@ -302,10 +308,15 @@ class AdagradParameters(_OptimizationParameters):
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
"""
|
||||
super(AdagradParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max)
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
if initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
self.initial_accumulator = initial_accumulator
|
||||
@ -340,7 +351,9 @@ class AdamParameters(_OptimizationParameters):
|
||||
sum_inside_sqrt=True,
|
||||
use_gradient_accumulation=True,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None):
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None):
|
||||
"""Optimization parameters for Adam.
|
||||
|
||||
Args:
|
||||
@ -360,10 +373,15 @@ class AdamParameters(_OptimizationParameters):
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
"""
|
||||
super(AdamParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max)
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
if beta1 < 0. or beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
|
||||
if beta2 < 0. or beta2 >= 1.:
|
||||
@ -409,7 +427,9 @@ class FtrlParameters(_OptimizationParameters):
|
||||
l2_regularization_strength=0.0,
|
||||
use_gradient_accumulation=True,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None):
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None):
|
||||
"""Optimization parameters for Ftrl.
|
||||
|
||||
Args:
|
||||
@ -430,10 +450,15 @@ class FtrlParameters(_OptimizationParameters):
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
"""
|
||||
super(FtrlParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max)
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
if learning_rate_power > 0.:
|
||||
raise ValueError('learning_rate_power must be less than or equal to 0. '
|
||||
'got {}.'.format(learning_rate_power))
|
||||
@ -477,17 +502,27 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, clip_weight_min=None,
|
||||
clip_weight_max=None):
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None):
|
||||
"""Optimization parameters for stochastic gradient descent.
|
||||
|
||||
Args:
|
||||
learning_rate: a floating point value. The learning rate.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
"""
|
||||
super(StochasticGradientDescentParameters,
|
||||
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max)
|
||||
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max,
|
||||
weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
|
||||
|
||||
DeviceConfig = collections.namedtuple('DeviceConfig',
|
||||
@ -557,6 +592,40 @@ class TPUEmbedding(object):
|
||||
sess.run(enqueue_ops)
|
||||
loss_val = sess.run(loss)
|
||||
```
|
||||
|
||||
Example with weight decay:
|
||||
|
||||
>>> def learning_rate_fn(global_step):
|
||||
... return tf.compat.v1.train.polynomial_decay(
|
||||
... learning_rate=5e-5,
|
||||
... global_step=global_step,
|
||||
... decay_steps=100000,
|
||||
... end_learning_rate=0.0)
|
||||
>>> wordpiece_table_config = TableConfig(
|
||||
... vocabulary_size=119547,
|
||||
... dimension=768,
|
||||
... learning_rate_fn=learning_rate_fn)
|
||||
>>> wordpiece_feature_config = FeatureConfig(
|
||||
... table_id='bert/embeddings/word_embeddings',
|
||||
... max_sequence_length=512)
|
||||
>>> optimization_parameters = AdamParameters(
|
||||
... learning_rate=5e-5,
|
||||
... epsilon=1e-6,
|
||||
... weight_decay_factor=0.01,
|
||||
... multiply_weight_decay_factor_by_learning_rate=True)
|
||||
>>> tpu_embedding = TPUEmbedding(
|
||||
... table_to_config_dict={
|
||||
... 'bert/embeddings/word_embeddings': wordpiece_table_config,
|
||||
... },
|
||||
... feature_to_config_dict={'input_ids': wordpiece_feature_config},
|
||||
... batch_size=128,
|
||||
... mode=TRAINING,
|
||||
... optimization_parameters=optimization_parameters,
|
||||
... device_config=DeviceConfig(
|
||||
... num_cores=64, num_hosts=4, job_name='tpu_worker'))
|
||||
>>> with tf.Graph().as_default():
|
||||
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
|
||||
... embedding_config=tpu_embedding.config_proto, job='tpu_worker')
|
||||
"""
|
||||
|
||||
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
|
||||
@ -814,6 +883,12 @@ class TPUEmbedding(object):
|
||||
if self._optimization_parameters.clip_weight_max is not None:
|
||||
parameters.clipping_limits.upper.value = (
|
||||
self._optimization_parameters.clip_weight_max)
|
||||
if self._optimization_parameters.weight_decay_factor:
|
||||
parameters.weight_decay_factor = (
|
||||
self._optimization_parameters.weight_decay_factor)
|
||||
if (self._optimization_parameters
|
||||
.multiply_weight_decay_factor_by_learning_rate):
|
||||
parameters.multiply_weight_decay_factor_by_learning_rate = True
|
||||
if table_config.hot_id_replication:
|
||||
parameters.hot_id_replication_configuration.status = (
|
||||
optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user