Add support to the TPUEmbedding classes for sequence features.

PiperOrigin-RevId: 243712953
This commit is contained in:
Bruce Fontaine 2019-04-15 16:43:22 -07:00 committed by TensorFlower Gardener
parent b67614f20e
commit 85709f9d72
3 changed files with 121 additions and 58 deletions

View File

@ -115,7 +115,7 @@ def get_configs_from_feature_columns(feature_columns):
Returns:
A tuple of dicts, the first maps tables to their config, the second maps
features to tables, and the third maps features to weight key names.
features to their config, and the third maps features to weight key names.
"""
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
@ -127,17 +127,18 @@ def get_configs_from_feature_columns(feature_columns):
type(column), allowed))
table_to_config = {}
feature_to_table = {}
feature_to_config = {}
feature_to_weight_key_name = {}
for column in feature_columns:
feature_name = column.get_feature_key_name()
table_name = _get_table_name_from_embedding_var_name(
column.get_embedding_var_name())
if feature_name in feature_to_table:
if feature_name in feature_to_config:
raise ValueError(
'Feature column {} is used with multiple embeddings and this is '
'not supported.'.format(feature_name))
feature_to_table[feature_name] = table_name
feature_to_config[feature_name] = tpu_embedding.FeatureConfig(
table_id=table_name)
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
vocabulary_size, dimension = column.get_embedding_table_size()
table_to_config[table_name] = tpu_embedding.TableConfig(
@ -146,7 +147,7 @@ def get_configs_from_feature_columns(feature_columns):
initializer=column.get_initializer(),
combiner=column.get_combiner())
return table_to_config, feature_to_table, feature_to_weight_key_name
return table_to_config, feature_to_config, feature_to_weight_key_name
class EmbeddingConfigSpec(
@ -238,7 +239,7 @@ class EmbeddingConfig(object):
self._num_cores = num_cores
self._run_config = run_config
(self._table_to_config_dict, self._feature_to_table_dict,
(self._table_to_config_dict, self._feature_to_config_dict,
self.feature_to_weight_key_name_dict) = (
get_configs_from_feature_columns(
embedding_config_spec.feature_columns))
@ -286,7 +287,7 @@ class EmbeddingConfig(object):
cluster_def = None
tpu_embedding_ = tpu_embedding.TPUEmbedding(
self._table_to_config_dict,
self._feature_to_table_dict,
self._feature_to_config_dict,
batch_size,
tpu_embedding_mode,
master,
@ -310,7 +311,7 @@ def split_inputs(ctx, features, labels):
tpu_embedding_ = ctx.embedding_config.tpu_embedding
feature_to_weight_key_name_dict = (
ctx.embedding_config.feature_to_weight_key_name_dict)
for feature_key in tpu_embedding_.feature_to_table_dict:
for feature_key in tpu_embedding_.feature_to_config_dict:
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
weight_key_name = feature_to_weight_key_name_dict[feature_key]
if isinstance(sparse_feature, sparse_tensor.SparseTensor):

View File

@ -96,6 +96,37 @@ class TableConfig(
initializer, combiner)
class FeatureConfig(
collections.namedtuple(
'FeatureConfig',
['table_id', 'max_sequence_length'])):
"""Feature configuration."""
def __new__(cls,
table_id,
max_sequence_length=0):
"""Feature configuration.
Args:
table_id: Which table the feature is uses for embedding lookups.
max_sequence_length: If positive, the feature is a sequence feature with
the corresponding maximum sequence length. If the sequence is longer
than this, it will be truncated. If 0, the feature is not a sequence
feature.
Returns:
`FeatureConfig`.
Raises:
ValueError: if `max_sequence_length` non-negative.
"""
if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
raise ValueError('Invalid max_sequence_length {}.'.format(
max_sequence_length))
return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length)
class EnqueueData(
collections.namedtuple(
'EnqueueData',
@ -281,15 +312,15 @@ class TPUEmbedding(object):
initializer=initializer, combiner='mean')
table_to_config_dict = {'video': table_config_video,
'user': table_config_user}
feature_to_table_dict = {'watched': 'video',
'favorited': 'video',
'friends': 'user'}
feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
'favorited': tpu_embedding.FeatureConfig('video'),
'friends': tpu_embedding.FeatureConfig('user')}
batch_size = 4
num_hosts = 1
optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
mode = tpu_embedding.TRAINING
embedding = tpu_embedding.TPUEmbedding(
table_to_config_dict, feature_to_table_dict,
table_to_config_dict, feature_to_config_dict,
batch_size, num_hosts, mode, optimization_parameters)
batch_size_per_core = embedding.batch_size_per_core
@ -336,17 +367,16 @@ class TPUEmbedding(object):
```
"""
# TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table
# name, consider `feature_to_config_dict` which maps to `FeatureConfig`.
# `FeatureConfig` could have fields other than table name. For example, it
# could have a field to indicate that the feature should not be used to
# update embedding table (cr/204852758, cr/204940540). Also, this can support
# different combiners for different features within the same table.
# TODO(shizhiw): Consider addign a field to FeatureConfig that indicates that
# the feature should not be used to update embedding table (cr/204852758,
# cr/204940540). Also, this can support different combiners for different
# features within the same table.
# TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
# to `FeatureConfig`?
# TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
# `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively?
# `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
# respectively?
# TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
# for-loops around construction of inputs.
@ -356,7 +386,7 @@ class TPUEmbedding(object):
# global setting.
def __init__(self,
table_to_config_dict,
feature_to_table_dict,
feature_to_config_dict,
batch_size,
mode,
master,
@ -369,9 +399,9 @@ class TPUEmbedding(object):
table_to_config_dict: A dictionary mapping from string of table name to
`TableConfig`. Table refers to an embedding table, e.g. `params`
argument to `tf.nn.embedding_lookup_sparse()`.
feature_to_table_dict: A dictionary mapping from string of feature name
to string of table name. Feature refers to ids to lookup in embedding
table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
feature_to_config_dict: A dictionary mapping from string of feature name
to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
batch_size: An `int` representing the global batch size.
mode: `TRAINING` or `INFERENCE`.
master: A `string` representing the TensorFlow master to use.
@ -391,10 +421,12 @@ class TPUEmbedding(object):
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
_validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict)
self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict)
self._table_to_features_dict = _create_table_to_features_dict(
self._feature_to_table_dict)
_validate_feature_to_config_dict(table_to_config_dict,
feature_to_config_dict)
self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
self._table_to_features_dict, self._table_to_num_features_dict = (
_create_table_to_features_and_num_features_dicts(
self._feature_to_config_dict))
self._combiners = _create_combiners(self._table_to_config_dict,
self._table_to_features_dict)
@ -504,8 +536,8 @@ class TPUEmbedding(object):
return copy.copy(self._table_to_config_dict)
@property
def feature_to_table_dict(self):
return copy.copy(self._feature_to_table_dict)
def feature_to_config_dict(self):
return copy.copy(self._feature_to_config_dict)
@property
def table_to_features_dict(self):
@ -526,8 +558,7 @@ class TPUEmbedding(object):
table_descriptor.vocabulary_size = table_config.vocabulary_size
table_descriptor.dimension = table_config.dimension
features_for_table = self._table_to_features_dict[table]
table_descriptor.num_features = len(features_for_table)
table_descriptor.num_features = self._table_to_num_features_dict[table]
table_descriptor.optimization_parameters.learning_rate.constant = (
self._optimization_parameters.learning_rate)
@ -652,7 +683,7 @@ class TPUEmbedding(object):
def _validate_generate_enqueue_ops_enqueue_datas_list(self,
enqueue_datas_list):
"""Validate `enqueue_datas_list`."""
feature_set = set(self._feature_to_table_dict.keys())
feature_set = set(self._feature_to_config_dict.keys())
contiguous_device = None
for i, enqueue_datas in enumerate(enqueue_datas_list):
used_feature_set = set(enqueue_datas.keys())
@ -674,7 +705,7 @@ class TPUEmbedding(object):
device_feature = None
for feature, enqueue_data in six.iteritems(enqueue_datas):
combiner = self._table_to_config_dict[
self._feature_to_table_dict[feature]].combiner
self._feature_to_config_dict[feature].table_id].combiner
if not isinstance(enqueue_data, EnqueueData):
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
'not mapped to `EnqueueData`. `feature`: {}'.format(
@ -726,7 +757,7 @@ class TPUEmbedding(object):
enqueue_data0 = list(enqueue_datas.values())[0]
with ops.colocate_with(enqueue_data0.embedding_indices):
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
table_ids) = (
table_ids, max_sequence_lengths) = (
self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
sample_indices_list,
@ -734,7 +765,8 @@ class TPUEmbedding(object):
aggregation_weights_list,
table_ids,
device_ordinal=device_ordinal,
combiners=self._combiners)
combiners=self._combiners,
max_sequence_lengths=max_sequence_lengths)
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
@ -748,7 +780,7 @@ class TPUEmbedding(object):
"""
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
table_ids) = [], [], [], []
table_ids, max_sequence_lengths) = [], [], [], [], []
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for feature in features:
@ -769,9 +801,11 @@ class TPUEmbedding(object):
embedding_indices_list.append(enqueue_data.embedding_indices)
table_ids.append(table_id)
max_sequence_lengths.append(
self._feature_to_config_dict[feature].max_sequence_length)
return (sample_indices_list, embedding_indices_list,
aggregation_weights_list, table_ids)
aggregation_weights_list, table_ids, max_sequence_lengths)
def get_activations(self):
"""Get activations for features.
@ -790,9 +824,21 @@ class TPUEmbedding(object):
activations = collections.OrderedDict()
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for lookup_id, feature in enumerate(features):
stride = len(self._table_to_features_dict[table])
activations[feature] = recv_activations[table_id][lookup_id::stride, :]
num_features = self._table_to_num_features_dict[table]
feature_index = 0
table_activations = array_ops.reshape(
recv_activations[table_id],
[self.batch_size_per_core, num_features, -1])
for feature in features:
seq_length = self._feature_to_config_dict[feature].max_sequence_length
if not seq_length:
activations[feature] = table_activations[:, feature_index, :]
feature_index = feature_index + 1
else:
activations[feature] = (
table_activations[:, feature_index:(feature_index+seq_length), :])
feature_index = feature_index + seq_length
return activations
def generate_send_gradients_op(self, feature_to_gradient_dict):
@ -815,12 +861,16 @@ class TPUEmbedding(object):
gradients = []
for table in self._table_to_features_dict:
features = self._table_to_features_dict[table]
table_gradients = [
feature_to_gradient_dict[feature] for feature in features
]
table_gradients = []
for feature in features:
gradient = feature_to_gradient_dict[feature]
# Expand dims for non-sequence feature to match sequence features.
if gradient.shape.ndims == 2:
gradient = array_ops.expand_dims(gradient, 1)
table_gradients.append(gradient)
interleaved_table_grads = array_ops.reshape(
array_ops.stack(table_gradients, axis=1),
[-1, table_gradients[0].shape[1]])
array_ops.concat(table_gradients, axis=1),
[-1, table_gradients[0].shape[-1]])
gradients.append(interleaved_table_grads)
return tpu_ops.send_tpu_embedding_gradients(
inputs=gradients, config=self.config_proto.SerializeToString())
@ -834,21 +884,22 @@ def _validate_table_to_config_dict(table_to_config_dict):
'`TableConfig`, got {} for {}.'.format(type(v), k))
def _validate_feature_to_table_dict(table_to_config_dict,
feature_to_table_dict):
"""Validate `feature_to_table_dict`."""
used_table_set = set(feature_to_table_dict.values())
def _validate_feature_to_config_dict(table_to_config_dict,
feature_to_config_dict):
"""Validate `feature_to_config_dict`."""
used_table_set = set([feature.table_id
for feature in feature_to_config_dict.values()])
table_set = set(table_to_config_dict.keys())
unused_table_set = table_set - used_table_set
if unused_table_set:
raise ValueError('`table_to_config_dict` specifies table that is not '
'used in `feature_to_table_dict`: {}.'
'used in `feature_to_config_dict`: {}.'
.format(unused_table_set))
extra_table_set = used_table_set - table_set
if extra_table_set:
raise ValueError('`feature_to_table_dict` refers to a table that is not '
raise ValueError('`feature_to_config_dict` refers to a table that is not '
'specified in `table_to_config_dict`: {}.'
.format(extra_table_set))
@ -1135,19 +1186,30 @@ def _create_combiners(table_to_config_dict, table_to_features_dict):
return combiners
def _create_table_to_features_dict(feature_to_table_dict):
def _create_table_to_features_and_num_features_dicts(feature_to_config_dict):
"""Create mapping from table to a list of its features."""
table_to_features_dict_tmp = {}
for feature, table in six.iteritems(feature_to_table_dict):
if table in table_to_features_dict_tmp:
table_to_features_dict_tmp[table].append(feature)
table_to_num_features_dict_tmp = {}
for feature, feature_config in six.iteritems(feature_to_config_dict):
if feature_config.table_id in table_to_features_dict_tmp:
table_to_features_dict_tmp[feature_config.table_id].append(feature)
else:
table_to_features_dict_tmp[table] = [feature]
table_to_features_dict_tmp[feature_config.table_id] = [feature]
table_to_num_features_dict_tmp[feature_config.table_id] = 0
if feature_config.max_sequence_length == 0:
table_to_num_features_dict_tmp[feature_config.table_id] = (
table_to_num_features_dict_tmp[feature_config.table_id] + 1)
else:
table_to_num_features_dict_tmp[feature_config.table_id] = (
table_to_num_features_dict_tmp[feature_config.table_id] +
feature_config.max_sequence_length)
table_to_features_dict = collections.OrderedDict()
table_to_num_features_dict = collections.OrderedDict()
for table in sorted(table_to_features_dict_tmp):
table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
return table_to_features_dict
table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table]
return table_to_features_dict, table_to_num_features_dict
def _create_device_fn(hosts):

View File

@ -115,7 +115,7 @@ def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
"""
new_activations = collections.OrderedDict()
for feature in activations:
table = tpu_embedding.feature_to_table_dict[feature]
table = tpu_embedding.feature_to_config_dict[feature].table_id
new_activations[feature] = tpu_ops.tpu_embedding_activations(
dummy_table_variables[table],
activations[feature],