Expose gradient clipping for TPU embeddings.

PiperOrigin-RevId: 337941995
Change-Id: I6c8970a71520a9425654110babdb41aee3c861ba
This commit is contained in:
A. Unique TensorFlower 2020-10-19 15:03:30 -07:00 committed by TensorFlower Gardener
parent 19fda561d8
commit ea13fb0c5a
5 changed files with 118 additions and 28 deletions

View File

@ -370,6 +370,8 @@ class _OptimizationParameters(object):
clip_weight_max: Optional[float],
weight_decay_factor: Optional[float],
multiply_weight_decay_factor_by_learning_rate: Optional[bool],
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
self.learning_rate = learning_rate
self.use_gradient_accumulation = use_gradient_accumulation
@ -378,6 +380,8 @@ class _OptimizationParameters(object):
self.weight_decay_factor = weight_decay_factor
self.multiply_weight_decay_factor_by_learning_rate = (
multiply_weight_decay_factor_by_learning_rate)
self.clip_gradient_min = clip_gradient_min
self.clip_gradient_max = clip_gradient_max
@tf_export(v1=['tpu.experimental.AdagradParameters'])
@ -409,6 +413,8 @@ class AdagradParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for Adagrad.
@ -425,11 +431,20 @@ class AdagradParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(AdagradParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(AdagradParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=use_gradient_accumulation,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
if initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
self.initial_accumulator = initial_accumulator
@ -455,6 +470,8 @@ class ProximalAdagradParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for Adagrad.
@ -474,11 +491,20 @@ class ProximalAdagradParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(ProximalAdagradParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(ProximalAdagradParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=use_gradient_accumulation,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
if initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
if l1_regularization_strength < 0.:
@ -527,6 +553,8 @@ class AdamParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for Adam.
@ -551,11 +579,20 @@ class AdamParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(AdamParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(AdamParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=use_gradient_accumulation,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
if beta1 < 0. or beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
if beta2 < 0. or beta2 >= 1.:
@ -608,6 +645,8 @@ class FtrlParameters(_OptimizationParameters):
multiply_linear_by_learning_rate: bool = False,
beta: float = 0,
allow_zero_accumulator: bool = False,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for Ftrl.
@ -644,11 +683,20 @@ class FtrlParameters(_OptimizationParameters):
allow_zero_accumulator: Changes the implementation of the square root to
allow for the case of initial_accumulator_value being zero. This will
cause a slight performance drop.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(FtrlParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(FtrlParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=use_gradient_accumulation,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
if learning_rate_power > 0.:
raise ValueError('learning_rate_power must be less than or equal to 0. '
'got {}.'.format(learning_rate_power))
@ -703,6 +751,8 @@ class ProximalYogiParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for Proximal Yogi.
@ -728,11 +778,20 @@ class ProximalYogiParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(ProximalYogiParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(ProximalYogiParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=use_gradient_accumulation,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
if beta1 < 0. or beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
if beta2 < 0. or beta2 >= 1.:
@ -783,6 +842,8 @@ class MomentumParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for momentum.
@ -807,6 +868,8 @@ class MomentumParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(MomentumParameters, self).__init__(
learning_rate=learning_rate,
@ -816,6 +879,8 @@ class MomentumParameters(_OptimizationParameters):
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
self.momentum = momentum
self.use_nesterov = use_nesterov
@ -851,6 +916,8 @@ class RMSPropParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for RMS prop.
@ -868,6 +935,8 @@ class RMSPropParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(RMSPropParameters, self).__init__(
learning_rate=learning_rate,
@ -877,6 +946,8 @@ class RMSPropParameters(_OptimizationParameters):
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
self.rho = rho
self.momentum = momentum
@ -910,6 +981,8 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
clip_gradient_min: Optional[float] = None,
clip_gradient_max: Optional[float] = None,
):
"""Optimization parameters for stochastic gradient descent.
@ -921,11 +994,20 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clip_gradient_min: the minimum value to clip by; None means -infinity.
clip_gradient_max: the maximum value to clip by; None means +infinity.
"""
super(StochasticGradientDescentParameters,
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
super(StochasticGradientDescentParameters, self).__init__(
learning_rate=learning_rate,
use_gradient_accumulation=False,
clip_weight_min=clip_weight_min,
clip_weight_max=clip_weight_max,
weight_decay_factor=weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate=(
multiply_weight_decay_factor_by_learning_rate),
clip_gradient_min=clip_gradient_min,
clip_gradient_max=clip_gradient_max,
)
DeviceConfig = collections.namedtuple('DeviceConfig',
@ -1285,6 +1367,14 @@ class TPUEmbedding(object):
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
if optimization_parameters.use_gradient_accumulation else
optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
if optimization_parameters.clip_gradient_min is not None:
parameters.gradient_clipping_limits.lower.value = (
optimization_parameters.clip_gradient_min)
if optimization_parameters.clip_gradient_max is not None:
parameters.gradient_clipping_limits.upper.value = (
optimization_parameters.clip_gradient_max)
if optimization_parameters.clip_weight_min is not None:
parameters.clipping_limits.lower.value = (
optimization_parameters.clip_weight_min)

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\'], "
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\', \'None\', \'None\'], "
}
}

View File

@ -5,6 +5,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
}