Support dynamic learning rate in mid-level API.

PiperOrigin-RevId: 260635422
This commit is contained in:
A. Unique TensorFlower 2019-07-29 20:13:52 -07:00 committed by TensorFlower Gardener
parent bdbd1d27dc
commit e5bfe5636c

View File

@ -43,11 +43,13 @@ TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
# as AdagradParameters etc instead of learning_rate.
class TableConfig( class TableConfig(
collections.namedtuple( collections.namedtuple('TableConfig', [
'TableConfig', 'vocabulary_size', 'dimension', 'initializer', 'combiner',
['vocabulary_size', 'dimension', 'initializer', 'combiner', 'hot_id_replication', 'learning_rate', 'learning_rate_key'
'hot_id_replication'])): ])):
"""Embedding table configuration.""" """Embedding table configuration."""
def __new__(cls, def __new__(cls,
@ -55,7 +57,9 @@ class TableConfig(
dimension, dimension,
initializer=None, initializer=None,
combiner='mean', combiner='mean',
hot_id_replication=False): hot_id_replication=False,
learning_rate=None,
learning_rate_key=None):
"""Embedding table configuration. """Embedding table configuration.
Args: Args:
@ -73,6 +77,18 @@ class TableConfig(
than sparse tensors. than sparse tensors.
hot_id_replication: If true, enables hot id replication, which can make hot_id_replication: If true, enables hot id replication, which can make
embedding lookups faster if there are some hot rows in the table. embedding lookups faster if there are some hot rows in the table.
learning_rate: float, static learning rate for this table. If
learning_rate and learning_rate_key are both `None`, global
static learning rate as specified in `optimization_parameters` in
`TPUEmbedding` constructor will be used. `learning_rate_key` must be
`None` if `learning_rate` is not `None.
learning_rate_key: string, use dynamic learning rate of
`learning_rates[learning_rate_key]` for this table, where
`learning_rates` is the second argument of
`generate_send_gradients_op()`. If learning_rate and learning_rate_key
are both `None`, global static learning rate as specified in
`optimization_parameters` in `TPUEmbedding` constructor will be used.
`learning_rate` must be `None` if `learning_rate_key` is not `None.
Returns: Returns:
`TableConfig`. `TableConfig`.
@ -82,6 +98,8 @@ class TableConfig(
ValueError: if `dimension` is not positive integer. ValueError: if `dimension` is not positive integer.
ValueError: if `initializer` is specified and is not callable. ValueError: if `initializer` is specified and is not callable.
ValueError: if `combiner` is not supported. ValueError: if `combiner` is not supported.
ValueError: if `learning_rate` and `learning_rate_key` are both not
`None`.
""" """
if not isinstance(vocabulary_size, int) or vocabulary_size < 1: if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
@ -98,9 +116,14 @@ class TableConfig(
if combiner not in ('mean', 'sum', 'sqrtn', None): if combiner not in ('mean', 'sum', 'sqrtn', None):
raise ValueError('Invalid combiner {}'.format(combiner)) raise ValueError('Invalid combiner {}'.format(combiner))
return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension, if learning_rate is not None and learning_rate_key is not None:
initializer, combiner, raise ValueError('At most one of learning_rate and learning_rate_key '
hot_id_replication) 'can be None; got {} and {}'
.format(learning_rate, learning_rate_key))
return super(TableConfig, cls).__new__(
cls, vocabulary_size, dimension, initializer, combiner,
hot_id_replication, learning_rate, learning_rate_key)
class FeatureConfig( class FeatureConfig(
@ -661,6 +684,10 @@ class TPUEmbedding(object):
def _create_config_proto(self): def _create_config_proto(self):
"""Create `TPUEmbeddingConfiguration`.""" """Create `TPUEmbeddingConfiguration`."""
self._learning_rate_keys = list(
set(c.learning_rate_key
for c in self._table_to_config_dict.values()
if c.learning_rate_key is not None))
config_proto = elc.TPUEmbeddingConfiguration() config_proto = elc.TPUEmbeddingConfiguration()
for table in self._table_to_config_dict: for table in self._table_to_config_dict:
table_descriptor = config_proto.table_descriptor.add() table_descriptor = config_proto.table_descriptor.add()
@ -676,8 +703,14 @@ class TPUEmbedding(object):
table_descriptor.num_features = self._table_to_num_features_dict[table] table_descriptor.num_features = self._table_to_num_features_dict[table]
parameters = table_descriptor.optimization_parameters parameters = table_descriptor.optimization_parameters
parameters.learning_rate.constant = ( if table_config.learning_rate:
self._optimization_parameters.learning_rate) parameters.learning_rate.constant = (table_config.learning_rate)
elif table_config.learning_rate_key:
parameters.learning_rate.dynamic.tag = (
self._learning_rate_keys.index(table_config.learning_rate_key))
else:
parameters.learning_rate.constant = (
self._optimization_parameters.learning_rate)
parameters.gradient_accumulation_status = ( parameters.gradient_accumulation_status = (
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
if self._optimization_parameters.use_gradient_accumulation else if self._optimization_parameters.use_gradient_accumulation else
@ -969,12 +1002,16 @@ class TPUEmbedding(object):
return activations return activations
def generate_send_gradients_op(self, feature_to_gradient_dict): def generate_send_gradients_op(self,
feature_to_gradient_dict,
learning_rates=None):
"""Send gradient to TPU embedding. """Send gradient to TPU embedding.
Args: Args:
feature_to_gradient_dict: dict mapping feature names to gradient wrt feature_to_gradient_dict: dict mapping feature names to gradient wrt
activations. activations.
learning_rates: dict mapping from learning rate key to dynamic learning
rate. Defaults to `None`.
Returns: Returns:
SendTPUEmbeddingGradients Op. SendTPUEmbeddingGradients Op.
@ -986,6 +1023,10 @@ class TPUEmbedding(object):
raise RuntimeError('Only in training mode gradients need to ' raise RuntimeError('Only in training mode gradients need to '
'be sent to TPU embedding; got mode {}.' 'be sent to TPU embedding; got mode {}.'
.format(self._mode)) .format(self._mode))
if learning_rates is None:
learning_rates = dict()
gradients = [] gradients = []
for table in self._table_to_features_dict: for table in self._table_to_features_dict:
features = self._table_to_features_dict[table] features = self._table_to_features_dict[table]
@ -1000,8 +1041,13 @@ class TPUEmbedding(object):
array_ops.concat(table_gradients, axis=1), array_ops.concat(table_gradients, axis=1),
[-1, array_ops.shape(table_gradients[0])[-1]]) [-1, array_ops.shape(table_gradients[0])[-1]])
gradients.append(interleaved_table_grads) gradients.append(interleaved_table_grads)
return tpu_ops.send_tpu_embedding_gradients( return tpu_ops.send_tpu_embedding_gradients(
inputs=gradients, config=self.config_proto.SerializeToString()) inputs=gradients,
learning_rates=[
learning_rates[tag] for tag in self._learning_rate_keys
],
config=self.config_proto.SerializeToString())
def _validate_table_to_config_dict(table_to_config_dict): def _validate_table_to_config_dict(table_to_config_dict):