Support clipping limits for TPU embedding in tpu_estimator.py.

PiperOrigin-RevId: 243928920
This commit is contained in:
A. Unique TensorFlower 2019-04-16 20:36:49 -07:00 committed by TensorFlower Gardener
parent 18daf67542
commit 7d2fd55ef6

View File

@ -216,16 +216,23 @@ VariablesAndOps = collections.namedtuple(
class _OptimizationParameters(object):
"""Parameters common to all optimizations."""
def __init__(self, learning_rate, use_gradient_accumulation):
def __init__(self, learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max):
self.learning_rate = learning_rate
self.use_gradient_accumulation = use_gradient_accumulation
self.clip_weight_min = clip_weight_min
self.clip_weight_max = clip_weight_max
class AdagradParameters(_OptimizationParameters):
"""Optimization parameters for Adagrad."""
def __init__(self, learning_rate, initial_accumulator=0.1,
use_gradient_accumulation=True):
def __init__(self,
learning_rate,
initial_accumulator=0.1,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None):
"""Optimization parameters for Adagrad.
Args:
@ -235,9 +242,12 @@ class AdagradParameters(_OptimizationParameters):
gradients calculation less accurate but faster. Please see
`optimization_parameters.proto` for details.
for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
"""
super(AdagradParameters, self).__init__(learning_rate,
use_gradient_accumulation)
super(AdagradParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max)
if initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
self.initial_accumulator = initial_accumulator
@ -246,13 +256,16 @@ class AdagradParameters(_OptimizationParameters):
class AdamParameters(_OptimizationParameters):
"""Optimization parameters for Adam."""
def __init__(self, learning_rate,
def __init__(self,
learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
lazy_adam=True,
sum_inside_sqrt=True,
use_gradient_accumulation=True):
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None):
"""Optimization parameters for Adam.
Args:
@ -270,9 +283,12 @@ class AdamParameters(_OptimizationParameters):
gradients calculation less accurate but faster. Please see
`optimization_parameters.proto` for details.
for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
"""
super(AdamParameters, self).__init__(learning_rate,
use_gradient_accumulation)
super(AdamParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max)
if beta1 < 0. or beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
if beta2 < 0. or beta2 >= 1.:
@ -291,15 +307,19 @@ class AdamParameters(_OptimizationParameters):
class StochasticGradientDescentParameters(_OptimizationParameters):
"""Optimization parameters for stochastic gradient descent."""
def __init__(self, learning_rate, clip_weight_min=None,
clip_weight_max=None):
"""Optimization parameters for stochastic gradient descent.
Args:
learning_rate: a floating point value. The learning rate.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
"""
def __init__(self, learning_rate):
super(StochasticGradientDescentParameters, self).__init__(
learning_rate, False)
super(StochasticGradientDescentParameters,
self).__init__(learning_rate, False, clip_weight_min, clip_weight_max)
class TPUEmbedding(object):
@ -566,6 +586,12 @@ class TPUEmbedding(object):
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
if self._optimization_parameters.use_gradient_accumulation else
optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
if self._optimization_parameters.clip_weight_min is not None:
table_descriptor.optimization_parameters.clipping_limits.lower.value = (
self._optimization_parameters.clip_weight_min)
if self._optimization_parameters.clip_weight_max is not None:
table_descriptor.optimization_parameters.clipping_limits.upper.value = (
self._optimization_parameters.clip_weight_max)
self._optimizer_handler.set_optimization_parameters(table_descriptor)
config_proto.mode = self._mode