Support dynamic learning rate in mid-level API.
PiperOrigin-RevId: 260635422
This commit is contained in:
parent
bdbd1d27dc
commit
e5bfe5636c
@ -43,11 +43,13 @@ TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
|
|||||||
INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
|
INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
|
||||||
|
# as AdagradParameters etc instead of learning_rate.
|
||||||
class TableConfig(
|
class TableConfig(
|
||||||
collections.namedtuple(
|
collections.namedtuple('TableConfig', [
|
||||||
'TableConfig',
|
'vocabulary_size', 'dimension', 'initializer', 'combiner',
|
||||||
['vocabulary_size', 'dimension', 'initializer', 'combiner',
|
'hot_id_replication', 'learning_rate', 'learning_rate_key'
|
||||||
'hot_id_replication'])):
|
])):
|
||||||
"""Embedding table configuration."""
|
"""Embedding table configuration."""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
@ -55,7 +57,9 @@ class TableConfig(
|
|||||||
dimension,
|
dimension,
|
||||||
initializer=None,
|
initializer=None,
|
||||||
combiner='mean',
|
combiner='mean',
|
||||||
hot_id_replication=False):
|
hot_id_replication=False,
|
||||||
|
learning_rate=None,
|
||||||
|
learning_rate_key=None):
|
||||||
"""Embedding table configuration.
|
"""Embedding table configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -73,6 +77,18 @@ class TableConfig(
|
|||||||
than sparse tensors.
|
than sparse tensors.
|
||||||
hot_id_replication: If true, enables hot id replication, which can make
|
hot_id_replication: If true, enables hot id replication, which can make
|
||||||
embedding lookups faster if there are some hot rows in the table.
|
embedding lookups faster if there are some hot rows in the table.
|
||||||
|
learning_rate: float, static learning rate for this table. If
|
||||||
|
learning_rate and learning_rate_key are both `None`, global
|
||||||
|
static learning rate as specified in `optimization_parameters` in
|
||||||
|
`TPUEmbedding` constructor will be used. `learning_rate_key` must be
|
||||||
|
`None` if `learning_rate` is not `None.
|
||||||
|
learning_rate_key: string, use dynamic learning rate of
|
||||||
|
`learning_rates[learning_rate_key]` for this table, where
|
||||||
|
`learning_rates` is the second argument of
|
||||||
|
`generate_send_gradients_op()`. If learning_rate and learning_rate_key
|
||||||
|
are both `None`, global static learning rate as specified in
|
||||||
|
`optimization_parameters` in `TPUEmbedding` constructor will be used.
|
||||||
|
`learning_rate` must be `None` if `learning_rate_key` is not `None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`TableConfig`.
|
`TableConfig`.
|
||||||
@ -82,6 +98,8 @@ class TableConfig(
|
|||||||
ValueError: if `dimension` is not positive integer.
|
ValueError: if `dimension` is not positive integer.
|
||||||
ValueError: if `initializer` is specified and is not callable.
|
ValueError: if `initializer` is specified and is not callable.
|
||||||
ValueError: if `combiner` is not supported.
|
ValueError: if `combiner` is not supported.
|
||||||
|
ValueError: if `learning_rate` and `learning_rate_key` are both not
|
||||||
|
`None`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
|
if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
|
||||||
raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
|
raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
|
||||||
@ -98,9 +116,14 @@ class TableConfig(
|
|||||||
if combiner not in ('mean', 'sum', 'sqrtn', None):
|
if combiner not in ('mean', 'sum', 'sqrtn', None):
|
||||||
raise ValueError('Invalid combiner {}'.format(combiner))
|
raise ValueError('Invalid combiner {}'.format(combiner))
|
||||||
|
|
||||||
return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension,
|
if learning_rate is not None and learning_rate_key is not None:
|
||||||
initializer, combiner,
|
raise ValueError('At most one of learning_rate and learning_rate_key '
|
||||||
hot_id_replication)
|
'can be None; got {} and {}'
|
||||||
|
.format(learning_rate, learning_rate_key))
|
||||||
|
|
||||||
|
return super(TableConfig, cls).__new__(
|
||||||
|
cls, vocabulary_size, dimension, initializer, combiner,
|
||||||
|
hot_id_replication, learning_rate, learning_rate_key)
|
||||||
|
|
||||||
|
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
@ -661,6 +684,10 @@ class TPUEmbedding(object):
|
|||||||
|
|
||||||
def _create_config_proto(self):
|
def _create_config_proto(self):
|
||||||
"""Create `TPUEmbeddingConfiguration`."""
|
"""Create `TPUEmbeddingConfiguration`."""
|
||||||
|
self._learning_rate_keys = list(
|
||||||
|
set(c.learning_rate_key
|
||||||
|
for c in self._table_to_config_dict.values()
|
||||||
|
if c.learning_rate_key is not None))
|
||||||
config_proto = elc.TPUEmbeddingConfiguration()
|
config_proto = elc.TPUEmbeddingConfiguration()
|
||||||
for table in self._table_to_config_dict:
|
for table in self._table_to_config_dict:
|
||||||
table_descriptor = config_proto.table_descriptor.add()
|
table_descriptor = config_proto.table_descriptor.add()
|
||||||
@ -676,6 +703,12 @@ class TPUEmbedding(object):
|
|||||||
table_descriptor.num_features = self._table_to_num_features_dict[table]
|
table_descriptor.num_features = self._table_to_num_features_dict[table]
|
||||||
|
|
||||||
parameters = table_descriptor.optimization_parameters
|
parameters = table_descriptor.optimization_parameters
|
||||||
|
if table_config.learning_rate:
|
||||||
|
parameters.learning_rate.constant = (table_config.learning_rate)
|
||||||
|
elif table_config.learning_rate_key:
|
||||||
|
parameters.learning_rate.dynamic.tag = (
|
||||||
|
self._learning_rate_keys.index(table_config.learning_rate_key))
|
||||||
|
else:
|
||||||
parameters.learning_rate.constant = (
|
parameters.learning_rate.constant = (
|
||||||
self._optimization_parameters.learning_rate)
|
self._optimization_parameters.learning_rate)
|
||||||
parameters.gradient_accumulation_status = (
|
parameters.gradient_accumulation_status = (
|
||||||
@ -969,12 +1002,16 @@ class TPUEmbedding(object):
|
|||||||
|
|
||||||
return activations
|
return activations
|
||||||
|
|
||||||
def generate_send_gradients_op(self, feature_to_gradient_dict):
|
def generate_send_gradients_op(self,
|
||||||
|
feature_to_gradient_dict,
|
||||||
|
learning_rates=None):
|
||||||
"""Send gradient to TPU embedding.
|
"""Send gradient to TPU embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_to_gradient_dict: dict mapping feature names to gradient wrt
|
feature_to_gradient_dict: dict mapping feature names to gradient wrt
|
||||||
activations.
|
activations.
|
||||||
|
learning_rates: dict mapping from learning rate key to dynamic learning
|
||||||
|
rate. Defaults to `None`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SendTPUEmbeddingGradients Op.
|
SendTPUEmbeddingGradients Op.
|
||||||
@ -986,6 +1023,10 @@ class TPUEmbedding(object):
|
|||||||
raise RuntimeError('Only in training mode gradients need to '
|
raise RuntimeError('Only in training mode gradients need to '
|
||||||
'be sent to TPU embedding; got mode {}.'
|
'be sent to TPU embedding; got mode {}.'
|
||||||
.format(self._mode))
|
.format(self._mode))
|
||||||
|
|
||||||
|
if learning_rates is None:
|
||||||
|
learning_rates = dict()
|
||||||
|
|
||||||
gradients = []
|
gradients = []
|
||||||
for table in self._table_to_features_dict:
|
for table in self._table_to_features_dict:
|
||||||
features = self._table_to_features_dict[table]
|
features = self._table_to_features_dict[table]
|
||||||
@ -1000,8 +1041,13 @@ class TPUEmbedding(object):
|
|||||||
array_ops.concat(table_gradients, axis=1),
|
array_ops.concat(table_gradients, axis=1),
|
||||||
[-1, array_ops.shape(table_gradients[0])[-1]])
|
[-1, array_ops.shape(table_gradients[0])[-1]])
|
||||||
gradients.append(interleaved_table_grads)
|
gradients.append(interleaved_table_grads)
|
||||||
|
|
||||||
return tpu_ops.send_tpu_embedding_gradients(
|
return tpu_ops.send_tpu_embedding_gradients(
|
||||||
inputs=gradients, config=self.config_proto.SerializeToString())
|
inputs=gradients,
|
||||||
|
learning_rates=[
|
||||||
|
learning_rates[tag] for tag in self._learning_rate_keys
|
||||||
|
],
|
||||||
|
config=self.config_proto.SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
def _validate_table_to_config_dict(table_to_config_dict):
|
def _validate_table_to_config_dict(table_to_config_dict):
|
||||||
|
Loading…
Reference in New Issue
Block a user