Expose gradient clipping for TPU embeddings.
PiperOrigin-RevId: 337941995 Change-Id: I6c8970a71520a9425654110babdb41aee3c861ba
This commit is contained in:
parent
19fda561d8
commit
ea13fb0c5a
@ -370,6 +370,8 @@ class _OptimizationParameters(object):
|
||||
clip_weight_max: Optional[float],
|
||||
weight_decay_factor: Optional[float],
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool],
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_accumulation = use_gradient_accumulation
|
||||
@ -378,6 +380,8 @@ class _OptimizationParameters(object):
|
||||
self.weight_decay_factor = weight_decay_factor
|
||||
self.multiply_weight_decay_factor_by_learning_rate = (
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
self.clip_gradient_min = clip_gradient_min
|
||||
self.clip_gradient_max = clip_gradient_max
|
||||
|
||||
|
||||
@tf_export(v1=['tpu.experimental.AdagradParameters'])
|
||||
@ -409,6 +413,8 @@ class AdagradParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for Adagrad.
|
||||
|
||||
@ -425,11 +431,20 @@ class AdagradParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(AdagradParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(AdagradParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
if initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
self.initial_accumulator = initial_accumulator
|
||||
@ -455,6 +470,8 @@ class ProximalAdagradParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for Adagrad.
|
||||
|
||||
@ -474,11 +491,20 @@ class ProximalAdagradParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(ProximalAdagradParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(ProximalAdagradParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
if initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
if l1_regularization_strength < 0.:
|
||||
@ -527,6 +553,8 @@ class AdamParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for Adam.
|
||||
|
||||
@ -551,11 +579,20 @@ class AdamParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(AdamParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(AdamParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
if beta1 < 0. or beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
|
||||
if beta2 < 0. or beta2 >= 1.:
|
||||
@ -608,6 +645,8 @@ class FtrlParameters(_OptimizationParameters):
|
||||
multiply_linear_by_learning_rate: bool = False,
|
||||
beta: float = 0,
|
||||
allow_zero_accumulator: bool = False,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for Ftrl.
|
||||
|
||||
@ -644,11 +683,20 @@ class FtrlParameters(_OptimizationParameters):
|
||||
allow_zero_accumulator: Changes the implementation of the square root to
|
||||
allow for the case of initial_accumulator_value being zero. This will
|
||||
cause a slight performance drop.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(FtrlParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(FtrlParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
if learning_rate_power > 0.:
|
||||
raise ValueError('learning_rate_power must be less than or equal to 0. '
|
||||
'got {}.'.format(learning_rate_power))
|
||||
@ -703,6 +751,8 @@ class ProximalYogiParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for Proximal Yogi.
|
||||
|
||||
@ -728,11 +778,20 @@ class ProximalYogiParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(ProximalYogiParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(ProximalYogiParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
if beta1 < 0. or beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
|
||||
if beta2 < 0. or beta2 >= 1.:
|
||||
@ -783,6 +842,8 @@ class MomentumParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for momentum.
|
||||
|
||||
@ -807,6 +868,8 @@ class MomentumParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(MomentumParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -816,6 +879,8 @@ class MomentumParameters(_OptimizationParameters):
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
self.momentum = momentum
|
||||
self.use_nesterov = use_nesterov
|
||||
@ -851,6 +916,8 @@ class RMSPropParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for RMS prop.
|
||||
|
||||
@ -868,6 +935,8 @@ class RMSPropParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(RMSPropParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -877,6 +946,8 @@ class RMSPropParameters(_OptimizationParameters):
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
self.rho = rho
|
||||
self.momentum = momentum
|
||||
@ -910,6 +981,8 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
clip_weight_max: Optional[float] = None,
|
||||
weight_decay_factor: Optional[float] = None,
|
||||
multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
|
||||
clip_gradient_min: Optional[float] = None,
|
||||
clip_gradient_max: Optional[float] = None,
|
||||
):
|
||||
"""Optimization parameters for stochastic gradient descent.
|
||||
|
||||
@ -921,11 +994,20 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
clip_gradient_min: the minimum value to clip by; None means -infinity.
|
||||
clip_gradient_max: the maximum value to clip by; None means +infinity.
|
||||
"""
|
||||
super(StochasticGradientDescentParameters,
|
||||
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max,
|
||||
weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
super(StochasticGradientDescentParameters, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
use_gradient_accumulation=False,
|
||||
clip_weight_min=clip_weight_min,
|
||||
clip_weight_max=clip_weight_max,
|
||||
weight_decay_factor=weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate=(
|
||||
multiply_weight_decay_factor_by_learning_rate),
|
||||
clip_gradient_min=clip_gradient_min,
|
||||
clip_gradient_max=clip_gradient_max,
|
||||
)
|
||||
|
||||
|
||||
DeviceConfig = collections.namedtuple('DeviceConfig',
|
||||
@ -1285,6 +1367,14 @@ class TPUEmbedding(object):
|
||||
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
|
||||
if optimization_parameters.use_gradient_accumulation else
|
||||
optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
|
||||
|
||||
if optimization_parameters.clip_gradient_min is not None:
|
||||
parameters.gradient_clipping_limits.lower.value = (
|
||||
optimization_parameters.clip_gradient_min)
|
||||
if optimization_parameters.clip_gradient_max is not None:
|
||||
parameters.gradient_clipping_limits.upper.value = (
|
||||
optimization_parameters.clip_gradient_max)
|
||||
|
||||
if optimization_parameters.clip_weight_min is not None:
|
||||
parameters.clipping_limits.lower.value = (
|
||||
optimization_parameters.clip_weight_min)
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user