Change TPU Embedding API to allow passing functions in the initializer rather than keys. This more closely maps to the feature column API.

PiperOrigin-RevId: 281926084
Change-Id: I6f653d048aa0c00940b70a4616dbc63376ab25c2
This commit is contained in:
Bruce Fontaine 2019-11-22 01:50:54 -08:00 committed by TensorFlower Gardener
parent a30111665f
commit 5fb6dcaea8

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
@ -49,7 +50,7 @@ INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
class TableConfig( class TableConfig(
collections.namedtuple('TableConfig', [ collections.namedtuple('TableConfig', [
'vocabulary_size', 'dimension', 'initializer', 'combiner', 'vocabulary_size', 'dimension', 'initializer', 'combiner',
'hot_id_replication', 'learning_rate', 'learning_rate_key' 'hot_id_replication', 'learning_rate', 'learning_rate_fn'
])): ])):
"""Embedding table configuration.""" """Embedding table configuration."""
@ -60,7 +61,7 @@ class TableConfig(
combiner='mean', combiner='mean',
hot_id_replication=False, hot_id_replication=False,
learning_rate=None, learning_rate=None,
learning_rate_key=None): learning_rate_fn=None):
"""Embedding table configuration. """Embedding table configuration.
Args: Args:
@ -79,17 +80,16 @@ class TableConfig(
hot_id_replication: If true, enables hot id replication, which can make hot_id_replication: If true, enables hot id replication, which can make
embedding lookups faster if there are some hot rows in the table. embedding lookups faster if there are some hot rows in the table.
learning_rate: float, static learning rate for this table. If learning_rate: float, static learning rate for this table. If
learning_rate and learning_rate_key are both `None`, global learning_rate and learning_rate_fn are both `None`, global
static learning rate as specified in `optimization_parameters` in static learning rate as specified in `optimization_parameters` in
`TPUEmbedding` constructor will be used. `learning_rate_key` must be `TPUEmbedding` constructor will be used. `learning_rate_fn` must be
`None` if `learning_rate` is not `None. `None` if `learning_rate` is not `None.
learning_rate_key: string, use dynamic learning rate of learning_rate_fn: string, use dynamic learning rate given by the function.
`learning_rates[learning_rate_key]` for this table, where This function function will be passed the current global step. If
`learning_rates` is the second argument of learning_rate and learning_rate_fn are both `None`, global static
`generate_send_gradients_op()`. If learning_rate and learning_rate_key learning rate as specified in `optimization_parameters` in
are both `None`, global static learning rate as specified in `TPUEmbedding` constructor will be used. `learning_rate` must be `None`
`optimization_parameters` in `TPUEmbedding` constructor will be used. if `learning_rate_fn` is not `None.
`learning_rate` must be `None` if `learning_rate_key` is not `None.
Returns: Returns:
`TableConfig`. `TableConfig`.
@ -99,7 +99,7 @@ class TableConfig(
ValueError: if `dimension` is not positive integer. ValueError: if `dimension` is not positive integer.
ValueError: if `initializer` is specified and is not callable. ValueError: if `initializer` is specified and is not callable.
ValueError: if `combiner` is not supported. ValueError: if `combiner` is not supported.
ValueError: if `learning_rate` and `learning_rate_key` are both not ValueError: if `learning_rate` and `learning_rate_fn` are both not
`None`. `None`.
""" """
if not isinstance(vocabulary_size, int) or vocabulary_size < 1: if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
@ -117,14 +117,14 @@ class TableConfig(
if combiner not in ('mean', 'sum', 'sqrtn', None): if combiner not in ('mean', 'sum', 'sqrtn', None):
raise ValueError('Invalid combiner {}'.format(combiner)) raise ValueError('Invalid combiner {}'.format(combiner))
if learning_rate is not None and learning_rate_key is not None: if learning_rate is not None and learning_rate_fn is not None:
raise ValueError('At most one of learning_rate and learning_rate_key ' raise ValueError('At most one of learning_rate and learning_rate_fn '
'can be None; got {} and {}' 'can be None; got {} and {}'
.format(learning_rate, learning_rate_key)) .format(learning_rate, learning_rate_fn))
return super(TableConfig, cls).__new__( return super(TableConfig, cls).__new__(
cls, vocabulary_size, dimension, initializer, combiner, cls, vocabulary_size, dimension, initializer, combiner,
hot_id_replication, learning_rate, learning_rate_key) hot_id_replication, learning_rate, learning_rate_fn)
class FeatureConfig( class FeatureConfig(
@ -694,6 +694,11 @@ class TPUEmbedding(object):
self._optimization_parameters) self._optimization_parameters)
self._pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core = (
pipeline_execution_with_tensor_core) pipeline_execution_with_tensor_core)
self._learning_rate_fn = list(set(
c.learning_rate_fn for c in self._table_to_config_dict.values()
if c.learning_rate_fn is not None))
self._learning_rate_fn_to_tag = {
fn: id for id, fn in enumerate(self._learning_rate_fn)}
self._config_proto = self._create_config_proto() self._config_proto = self._create_config_proto()
@ -767,10 +772,6 @@ class TPUEmbedding(object):
def _create_config_proto(self): def _create_config_proto(self):
"""Create `TPUEmbeddingConfiguration`.""" """Create `TPUEmbeddingConfiguration`."""
self._learning_rate_keys = list(
set(c.learning_rate_key
for c in self._table_to_config_dict.values()
if c.learning_rate_key is not None))
config_proto = elc.TPUEmbeddingConfiguration() config_proto = elc.TPUEmbeddingConfiguration()
for table in self._table_to_config_dict: for table in self._table_to_config_dict:
table_descriptor = config_proto.table_descriptor.add() table_descriptor = config_proto.table_descriptor.add()
@ -788,9 +789,9 @@ class TPUEmbedding(object):
parameters = table_descriptor.optimization_parameters parameters = table_descriptor.optimization_parameters
if table_config.learning_rate: if table_config.learning_rate:
parameters.learning_rate.constant = (table_config.learning_rate) parameters.learning_rate.constant = (table_config.learning_rate)
elif table_config.learning_rate_key: elif table_config.learning_rate_fn:
parameters.learning_rate.dynamic.tag = ( parameters.learning_rate.dynamic.tag = (
self._learning_rate_keys.index(table_config.learning_rate_key)) self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
else: else:
parameters.learning_rate.constant = ( parameters.learning_rate.constant = (
self._optimization_parameters.learning_rate) self._optimization_parameters.learning_rate)
@ -1097,14 +1098,13 @@ class TPUEmbedding(object):
def generate_send_gradients_op(self, def generate_send_gradients_op(self,
feature_to_gradient_dict, feature_to_gradient_dict,
learning_rates=None): step=None):
"""Send gradient to TPU embedding. """Send gradient to TPU embedding.
Args: Args:
feature_to_gradient_dict: dict mapping feature names to gradient wrt feature_to_gradient_dict: dict mapping feature names to gradient wrt
activations. activations.
learning_rates: dict mapping from learning rate key to dynamic learning step: the current global step, used for dynamic learning rate.
rate. Defaults to `None`.
Returns: Returns:
SendTPUEmbeddingGradients Op. SendTPUEmbeddingGradients Op.
@ -1116,9 +1116,8 @@ class TPUEmbedding(object):
raise RuntimeError('Only in training mode gradients need to ' raise RuntimeError('Only in training mode gradients need to '
'be sent to TPU embedding; got mode {}.' 'be sent to TPU embedding; got mode {}.'
.format(self._mode)) .format(self._mode))
if step is None and self._learning_rate_fn:
if learning_rates is None: raise ValueError('There are dynamic learning rates but step is None.')
learning_rates = dict()
gradients = [] gradients = []
for table in self._table_to_features_dict: for table in self._table_to_features_dict:
@ -1137,9 +1136,8 @@ class TPUEmbedding(object):
return tpu_ops.send_tpu_embedding_gradients( return tpu_ops.send_tpu_embedding_gradients(
inputs=gradients, inputs=gradients,
learning_rates=[ learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
learning_rates[tag] for tag in self._learning_rate_keys for fn in self._learning_rate_fn],
],
config=self.config_proto.SerializeToString()) config=self.config_proto.SerializeToString())