Change TPU Embedding API to allow passing functions in the initializer rather than keys. This more closely maps to the feature column API.
PiperOrigin-RevId: 281926084 Change-Id: I6f653d048aa0c00940b70a4616dbc63376ab25c2
This commit is contained in:
parent
a30111665f
commit
5fb6dcaea8
@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -49,7 +50,7 @@ INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
|
||||
class TableConfig(
|
||||
collections.namedtuple('TableConfig', [
|
||||
'vocabulary_size', 'dimension', 'initializer', 'combiner',
|
||||
'hot_id_replication', 'learning_rate', 'learning_rate_key'
|
||||
'hot_id_replication', 'learning_rate', 'learning_rate_fn'
|
||||
])):
|
||||
"""Embedding table configuration."""
|
||||
|
||||
@ -60,7 +61,7 @@ class TableConfig(
|
||||
combiner='mean',
|
||||
hot_id_replication=False,
|
||||
learning_rate=None,
|
||||
learning_rate_key=None):
|
||||
learning_rate_fn=None):
|
||||
"""Embedding table configuration.
|
||||
|
||||
Args:
|
||||
@ -79,17 +80,16 @@ class TableConfig(
|
||||
hot_id_replication: If true, enables hot id replication, which can make
|
||||
embedding lookups faster if there are some hot rows in the table.
|
||||
learning_rate: float, static learning rate for this table. If
|
||||
learning_rate and learning_rate_key are both `None`, global
|
||||
learning_rate and learning_rate_fn are both `None`, global
|
||||
static learning rate as specified in `optimization_parameters` in
|
||||
`TPUEmbedding` constructor will be used. `learning_rate_key` must be
|
||||
`TPUEmbedding` constructor will be used. `learning_rate_fn` must be
|
||||
`None` if `learning_rate` is not `None.
|
||||
learning_rate_key: string, use dynamic learning rate of
|
||||
`learning_rates[learning_rate_key]` for this table, where
|
||||
`learning_rates` is the second argument of
|
||||
`generate_send_gradients_op()`. If learning_rate and learning_rate_key
|
||||
are both `None`, global static learning rate as specified in
|
||||
`optimization_parameters` in `TPUEmbedding` constructor will be used.
|
||||
`learning_rate` must be `None` if `learning_rate_key` is not `None.
|
||||
learning_rate_fn: string, use dynamic learning rate given by the function.
|
||||
This function function will be passed the current global step. If
|
||||
learning_rate and learning_rate_fn are both `None`, global static
|
||||
learning rate as specified in `optimization_parameters` in
|
||||
`TPUEmbedding` constructor will be used. `learning_rate` must be `None`
|
||||
if `learning_rate_fn` is not `None.
|
||||
|
||||
Returns:
|
||||
`TableConfig`.
|
||||
@ -99,7 +99,7 @@ class TableConfig(
|
||||
ValueError: if `dimension` is not positive integer.
|
||||
ValueError: if `initializer` is specified and is not callable.
|
||||
ValueError: if `combiner` is not supported.
|
||||
ValueError: if `learning_rate` and `learning_rate_key` are both not
|
||||
ValueError: if `learning_rate` and `learning_rate_fn` are both not
|
||||
`None`.
|
||||
"""
|
||||
if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
|
||||
@ -117,14 +117,14 @@ class TableConfig(
|
||||
if combiner not in ('mean', 'sum', 'sqrtn', None):
|
||||
raise ValueError('Invalid combiner {}'.format(combiner))
|
||||
|
||||
if learning_rate is not None and learning_rate_key is not None:
|
||||
raise ValueError('At most one of learning_rate and learning_rate_key '
|
||||
if learning_rate is not None and learning_rate_fn is not None:
|
||||
raise ValueError('At most one of learning_rate and learning_rate_fn '
|
||||
'can be None; got {} and {}'
|
||||
.format(learning_rate, learning_rate_key))
|
||||
.format(learning_rate, learning_rate_fn))
|
||||
|
||||
return super(TableConfig, cls).__new__(
|
||||
cls, vocabulary_size, dimension, initializer, combiner,
|
||||
hot_id_replication, learning_rate, learning_rate_key)
|
||||
hot_id_replication, learning_rate, learning_rate_fn)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
@ -694,6 +694,11 @@ class TPUEmbedding(object):
|
||||
self._optimization_parameters)
|
||||
self._pipeline_execution_with_tensor_core = (
|
||||
pipeline_execution_with_tensor_core)
|
||||
self._learning_rate_fn = list(set(
|
||||
c.learning_rate_fn for c in self._table_to_config_dict.values()
|
||||
if c.learning_rate_fn is not None))
|
||||
self._learning_rate_fn_to_tag = {
|
||||
fn: id for id, fn in enumerate(self._learning_rate_fn)}
|
||||
|
||||
self._config_proto = self._create_config_proto()
|
||||
|
||||
@ -767,10 +772,6 @@ class TPUEmbedding(object):
|
||||
|
||||
def _create_config_proto(self):
|
||||
"""Create `TPUEmbeddingConfiguration`."""
|
||||
self._learning_rate_keys = list(
|
||||
set(c.learning_rate_key
|
||||
for c in self._table_to_config_dict.values()
|
||||
if c.learning_rate_key is not None))
|
||||
config_proto = elc.TPUEmbeddingConfiguration()
|
||||
for table in self._table_to_config_dict:
|
||||
table_descriptor = config_proto.table_descriptor.add()
|
||||
@ -788,9 +789,9 @@ class TPUEmbedding(object):
|
||||
parameters = table_descriptor.optimization_parameters
|
||||
if table_config.learning_rate:
|
||||
parameters.learning_rate.constant = (table_config.learning_rate)
|
||||
elif table_config.learning_rate_key:
|
||||
elif table_config.learning_rate_fn:
|
||||
parameters.learning_rate.dynamic.tag = (
|
||||
self._learning_rate_keys.index(table_config.learning_rate_key))
|
||||
self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
|
||||
else:
|
||||
parameters.learning_rate.constant = (
|
||||
self._optimization_parameters.learning_rate)
|
||||
@ -1097,14 +1098,13 @@ class TPUEmbedding(object):
|
||||
|
||||
def generate_send_gradients_op(self,
|
||||
feature_to_gradient_dict,
|
||||
learning_rates=None):
|
||||
step=None):
|
||||
"""Send gradient to TPU embedding.
|
||||
|
||||
Args:
|
||||
feature_to_gradient_dict: dict mapping feature names to gradient wrt
|
||||
activations.
|
||||
learning_rates: dict mapping from learning rate key to dynamic learning
|
||||
rate. Defaults to `None`.
|
||||
step: the current global step, used for dynamic learning rate.
|
||||
|
||||
Returns:
|
||||
SendTPUEmbeddingGradients Op.
|
||||
@ -1116,9 +1116,8 @@ class TPUEmbedding(object):
|
||||
raise RuntimeError('Only in training mode gradients need to '
|
||||
'be sent to TPU embedding; got mode {}.'
|
||||
.format(self._mode))
|
||||
|
||||
if learning_rates is None:
|
||||
learning_rates = dict()
|
||||
if step is None and self._learning_rate_fn:
|
||||
raise ValueError('There are dynamic learning rates but step is None.')
|
||||
|
||||
gradients = []
|
||||
for table in self._table_to_features_dict:
|
||||
@ -1137,9 +1136,8 @@ class TPUEmbedding(object):
|
||||
|
||||
return tpu_ops.send_tpu_embedding_gradients(
|
||||
inputs=gradients,
|
||||
learning_rates=[
|
||||
learning_rates[tag] for tag in self._learning_rate_keys
|
||||
],
|
||||
learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
|
||||
for fn in self._learning_rate_fn],
|
||||
config=self.config_proto.SerializeToString())
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user